From 7a872cfca7619350aafb5e6b72b7ed00f429d1ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Mon, 5 Aug 2024 15:09:19 +0200 Subject: [PATCH 01/31] Remove unused empty constructor --- .../elasticsearch/compute/aggregation/BooleanState.java | 4 ---- .../elasticsearch/compute/aggregation/DoubleState.java | 4 ---- .../org/elasticsearch/compute/aggregation/FloatState.java | 4 ---- .../org/elasticsearch/compute/aggregation/IntState.java | 4 ---- .../org/elasticsearch/compute/aggregation/LongState.java | 4 ---- .../compute/aggregation/CountAggregatorFunction.java | 2 +- .../org/elasticsearch/compute/aggregation/X-State.java.st | 8 -------- 7 files changed, 1 insertion(+), 29 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/BooleanState.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/BooleanState.java index 7d225c7c06a72..ba4d133dee553 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/BooleanState.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/BooleanState.java @@ -18,10 +18,6 @@ final class BooleanState implements AggregatorState { private boolean value; private boolean seen; - BooleanState() { - this(false); - } - BooleanState(boolean init) { this.value = init; } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DoubleState.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DoubleState.java index f1c92c685bcab..90ecc2c1d3c03 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DoubleState.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DoubleState.java @@ -18,10 +18,6 @@ final class DoubleState implements AggregatorState { private double value; private boolean seen; - DoubleState() { - this(0); - } - DoubleState(double init) { this.value = init; } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/FloatState.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/FloatState.java index 81bdd39e51b6e..6f608271b6e42 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/FloatState.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/FloatState.java @@ -18,10 +18,6 @@ final class FloatState implements AggregatorState { private float value; private boolean seen; - FloatState() { - this(0); - } - FloatState(float init) { this.value = init; } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/IntState.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/IntState.java index e7db40eccf9c8..c539c576ef36d 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/IntState.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/IntState.java @@ -18,10 +18,6 @@ final class IntState implements AggregatorState { private int value; private boolean seen; - IntState() { - this(0); - } - IntState(int init) { this.value = init; } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/LongState.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/LongState.java index da78b649782d5..e9d97dcfe7fc1 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/LongState.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/LongState.java @@ -18,10 +18,6 @@ final class LongState implements AggregatorState { private long value; private boolean seen; - LongState() { - this(0); - } - LongState(long init) { this.value = init; } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountAggregatorFunction.java index 13a4204edfd8f..c32f6f4703a79 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountAggregatorFunction.java @@ -52,7 +52,7 @@ public static List intermediateStateDesc() { private final boolean countAll; public static CountAggregatorFunction create(List inputChannels) { - return new CountAggregatorFunction(inputChannels, new LongState()); + return new CountAggregatorFunction(inputChannels, new LongState(0)); } private CountAggregatorFunction(List channels, LongState state) { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-State.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-State.java.st index 2d2d706c9454f..7e0949c86faaa 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-State.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-State.java.st @@ -18,14 +18,6 @@ final class $Type$State implements AggregatorState { private $type$ value; private boolean seen; - $Type$State() { -$if(boolean)$ - this(false); -$else$ - this(0); -$endif$ - } - $Type$State($type$ init) { this.value = init; } From 87db6acca4620a9cfe21b6c2193cf6804d0b425f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Mon, 5 Aug 2024 15:18:15 +0200 Subject: [PATCH 02/31] Added fallible single states --- x-pack/plugin/esql/compute/build.gradle | 26 ++++++++ .../aggregation/BooleanFallibleState.java | 62 +++++++++++++++++++ .../aggregation/DoubleFallibleState.java | 62 +++++++++++++++++++ .../aggregation/FloatFallibleState.java | 62 +++++++++++++++++++ .../compute/aggregation/IntFallibleState.java | 62 +++++++++++++++++++ .../aggregation/LongFallibleState.java | 62 +++++++++++++++++++ .../aggregation/X-FallibleState.java.st | 62 +++++++++++++++++++ 7 files changed, 398 insertions(+) create mode 100644 x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/BooleanFallibleState.java create mode 100644 x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DoubleFallibleState.java create mode 100644 x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/FloatFallibleState.java create mode 100644 x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/IntFallibleState.java create mode 100644 x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/LongFallibleState.java create mode 100644 x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-FallibleState.java.st diff --git a/x-pack/plugin/esql/compute/build.gradle b/x-pack/plugin/esql/compute/build.gradle index d31a7e629003e..d4795ce8d129c 100644 --- a/x-pack/plugin/esql/compute/build.gradle +++ b/x-pack/plugin/esql/compute/build.gradle @@ -433,6 +433,32 @@ tasks.named('stringTemplates').configure { it.inputFile = stateInputFile it.outputFile = "org/elasticsearch/compute/aggregation/DoubleState.java" } + File fallibleStateInputFile = new File("${projectDir}/src/main/java/org/elasticsearch/compute/aggregation/X-FallibleState.java.st") + template { + it.properties = booleanProperties + it.inputFile = fallibleStateInputFile + it.outputFile = "org/elasticsearch/compute/aggregation/BooleanFallibleState.java" + } + template { + it.properties = intProperties + it.inputFile = fallibleStateInputFile + it.outputFile = "org/elasticsearch/compute/aggregation/IntFallibleState.java" + } + template { + it.properties = longProperties + it.inputFile = fallibleStateInputFile + it.outputFile = "org/elasticsearch/compute/aggregation/LongFallibleState.java" + } + template { + it.properties = floatProperties + it.inputFile = fallibleStateInputFile + it.outputFile = "org/elasticsearch/compute/aggregation/FloatFallibleState.java" + } + template { + it.properties = doubleProperties + it.inputFile = fallibleStateInputFile + it.outputFile = "org/elasticsearch/compute/aggregation/DoubleFallibleState.java" + } // block lookups File lookupInputFile = new File("${projectDir}/src/main/java/org/elasticsearch/compute/data/X-Lookup.java.st") template { diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/BooleanFallibleState.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/BooleanFallibleState.java new file mode 100644 index 0000000000000..073f31c390a6f --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/BooleanFallibleState.java @@ -0,0 +1,62 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * Aggregator state for a single boolean. + * It stores a third boolean to store if the aggregation failed. + * This class is generated. Do not edit it. + */ +final class BooleanFallibleState implements AggregatorState { + private boolean value; + private boolean seen; + private boolean failed; + + BooleanFallibleState(boolean init) { + this.value = init; + } + + boolean booleanValue() { + return value; + } + + void booleanValue(boolean value) { + this.value = value; + } + + boolean seen() { + return seen; + } + + void seen(boolean seen) { + this.seen = seen; + } + + boolean failed() { + return failed; + } + + void failed(boolean failed) { + this.failed = failed; + } + + /** Extracts an intermediate view of the contents of this state. */ + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + assert blocks.length >= offset + 3; + blocks[offset + 0] = driverContext.blockFactory().newConstantBooleanBlockWith(value, 1); + blocks[offset + 1] = driverContext.blockFactory().newConstantBooleanBlockWith(seen, 1); + blocks[offset + 2] = driverContext.blockFactory().newConstantBooleanBlockWith(failed, 1); + } + + @Override + public void close() {} +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DoubleFallibleState.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DoubleFallibleState.java new file mode 100644 index 0000000000000..4cdeddec724bf --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DoubleFallibleState.java @@ -0,0 +1,62 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * Aggregator state for a single double. + * It stores a third boolean to store if the aggregation failed. + * This class is generated. Do not edit it. + */ +final class DoubleFallibleState implements AggregatorState { + private double value; + private boolean seen; + private boolean failed; + + DoubleFallibleState(double init) { + this.value = init; + } + + double doubleValue() { + return value; + } + + void doubleValue(double value) { + this.value = value; + } + + boolean seen() { + return seen; + } + + void seen(boolean seen) { + this.seen = seen; + } + + boolean failed() { + return failed; + } + + void failed(boolean failed) { + this.failed = failed; + } + + /** Extracts an intermediate view of the contents of this state. */ + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + assert blocks.length >= offset + 3; + blocks[offset + 0] = driverContext.blockFactory().newConstantDoubleBlockWith(value, 1); + blocks[offset + 1] = driverContext.blockFactory().newConstantBooleanBlockWith(seen, 1); + blocks[offset + 2] = driverContext.blockFactory().newConstantBooleanBlockWith(failed, 1); + } + + @Override + public void close() {} +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/FloatFallibleState.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/FloatFallibleState.java new file mode 100644 index 0000000000000..b050c86258dcd --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/FloatFallibleState.java @@ -0,0 +1,62 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * Aggregator state for a single float. + * It stores a third boolean to store if the aggregation failed. + * This class is generated. Do not edit it. + */ +final class FloatFallibleState implements AggregatorState { + private float value; + private boolean seen; + private boolean failed; + + FloatFallibleState(float init) { + this.value = init; + } + + float floatValue() { + return value; + } + + void floatValue(float value) { + this.value = value; + } + + boolean seen() { + return seen; + } + + void seen(boolean seen) { + this.seen = seen; + } + + boolean failed() { + return failed; + } + + void failed(boolean failed) { + this.failed = failed; + } + + /** Extracts an intermediate view of the contents of this state. */ + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + assert blocks.length >= offset + 3; + blocks[offset + 0] = driverContext.blockFactory().newConstantFloatBlockWith(value, 1); + blocks[offset + 1] = driverContext.blockFactory().newConstantBooleanBlockWith(seen, 1); + blocks[offset + 2] = driverContext.blockFactory().newConstantBooleanBlockWith(failed, 1); + } + + @Override + public void close() {} +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/IntFallibleState.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/IntFallibleState.java new file mode 100644 index 0000000000000..360f3fdb009e4 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/IntFallibleState.java @@ -0,0 +1,62 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * Aggregator state for a single int. + * It stores a third boolean to store if the aggregation failed. + * This class is generated. Do not edit it. + */ +final class IntFallibleState implements AggregatorState { + private int value; + private boolean seen; + private boolean failed; + + IntFallibleState(int init) { + this.value = init; + } + + int intValue() { + return value; + } + + void intValue(int value) { + this.value = value; + } + + boolean seen() { + return seen; + } + + void seen(boolean seen) { + this.seen = seen; + } + + boolean failed() { + return failed; + } + + void failed(boolean failed) { + this.failed = failed; + } + + /** Extracts an intermediate view of the contents of this state. */ + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + assert blocks.length >= offset + 3; + blocks[offset + 0] = driverContext.blockFactory().newConstantIntBlockWith(value, 1); + blocks[offset + 1] = driverContext.blockFactory().newConstantBooleanBlockWith(seen, 1); + blocks[offset + 2] = driverContext.blockFactory().newConstantBooleanBlockWith(failed, 1); + } + + @Override + public void close() {} +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/LongFallibleState.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/LongFallibleState.java new file mode 100644 index 0000000000000..98669ef627d04 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/LongFallibleState.java @@ -0,0 +1,62 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * Aggregator state for a single long. + * It stores a third boolean to store if the aggregation failed. + * This class is generated. Do not edit it. + */ +final class LongFallibleState implements AggregatorState { + private long value; + private boolean seen; + private boolean failed; + + LongFallibleState(long init) { + this.value = init; + } + + long longValue() { + return value; + } + + void longValue(long value) { + this.value = value; + } + + boolean seen() { + return seen; + } + + void seen(boolean seen) { + this.seen = seen; + } + + boolean failed() { + return failed; + } + + void failed(boolean failed) { + this.failed = failed; + } + + /** Extracts an intermediate view of the contents of this state. */ + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + assert blocks.length >= offset + 3; + blocks[offset + 0] = driverContext.blockFactory().newConstantLongBlockWith(value, 1); + blocks[offset + 1] = driverContext.blockFactory().newConstantBooleanBlockWith(seen, 1); + blocks[offset + 2] = driverContext.blockFactory().newConstantBooleanBlockWith(failed, 1); + } + + @Override + public void close() {} +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-FallibleState.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-FallibleState.java.st new file mode 100644 index 0000000000000..27609383e4f61 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-FallibleState.java.st @@ -0,0 +1,62 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * Aggregator state for a single $type$. + * It stores a third boolean to store if the aggregation failed. + * This class is generated. Do not edit it. + */ +final class $Type$FallibleState implements AggregatorState { + private $type$ value; + private boolean seen; + private boolean failed; + + $Type$FallibleState($type$ init) { + this.value = init; + } + + $type$ $type$Value() { + return value; + } + + void $type$Value($type$ value) { + this.value = value; + } + + boolean seen() { + return seen; + } + + void seen(boolean seen) { + this.seen = seen; + } + + boolean failed() { + return failed; + } + + void failed(boolean failed) { + this.failed = failed; + } + + /** Extracts an intermediate view of the contents of this state. */ + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + assert blocks.length >= offset + 3; + blocks[offset + 0] = driverContext.blockFactory().newConstant$Type$BlockWith(value, 1); + blocks[offset + 1] = driverContext.blockFactory().newConstantBooleanBlockWith(seen, 1); + blocks[offset + 2] = driverContext.blockFactory().newConstantBooleanBlockWith(failed, 1); + } + + @Override + public void close() {} +} From 17e9ea25b0544af7b80f6c75bb04ee83b8c8fe79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Mon, 5 Aug 2024 15:28:41 +0200 Subject: [PATCH 03/31] Added fallible array state --- x-pack/plugin/esql/compute/build.gradle | 26 +++ .../BooleanFallibleArrayState.java | 125 +++++++++++++ .../aggregation/DoubleFallibleArrayState.java | 124 +++++++++++++ .../aggregation/FloatFallibleArrayState.java | 124 +++++++++++++ .../aggregation/IntFallibleArrayState.java | 124 +++++++++++++ .../aggregation/LongFallibleArrayState.java | 130 ++++++++++++++ .../AbstractFallibleArrayState.java | 69 ++++++++ .../aggregation/X-FallibleArrayState.java.st | 166 ++++++++++++++++++ 8 files changed, 888 insertions(+) create mode 100644 x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/BooleanFallibleArrayState.java create mode 100644 x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DoubleFallibleArrayState.java create mode 100644 x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/FloatFallibleArrayState.java create mode 100644 x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/IntFallibleArrayState.java create mode 100644 x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/LongFallibleArrayState.java create mode 100644 x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/AbstractFallibleArrayState.java create mode 100644 x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-FallibleArrayState.java.st diff --git a/x-pack/plugin/esql/compute/build.gradle b/x-pack/plugin/esql/compute/build.gradle index d4795ce8d129c..136b3b5313c90 100644 --- a/x-pack/plugin/esql/compute/build.gradle +++ b/x-pack/plugin/esql/compute/build.gradle @@ -517,6 +517,32 @@ tasks.named('stringTemplates').configure { it.inputFile = arrayStateInputFile it.outputFile = "org/elasticsearch/compute/aggregation/FloatArrayState.java" } + File fallibleArrayStateInputFile = new File("${projectDir}/src/main/java/org/elasticsearch/compute/aggregation/X-FallibleArrayState.java.st") + template { + it.properties = booleanProperties + it.inputFile = fallibleArrayStateInputFile + it.outputFile = "org/elasticsearch/compute/aggregation/BooleanFallibleArrayState.java" + } + template { + it.properties = intProperties + it.inputFile = fallibleArrayStateInputFile + it.outputFile = "org/elasticsearch/compute/aggregation/IntFallibleArrayState.java" + } + template { + it.properties = longProperties + it.inputFile = fallibleArrayStateInputFile + it.outputFile = "org/elasticsearch/compute/aggregation/LongFallibleArrayState.java" + } + template { + it.properties = doubleProperties + it.inputFile = fallibleArrayStateInputFile + it.outputFile = "org/elasticsearch/compute/aggregation/DoubleFallibleArrayState.java" + } + template { + it.properties = floatProperties + it.inputFile = fallibleArrayStateInputFile + it.outputFile = "org/elasticsearch/compute/aggregation/FloatFallibleArrayState.java" + } File valuesAggregatorInputFile = new File("${projectDir}/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st") template { it.properties = intProperties diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/BooleanFallibleArrayState.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/BooleanFallibleArrayState.java new file mode 100644 index 0000000000000..6367fdfb6617e --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/BooleanFallibleArrayState.java @@ -0,0 +1,125 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; + +/** + * Aggregator state for an array of booleans, that also tracks failures. + * It is created in a mode where it won't track + * the {@code groupId}s that are sent to it and it is the + * responsibility of the caller to only fetch values for {@code groupId}s + * that it has sent using the {@code selected} parameter when building the + * results. This is fine when there are no {@code null} values in the input + * data. But once there are null values in the input data it is + * much more convenient to only send non-null values and + * the tracking built into the grouping code can't track that. In that case + * call {@link #enableGroupIdTracking} to transition the state into a mode + * where it'll track which {@code groupIds} have been written. + *

+ * This class is generated. Do not edit it. + *

+ */ +final class BooleanFallibleArrayState extends AbstractFallibleArrayState implements GroupingAggregatorState { + private final boolean init; + + private BitArray values; + private int size; + + BooleanFallibleArrayState(BigArrays bigArrays, boolean init) { + super(bigArrays); + this.values = new BitArray(1, bigArrays); + this.size = 1; + this.values.set(0, init); + this.init = init; + } + + boolean get(int groupId) { + return values.get(groupId); + } + + boolean getOrDefault(int groupId) { + return groupId < size ? values.get(groupId) : init; + } + + void set(int groupId, boolean value) { + ensureCapacity(groupId); + values.set(groupId, value); + trackGroupId(groupId); + } + + Block toValuesBlock(org.elasticsearch.compute.data.IntVector selected, DriverContext driverContext) { + if (false == trackingGroupIds() && false == anyFailure()) { + try (var builder = driverContext.blockFactory().newBooleanVectorFixedBuilder(selected.getPositionCount())) { + for (int i = 0; i < selected.getPositionCount(); i++) { + builder.appendBoolean(i, values.get(selected.getInt(i))); + } + return builder.build().asBlock(); + } + } + try (BooleanBlock.Builder builder = driverContext.blockFactory().newBooleanBlockBuilder(selected.getPositionCount())) { + for (int i = 0; i < selected.getPositionCount(); i++) { + int group = selected.getInt(i); + if (hasValue(group) && !hasFailed(group)) { + builder.appendBoolean(values.get(group)); + } else { + builder.appendNull(); + } + } + return builder.build(); + } + } + + private void ensureCapacity(int groupId) { + if (groupId >= size) { + values.fill(size, groupId + 1, init); + size = groupId + 1; + } + } + + /** Extracts an intermediate view of the contents of this state. */ + @Override + public void toIntermediate( + Block[] blocks, + int offset, + IntVector selected, + org.elasticsearch.compute.operator.DriverContext driverContext + ) { + assert blocks.length >= offset + 3; + try ( + var valuesBuilder = driverContext.blockFactory().newBooleanBlockBuilder(selected.getPositionCount()); + var hasValueBuilder = driverContext.blockFactory().newBooleanVectorFixedBuilder(selected.getPositionCount()); + var hasFailedBuilder = driverContext.blockFactory().newBooleanVectorFixedBuilder(selected.getPositionCount()) + ) { + for (int i = 0; i < selected.getPositionCount(); i++) { + int group = selected.getInt(i); + if (group < size) { + valuesBuilder.appendBoolean(values.get(group)); + } else { + valuesBuilder.appendBoolean(false); // TODO can we just use null? + } + hasValueBuilder.appendBoolean(i, hasValue(group)); + hasFailedBuilder.appendBoolean(i, hasFailed(group)); + } + blocks[offset + 0] = valuesBuilder.build(); + blocks[offset + 1] = hasValueBuilder.build().asBlock(); + blocks[offset + 2] = hasFailedBuilder.build().asBlock(); + } + } + + @Override + public void close() { + Releasables.close(values, super::close); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DoubleFallibleArrayState.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DoubleFallibleArrayState.java new file mode 100644 index 0000000000000..dd1d60f7bd246 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DoubleFallibleArrayState.java @@ -0,0 +1,124 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.DoubleArray; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; + +/** + * Aggregator state for an array of doubles, that also tracks failures. + * It is created in a mode where it won't track + * the {@code groupId}s that are sent to it and it is the + * responsibility of the caller to only fetch values for {@code groupId}s + * that it has sent using the {@code selected} parameter when building the + * results. This is fine when there are no {@code null} values in the input + * data. But once there are null values in the input data it is + * much more convenient to only send non-null values and + * the tracking built into the grouping code can't track that. In that case + * call {@link #enableGroupIdTracking} to transition the state into a mode + * where it'll track which {@code groupIds} have been written. + *

+ * This class is generated. Do not edit it. + *

+ */ +final class DoubleFallibleArrayState extends AbstractFallibleArrayState implements GroupingAggregatorState { + private final double init; + + private DoubleArray values; + + DoubleFallibleArrayState(BigArrays bigArrays, double init) { + super(bigArrays); + this.values = bigArrays.newDoubleArray(1, false); + this.values.set(0, init); + this.init = init; + } + + double get(int groupId) { + return values.get(groupId); + } + + double getOrDefault(int groupId) { + return groupId < values.size() ? values.get(groupId) : init; + } + + void set(int groupId, double value) { + ensureCapacity(groupId); + values.set(groupId, value); + trackGroupId(groupId); + } + + Block toValuesBlock(org.elasticsearch.compute.data.IntVector selected, DriverContext driverContext) { + if (false == trackingGroupIds() && false == anyFailure()) { + try (var builder = driverContext.blockFactory().newDoubleVectorFixedBuilder(selected.getPositionCount())) { + for (int i = 0; i < selected.getPositionCount(); i++) { + builder.appendDouble(i, values.get(selected.getInt(i))); + } + return builder.build().asBlock(); + } + } + try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount())) { + for (int i = 0; i < selected.getPositionCount(); i++) { + int group = selected.getInt(i); + if (hasValue(group) && !hasFailed(group)) { + builder.appendDouble(values.get(group)); + } else { + builder.appendNull(); + } + } + return builder.build(); + } + } + + private void ensureCapacity(int groupId) { + if (groupId >= values.size()) { + long prevSize = values.size(); + values = bigArrays.grow(values, groupId + 1); + values.fill(prevSize, values.size(), init); + } + } + + /** Extracts an intermediate view of the contents of this state. */ + @Override + public void toIntermediate( + Block[] blocks, + int offset, + IntVector selected, + org.elasticsearch.compute.operator.DriverContext driverContext + ) { + assert blocks.length >= offset + 3; + try ( + var valuesBuilder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount()); + var hasValueBuilder = driverContext.blockFactory().newBooleanVectorFixedBuilder(selected.getPositionCount()); + var hasFailedBuilder = driverContext.blockFactory().newBooleanVectorFixedBuilder(selected.getPositionCount()) + ) { + for (int i = 0; i < selected.getPositionCount(); i++) { + int group = selected.getInt(i); + if (group < values.size()) { + valuesBuilder.appendDouble(values.get(group)); + } else { + valuesBuilder.appendDouble(0); // TODO can we just use null? + } + hasValueBuilder.appendBoolean(i, hasValue(group)); + hasFailedBuilder.appendBoolean(i, hasFailed(group)); + } + blocks[offset + 0] = valuesBuilder.build(); + blocks[offset + 1] = hasValueBuilder.build().asBlock(); + blocks[offset + 2] = hasFailedBuilder.build().asBlock(); + } + } + + @Override + public void close() { + Releasables.close(values, super::close); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/FloatFallibleArrayState.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/FloatFallibleArrayState.java new file mode 100644 index 0000000000000..055cf345033c5 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/FloatFallibleArrayState.java @@ -0,0 +1,124 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.FloatArray; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; + +/** + * Aggregator state for an array of floats, that also tracks failures. + * It is created in a mode where it won't track + * the {@code groupId}s that are sent to it and it is the + * responsibility of the caller to only fetch values for {@code groupId}s + * that it has sent using the {@code selected} parameter when building the + * results. This is fine when there are no {@code null} values in the input + * data. But once there are null values in the input data it is + * much more convenient to only send non-null values and + * the tracking built into the grouping code can't track that. In that case + * call {@link #enableGroupIdTracking} to transition the state into a mode + * where it'll track which {@code groupIds} have been written. + *

+ * This class is generated. Do not edit it. + *

+ */ +final class FloatFallibleArrayState extends AbstractFallibleArrayState implements GroupingAggregatorState { + private final float init; + + private FloatArray values; + + FloatFallibleArrayState(BigArrays bigArrays, float init) { + super(bigArrays); + this.values = bigArrays.newFloatArray(1, false); + this.values.set(0, init); + this.init = init; + } + + float get(int groupId) { + return values.get(groupId); + } + + float getOrDefault(int groupId) { + return groupId < values.size() ? values.get(groupId) : init; + } + + void set(int groupId, float value) { + ensureCapacity(groupId); + values.set(groupId, value); + trackGroupId(groupId); + } + + Block toValuesBlock(org.elasticsearch.compute.data.IntVector selected, DriverContext driverContext) { + if (false == trackingGroupIds() && false == anyFailure()) { + try (var builder = driverContext.blockFactory().newFloatVectorFixedBuilder(selected.getPositionCount())) { + for (int i = 0; i < selected.getPositionCount(); i++) { + builder.appendFloat(i, values.get(selected.getInt(i))); + } + return builder.build().asBlock(); + } + } + try (FloatBlock.Builder builder = driverContext.blockFactory().newFloatBlockBuilder(selected.getPositionCount())) { + for (int i = 0; i < selected.getPositionCount(); i++) { + int group = selected.getInt(i); + if (hasValue(group) && !hasFailed(group)) { + builder.appendFloat(values.get(group)); + } else { + builder.appendNull(); + } + } + return builder.build(); + } + } + + private void ensureCapacity(int groupId) { + if (groupId >= values.size()) { + long prevSize = values.size(); + values = bigArrays.grow(values, groupId + 1); + values.fill(prevSize, values.size(), init); + } + } + + /** Extracts an intermediate view of the contents of this state. */ + @Override + public void toIntermediate( + Block[] blocks, + int offset, + IntVector selected, + org.elasticsearch.compute.operator.DriverContext driverContext + ) { + assert blocks.length >= offset + 3; + try ( + var valuesBuilder = driverContext.blockFactory().newFloatBlockBuilder(selected.getPositionCount()); + var hasValueBuilder = driverContext.blockFactory().newBooleanVectorFixedBuilder(selected.getPositionCount()); + var hasFailedBuilder = driverContext.blockFactory().newBooleanVectorFixedBuilder(selected.getPositionCount()) + ) { + for (int i = 0; i < selected.getPositionCount(); i++) { + int group = selected.getInt(i); + if (group < values.size()) { + valuesBuilder.appendFloat(values.get(group)); + } else { + valuesBuilder.appendFloat(0); // TODO can we just use null? + } + hasValueBuilder.appendBoolean(i, hasValue(group)); + hasFailedBuilder.appendBoolean(i, hasFailed(group)); + } + blocks[offset + 0] = valuesBuilder.build(); + blocks[offset + 1] = hasValueBuilder.build().asBlock(); + blocks[offset + 2] = hasFailedBuilder.build().asBlock(); + } + } + + @Override + public void close() { + Releasables.close(values, super::close); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/IntFallibleArrayState.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/IntFallibleArrayState.java new file mode 100644 index 0000000000000..e45d84720ca1a --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/IntFallibleArrayState.java @@ -0,0 +1,124 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.IntArray; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; + +/** + * Aggregator state for an array of ints, that also tracks failures. + * It is created in a mode where it won't track + * the {@code groupId}s that are sent to it and it is the + * responsibility of the caller to only fetch values for {@code groupId}s + * that it has sent using the {@code selected} parameter when building the + * results. This is fine when there are no {@code null} values in the input + * data. But once there are null values in the input data it is + * much more convenient to only send non-null values and + * the tracking built into the grouping code can't track that. In that case + * call {@link #enableGroupIdTracking} to transition the state into a mode + * where it'll track which {@code groupIds} have been written. + *

+ * This class is generated. Do not edit it. + *

+ */ +final class IntFallibleArrayState extends AbstractFallibleArrayState implements GroupingAggregatorState { + private final int init; + + private IntArray values; + + IntFallibleArrayState(BigArrays bigArrays, int init) { + super(bigArrays); + this.values = bigArrays.newIntArray(1, false); + this.values.set(0, init); + this.init = init; + } + + int get(int groupId) { + return values.get(groupId); + } + + int getOrDefault(int groupId) { + return groupId < values.size() ? values.get(groupId) : init; + } + + void set(int groupId, int value) { + ensureCapacity(groupId); + values.set(groupId, value); + trackGroupId(groupId); + } + + Block toValuesBlock(org.elasticsearch.compute.data.IntVector selected, DriverContext driverContext) { + if (false == trackingGroupIds() && false == anyFailure()) { + try (var builder = driverContext.blockFactory().newIntVectorFixedBuilder(selected.getPositionCount())) { + for (int i = 0; i < selected.getPositionCount(); i++) { + builder.appendInt(i, values.get(selected.getInt(i))); + } + return builder.build().asBlock(); + } + } + try (IntBlock.Builder builder = driverContext.blockFactory().newIntBlockBuilder(selected.getPositionCount())) { + for (int i = 0; i < selected.getPositionCount(); i++) { + int group = selected.getInt(i); + if (hasValue(group) && !hasFailed(group)) { + builder.appendInt(values.get(group)); + } else { + builder.appendNull(); + } + } + return builder.build(); + } + } + + private void ensureCapacity(int groupId) { + if (groupId >= values.size()) { + long prevSize = values.size(); + values = bigArrays.grow(values, groupId + 1); + values.fill(prevSize, values.size(), init); + } + } + + /** Extracts an intermediate view of the contents of this state. */ + @Override + public void toIntermediate( + Block[] blocks, + int offset, + IntVector selected, + org.elasticsearch.compute.operator.DriverContext driverContext + ) { + assert blocks.length >= offset + 3; + try ( + var valuesBuilder = driverContext.blockFactory().newIntBlockBuilder(selected.getPositionCount()); + var hasValueBuilder = driverContext.blockFactory().newBooleanVectorFixedBuilder(selected.getPositionCount()); + var hasFailedBuilder = driverContext.blockFactory().newBooleanVectorFixedBuilder(selected.getPositionCount()) + ) { + for (int i = 0; i < selected.getPositionCount(); i++) { + int group = selected.getInt(i); + if (group < values.size()) { + valuesBuilder.appendInt(values.get(group)); + } else { + valuesBuilder.appendInt(0); // TODO can we just use null? + } + hasValueBuilder.appendBoolean(i, hasValue(group)); + hasFailedBuilder.appendBoolean(i, hasFailed(group)); + } + blocks[offset + 0] = valuesBuilder.build(); + blocks[offset + 1] = hasValueBuilder.build().asBlock(); + blocks[offset + 2] = hasFailedBuilder.build().asBlock(); + } + } + + @Override + public void close() { + Releasables.close(values, super::close); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/LongFallibleArrayState.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/LongFallibleArrayState.java new file mode 100644 index 0000000000000..cb69579906871 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/LongFallibleArrayState.java @@ -0,0 +1,130 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.LongArray; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; + +/** + * Aggregator state for an array of longs, that also tracks failures. + * It is created in a mode where it won't track + * the {@code groupId}s that are sent to it and it is the + * responsibility of the caller to only fetch values for {@code groupId}s + * that it has sent using the {@code selected} parameter when building the + * results. This is fine when there are no {@code null} values in the input + * data. But once there are null values in the input data it is + * much more convenient to only send non-null values and + * the tracking built into the grouping code can't track that. In that case + * call {@link #enableGroupIdTracking} to transition the state into a mode + * where it'll track which {@code groupIds} have been written. + *

+ * This class is generated. Do not edit it. + *

+ */ +final class LongFallibleArrayState extends AbstractFallibleArrayState implements GroupingAggregatorState { + private final long init; + + private LongArray values; + + LongFallibleArrayState(BigArrays bigArrays, long init) { + super(bigArrays); + this.values = bigArrays.newLongArray(1, false); + this.values.set(0, init); + this.init = init; + } + + long get(int groupId) { + return values.get(groupId); + } + + long getOrDefault(int groupId) { + return groupId < values.size() ? values.get(groupId) : init; + } + + void set(int groupId, long value) { + ensureCapacity(groupId); + values.set(groupId, value); + trackGroupId(groupId); + } + + void increment(int groupId, long value) { + ensureCapacity(groupId); + values.increment(groupId, value); + trackGroupId(groupId); + } + + Block toValuesBlock(org.elasticsearch.compute.data.IntVector selected, DriverContext driverContext) { + if (false == trackingGroupIds() && false == anyFailure()) { + try (var builder = driverContext.blockFactory().newLongVectorFixedBuilder(selected.getPositionCount())) { + for (int i = 0; i < selected.getPositionCount(); i++) { + builder.appendLong(i, values.get(selected.getInt(i))); + } + return builder.build().asBlock(); + } + } + try (LongBlock.Builder builder = driverContext.blockFactory().newLongBlockBuilder(selected.getPositionCount())) { + for (int i = 0; i < selected.getPositionCount(); i++) { + int group = selected.getInt(i); + if (hasValue(group) && !hasFailed(group)) { + builder.appendLong(values.get(group)); + } else { + builder.appendNull(); + } + } + return builder.build(); + } + } + + private void ensureCapacity(int groupId) { + if (groupId >= values.size()) { + long prevSize = values.size(); + values = bigArrays.grow(values, groupId + 1); + values.fill(prevSize, values.size(), init); + } + } + + /** Extracts an intermediate view of the contents of this state. */ + @Override + public void toIntermediate( + Block[] blocks, + int offset, + IntVector selected, + org.elasticsearch.compute.operator.DriverContext driverContext + ) { + assert blocks.length >= offset + 3; + try ( + var valuesBuilder = driverContext.blockFactory().newLongBlockBuilder(selected.getPositionCount()); + var hasValueBuilder = driverContext.blockFactory().newBooleanVectorFixedBuilder(selected.getPositionCount()); + var hasFailedBuilder = driverContext.blockFactory().newBooleanVectorFixedBuilder(selected.getPositionCount()) + ) { + for (int i = 0; i < selected.getPositionCount(); i++) { + int group = selected.getInt(i); + if (group < values.size()) { + valuesBuilder.appendLong(values.get(group)); + } else { + valuesBuilder.appendLong(0); // TODO can we just use null? + } + hasValueBuilder.appendBoolean(i, hasValue(group)); + hasFailedBuilder.appendBoolean(i, hasFailed(group)); + } + blocks[offset + 0] = valuesBuilder.build(); + blocks[offset + 1] = hasValueBuilder.build().asBlock(); + blocks[offset + 2] = hasFailedBuilder.build().asBlock(); + } + } + + @Override + public void close() { + Releasables.close(values, super::close); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/AbstractFallibleArrayState.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/AbstractFallibleArrayState.java new file mode 100644 index 0000000000000..8a5aa7580d927 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/AbstractFallibleArrayState.java @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; + +public class AbstractFallibleArrayState implements Releasable { + protected final BigArrays bigArrays; + + private BitArray seen; + private BitArray failed; + + public AbstractFallibleArrayState(BigArrays bigArrays) { + this.bigArrays = bigArrays; + } + + final boolean hasValue(int groupId) { + return seen == null || seen.get(groupId); + } + + final boolean hasFailed(int groupId) { + return failed != null && failed.get(groupId); + } + + /** + * Switches this array state into tracking which group ids are set. This is + * idempotent and fast if already tracking so it's safe to, say, call it once + * for every block of values that arrives containing {@code null}. + */ + final void enableGroupIdTracking(SeenGroupIds seenGroupIds) { + if (seen == null) { + seen = seenGroupIds.seenGroupIds(bigArrays); + } + } + + protected final void trackGroupId(int groupId) { + if (trackingGroupIds()) { + seen.set(groupId); + } + } + + protected final boolean trackingGroupIds() { + return seen != null; + } + + protected final boolean anyFailure() { + return failed != null; + } + + protected final void setFailed(int groupId) { + if (failed == null) { + failed = new BitArray(groupId + 1, bigArrays); + } + failed.set(groupId); + } + + @Override + public void close() { + Releasables.close(seen); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-FallibleArrayState.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-FallibleArrayState.java.st new file mode 100644 index 0000000000000..3c57ab948a79f --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-FallibleArrayState.java.st @@ -0,0 +1,166 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.util.BigArrays; +$if(boolean)$ +import org.elasticsearch.common.util.BitArray; +$else$ +import org.elasticsearch.common.util.$Type$Array; +$endif$ +import org.elasticsearch.compute.data.Block; +$if(long)$ +import org.elasticsearch.compute.data.IntVector; +$endif$ +import org.elasticsearch.compute.data.$Type$Block; +$if(int)$ +import org.elasticsearch.compute.data.$Type$Vector; +$endif$ +$if(boolean||double||float)$ +import org.elasticsearch.compute.data.IntVector; +$endif$ +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; + +/** + * Aggregator state for an array of $type$s, that also tracks failures. + * It is created in a mode where it won't track + * the {@code groupId}s that are sent to it and it is the + * responsibility of the caller to only fetch values for {@code groupId}s + * that it has sent using the {@code selected} parameter when building the + * results. This is fine when there are no {@code null} values in the input + * data. But once there are null values in the input data it is + * much more convenient to only send non-null values and + * the tracking built into the grouping code can't track that. In that case + * call {@link #enableGroupIdTracking} to transition the state into a mode + * where it'll track which {@code groupIds} have been written. + *

+ * This class is generated. Do not edit it. + *

+ */ +final class $Type$FallibleArrayState extends AbstractFallibleArrayState implements GroupingAggregatorState { + private final $type$ init; + +$if(boolean)$ + private BitArray values; + private int size; + +$else$ + private $Type$Array values; +$endif$ + + $Type$FallibleArrayState(BigArrays bigArrays, $type$ init) { + super(bigArrays); +$if(boolean)$ + this.values = new BitArray(1, bigArrays); + this.size = 1; +$else$ + this.values = bigArrays.new$Type$Array(1, false); +$endif$ + this.values.set(0, init); + this.init = init; + } + + $type$ get(int groupId) { + return values.get(groupId); + } + + $type$ getOrDefault(int groupId) { +$if(boolean)$ + return groupId < size ? values.get(groupId) : init; +$else$ + return groupId < values.size() ? values.get(groupId) : init; +$endif$ + } + + void set(int groupId, $type$ value) { + ensureCapacity(groupId); + values.set(groupId, value); + trackGroupId(groupId); + } + +$if(long)$ + void increment(int groupId, long value) { + ensureCapacity(groupId); + values.increment(groupId, value); + trackGroupId(groupId); + } +$endif$ + + Block toValuesBlock(org.elasticsearch.compute.data.IntVector selected, DriverContext driverContext) { + if (false == trackingGroupIds() && false == anyFailure()) { + try (var builder = driverContext.blockFactory().new$Type$VectorFixedBuilder(selected.getPositionCount())) { + for (int i = 0; i < selected.getPositionCount(); i++) { + builder.append$Type$(i, values.get(selected.getInt(i))); + } + return builder.build().asBlock(); + } + } + try ($Type$Block.Builder builder = driverContext.blockFactory().new$Type$BlockBuilder(selected.getPositionCount())) { + for (int i = 0; i < selected.getPositionCount(); i++) { + int group = selected.getInt(i); + if (hasValue(group) && !hasFailed(group)) { + builder.append$Type$(values.get(group)); + } else { + builder.appendNull(); + } + } + return builder.build(); + } + } + + private void ensureCapacity(int groupId) { +$if(boolean)$ + if (groupId >= size) { + values.fill(size, groupId + 1, init); + size = groupId + 1; + } +$else$ + if (groupId >= values.size()) { + long prevSize = values.size(); + values = bigArrays.grow(values, groupId + 1); + values.fill(prevSize, values.size(), init); + } +$endif$ + } + + /** Extracts an intermediate view of the contents of this state. */ + @Override + public void toIntermediate( + Block[] blocks, + int offset, + IntVector selected, + org.elasticsearch.compute.operator.DriverContext driverContext + ) { + assert blocks.length >= offset + 3; + try ( + var valuesBuilder = driverContext.blockFactory().new$Type$BlockBuilder(selected.getPositionCount()); + var hasValueBuilder = driverContext.blockFactory().newBooleanVectorFixedBuilder(selected.getPositionCount()); + var hasFailedBuilder = driverContext.blockFactory().newBooleanVectorFixedBuilder(selected.getPositionCount()) + ) { + for (int i = 0; i < selected.getPositionCount(); i++) { + int group = selected.getInt(i); + if (group < $if(boolean)$size$else$values.size()$endif$) { + valuesBuilder.append$Type$(values.get(group)); + } else { + valuesBuilder.append$Type$($if(boolean)$false$else$0$endif$); // TODO can we just use null? + } + hasValueBuilder.appendBoolean(i, hasValue(group)); + hasFailedBuilder.appendBoolean(i, hasFailed(group)); + } + blocks[offset + 0] = valuesBuilder.build(); + blocks[offset + 1] = hasValueBuilder.build().asBlock(); + blocks[offset + 2] = hasFailedBuilder.build().asBlock(); + } + } + + @Override + public void close() { + Releasables.close(values, super::close); + } +} From 318d5cdfeaf3f5b30d6c6773ba7cc5aa0df6c329 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Tue, 6 Aug 2024 14:56:21 +0200 Subject: [PATCH 04/31] Added custom warning object to aggregator function --- .../elasticsearch/compute/ann/Aggregator.java | 6 ++ ...AggregatorFunctionSupplierImplementer.java | 40 ++++++++-- .../compute/gen/AggregatorImplementer.java | 35 +++++++-- .../compute/gen/AggregatorProcessor.java | 37 +++++++++- .../org/elasticsearch/compute/gen/Types.java | 8 ++ .../SumLongAggregatorFunction.java | 11 ++- .../SumLongAggregatorFunctionSupplier.java | 15 +++- .../aggregation/SumLongAggregator.java | 5 +- .../compute/aggregation/Warnings.java | 74 +++++++++++++++++++ .../expression/function/aggregate/Sum.java | 4 +- 10 files changed, 215 insertions(+), 20 deletions(-) create mode 100644 x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/Warnings.java diff --git a/x-pack/plugin/esql/compute/ann/src/main/java/org/elasticsearch/compute/ann/Aggregator.java b/x-pack/plugin/esql/compute/ann/src/main/java/org/elasticsearch/compute/ann/Aggregator.java index 69db6a1310c9e..444dbcc1b9e58 100644 --- a/x-pack/plugin/esql/compute/ann/src/main/java/org/elasticsearch/compute/ann/Aggregator.java +++ b/x-pack/plugin/esql/compute/ann/src/main/java/org/elasticsearch/compute/ann/Aggregator.java @@ -57,4 +57,10 @@ IntermediateState[] value() default {}; + /** + * Exceptions thrown by the `combine*(...)` methods to catch and convert + * into a warning and turn into a null value. + */ + Class[] warnExceptions() default {}; + } diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorFunctionSupplierImplementer.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorFunctionSupplierImplementer.java index 3f031db2978f9..e43a26e89cb48 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorFunctionSupplierImplementer.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorFunctionSupplierImplementer.java @@ -10,6 +10,7 @@ import com.squareup.javapoet.ClassName; import com.squareup.javapoet.JavaFile; import com.squareup.javapoet.MethodSpec; +import com.squareup.javapoet.TypeName; import com.squareup.javapoet.TypeSpec; import org.elasticsearch.compute.ann.Aggregator; @@ -31,6 +32,7 @@ import static org.elasticsearch.compute.gen.Types.AGGREGATOR_FUNCTION_SUPPLIER; import static org.elasticsearch.compute.gen.Types.DRIVER_CONTEXT; import static org.elasticsearch.compute.gen.Types.LIST_INTEGER; +import static org.elasticsearch.compute.gen.Types.STRING; /** * Implements "AggregationFunctionSupplier" from a class annotated with both @@ -40,6 +42,7 @@ public class AggregatorFunctionSupplierImplementer { private final TypeElement declarationType; private final AggregatorImplementer aggregatorImplementer; private final GroupingAggregatorImplementer groupingAggregatorImplementer; + private final boolean hasWarnings; private final List createParameters; private final ClassName implementation; @@ -47,11 +50,13 @@ public AggregatorFunctionSupplierImplementer( Elements elements, TypeElement declarationType, AggregatorImplementer aggregatorImplementer, - GroupingAggregatorImplementer groupingAggregatorImplementer + GroupingAggregatorImplementer groupingAggregatorImplementer, + boolean hasWarnings ) { this.declarationType = declarationType; this.aggregatorImplementer = aggregatorImplementer; this.groupingAggregatorImplementer = groupingAggregatorImplementer; + this.hasWarnings = hasWarnings; Set createParameters = new LinkedHashSet<>(); if (aggregatorImplementer != null) { @@ -86,6 +91,11 @@ private TypeSpec type() { builder.addModifiers(Modifier.PUBLIC, Modifier.FINAL); builder.addSuperinterface(AGGREGATOR_FUNCTION_SUPPLIER); + if (hasWarnings) { + builder.addField(TypeName.INT, "warningsLineNumber"); + builder.addField(TypeName.INT, "warningsColumnNumber"); + builder.addField(STRING, "warningsSourceText"); + } createParameters.stream().forEach(p -> p.declareField(builder)); builder.addMethod(ctor()); if (aggregatorImplementer != null) { @@ -100,6 +110,14 @@ private TypeSpec type() { private MethodSpec ctor() { MethodSpec.Builder builder = MethodSpec.constructorBuilder().addModifiers(Modifier.PUBLIC); + if (hasWarnings) { + builder.addParameter(TypeName.INT, "warningsLineNumber"); + builder.addParameter(TypeName.INT, "warningsColumnNumber"); + builder.addParameter(STRING, "warningsSourceText"); + builder.addStatement("this.warningsLineNumber = warningsLineNumber"); + builder.addStatement("this.warningsColumnNumber = warningsColumnNumber"); + builder.addStatement("this.warningsSourceText = warningsSourceText"); + } createParameters.stream().forEach(p -> p.buildCtor(builder)); return builder.build(); } @@ -114,14 +132,26 @@ private MethodSpec unsupportedNonGroupingAggregator() { } private MethodSpec aggregator() { - MethodSpec.Builder builder = MethodSpec.methodBuilder("aggregator") - .addParameter(DRIVER_CONTEXT, "driverContext") - .returns(aggregatorImplementer.implementation()); + MethodSpec.Builder builder = MethodSpec.methodBuilder("aggregator"); builder.addAnnotation(Override.class).addModifiers(Modifier.PUBLIC); + builder.addParameter(DRIVER_CONTEXT, "driverContext"); + builder.returns(aggregatorImplementer.implementation()); + + if (hasWarnings) { + builder.addStatement("var warnings = Warnings.createWarnings(driverContext.warningsMode(), " + + "warningsLineNumber, warningsColumnNumber, warningsSourceText)"); + } + builder.addStatement( "return $T.create($L)", aggregatorImplementer.implementation(), - Stream.concat(Stream.of("driverContext, channels"), aggregatorImplementer.createParameters().stream().map(Parameter::name)) + Stream.concat( + Stream.concat( + hasWarnings ? Stream.of("warnings") : Stream.of(), + Stream.of("driverContext, channels") + ), + aggregatorImplementer.createParameters().stream().map(Parameter::name) + ) .collect(Collectors.joining(", ")) ); diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java index b3d32a82cc7a9..f302737ea79f2 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java @@ -25,6 +25,7 @@ import javax.lang.model.element.ExecutableElement; import javax.lang.model.element.Modifier; import javax.lang.model.element.TypeElement; +import javax.lang.model.type.TypeMirror; import javax.lang.model.util.Elements; import static java.util.stream.Collectors.joining; @@ -40,6 +41,7 @@ import static org.elasticsearch.compute.gen.Types.BYTES_REF; import static org.elasticsearch.compute.gen.Types.BYTES_REF_BLOCK; import static org.elasticsearch.compute.gen.Types.BYTES_REF_VECTOR; +import static org.elasticsearch.compute.gen.Types.COMPUTE_WARNINGS; import static org.elasticsearch.compute.gen.Types.DOUBLE_BLOCK; import static org.elasticsearch.compute.gen.Types.DOUBLE_VECTOR; import static org.elasticsearch.compute.gen.Types.DRIVER_CONTEXT; @@ -68,6 +70,7 @@ */ public class AggregatorImplementer { private final TypeElement declarationType; + private final List warnExceptions; private final ExecutableElement init; private final ExecutableElement combine; private final ExecutableElement combineValueCount; @@ -80,8 +83,14 @@ public class AggregatorImplementer { private final List intermediateState; private final List createParameters; - public AggregatorImplementer(Elements elements, TypeElement declarationType, IntermediateState[] interStateAnno) { + public AggregatorImplementer( + Elements elements, + TypeElement declarationType, + IntermediateState[] interStateAnno, + List warnExceptions + ) { this.declarationType = declarationType; + this.warnExceptions = warnExceptions; this.init = findRequiredMethod(declarationType, new String[] { "init", "initSingle" }, e -> true); this.stateType = choseStateType(); @@ -202,6 +211,11 @@ private TypeSpec type() { .initializer(initInterState()) .build() ); + + if (warnExceptions.isEmpty() == false) { + builder.addField(COMPUTE_WARNINGS, "warnings", Modifier.PRIVATE, Modifier.FINAL); + } + builder.addField(DRIVER_CONTEXT, "driverContext", Modifier.PRIVATE, Modifier.FINAL); builder.addField(stateType, "state", Modifier.PRIVATE, Modifier.FINAL); builder.addField(LIST_INTEGER, "channels", Modifier.PRIVATE, Modifier.FINAL); @@ -228,17 +242,22 @@ private TypeSpec type() { private MethodSpec create() { MethodSpec.Builder builder = MethodSpec.methodBuilder("create"); builder.addModifiers(Modifier.PUBLIC, Modifier.STATIC).returns(implementation); + if (warnExceptions.isEmpty() == false) { + builder.addParameter(COMPUTE_WARNINGS, "warnings"); + } builder.addParameter(DRIVER_CONTEXT, "driverContext"); builder.addParameter(LIST_INTEGER, "channels"); for (Parameter p : createParameters) { builder.addParameter(p.type(), p.name()); } if (createParameters.isEmpty()) { - builder.addStatement("return new $T(driverContext, channels, $L)", implementation, callInit()); + builder.addStatement("return new $T($LdriverContext, channels, $L)", implementation, + warnExceptions.isEmpty() ? "" : "warnings, ", callInit()); } else { builder.addStatement( - "return new $T(driverContext, channels, $L, $L)", + "return new $T($LdriverContext, channels, $L, $L)", implementation, + warnExceptions.isEmpty() ? "" : "warnings, ", callInit(), createParameters.stream().map(p -> p.name()).collect(joining(", ")) ); @@ -275,16 +294,22 @@ private CodeBlock initInterState() { private MethodSpec ctor() { MethodSpec.Builder builder = MethodSpec.constructorBuilder().addModifiers(Modifier.PUBLIC); + if (warnExceptions.isEmpty() == false) { + builder.addParameter(COMPUTE_WARNINGS, "warnings"); + } builder.addParameter(DRIVER_CONTEXT, "driverContext"); builder.addParameter(LIST_INTEGER, "channels"); builder.addParameter(stateType, "state"); + builder.addStatement("this.driverContext = driverContext"); + if (warnExceptions.isEmpty() == false) { + builder.addStatement("this.warnings = warnings"); + } builder.addStatement("this.channels = channels"); builder.addStatement("this.state = state"); for (Parameter p : createParameters()) { - builder.addParameter(p.type(), p.name()); - builder.addStatement("this.$N = $N", p.name(), p.name()); + p.buildCtor(builder); } return builder.build(); } diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorProcessor.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorProcessor.java index d07b24047b7e2..0adcd4d8a6cc3 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorProcessor.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorProcessor.java @@ -10,11 +10,15 @@ import com.squareup.javapoet.JavaFile; import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.ConvertEvaluator; +import org.elasticsearch.compute.ann.Evaluator; import org.elasticsearch.compute.ann.GroupingAggregator; import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.ann.MvEvaluator; import java.io.IOException; import java.io.Writer; +import java.util.ArrayList; import java.util.Collections; import java.util.IdentityHashMap; import java.util.List; @@ -27,9 +31,11 @@ import javax.annotation.processing.RoundEnvironment; import javax.lang.model.SourceVersion; import javax.lang.model.element.AnnotationMirror; +import javax.lang.model.element.AnnotationValue; import javax.lang.model.element.Element; import javax.lang.model.element.ExecutableElement; import javax.lang.model.element.TypeElement; +import javax.lang.model.type.TypeMirror; import javax.tools.Diagnostic; import javax.tools.JavaFileObject; @@ -80,9 +86,10 @@ public boolean process(Set set, RoundEnvironment roundEnv } for (TypeElement aggClass : annotatedClasses) { AggregatorImplementer implementer = null; + var warnExceptionsTypes = warnExceptions(aggClass); if (aggClass.getAnnotation(Aggregator.class) != null) { IntermediateState[] intermediateState = aggClass.getAnnotation(Aggregator.class).value(); - implementer = new AggregatorImplementer(env.getElementUtils(), aggClass, intermediateState); + implementer = new AggregatorImplementer(env.getElementUtils(), aggClass, intermediateState, warnExceptionsTypes); write(aggClass, "aggregator", implementer.sourceFile(), env); } GroupingAggregatorImplementer groupingAggregatorImplementer = null; @@ -104,7 +111,13 @@ public boolean process(Set set, RoundEnvironment roundEnv write( aggClass, "aggregator function supplier", - new AggregatorFunctionSupplierImplementer(env.getElementUtils(), aggClass, implementer, groupingAggregatorImplementer) + new AggregatorFunctionSupplierImplementer( + env.getElementUtils(), + aggClass, + implementer, + groupingAggregatorImplementer, + warnExceptionsTypes.isEmpty() == false + ) .sourceFile(), env ); @@ -133,4 +146,24 @@ public static void write(Object origination, String what, JavaFile file, Process throw new RuntimeException(e); } } + + private static List warnExceptions(Element aggregatorMethod) { + List result = new ArrayList<>(); + for (var mirror : aggregatorMethod.getAnnotationMirrors()) { + String annotationType = mirror.getAnnotationType().toString(); + if (annotationType.equals(Aggregator.class.getName()) + || annotationType.equals(GroupingAggregator.class.getName())) { + + for (var e : mirror.getElementValues().entrySet()) { + if (false == e.getKey().getSimpleName().toString().equals("warnExceptions")) { + continue; + } + for (var v : (List) e.getValue().getValue()) { + result.add((TypeMirror) ((AnnotationValue) v).getValue()); + } + } + } + } + return result; + } } diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Types.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Types.java index 3150741ddcb05..55a1c36895a5e 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Types.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Types.java @@ -27,6 +27,8 @@ public class Types { private static final String OPERATOR_PACKAGE = PACKAGE + ".operator"; private static final String DATA_PACKAGE = PACKAGE + ".data"; + static final TypeName STRING = ClassName.get("java.lang", "String"); + static final TypeName LIST_INTEGER = ParameterizedTypeName.get(ClassName.get(List.class), TypeName.INT.box()); static final ClassName PAGE = ClassName.get(DATA_PACKAGE, "Page"); @@ -127,6 +129,12 @@ public class Types { ); static final ClassName WARNINGS = ClassName.get("org.elasticsearch.xpack.esql.expression.function", "Warnings"); + /** + * Warnings class used in compute module. + * It uses no external dependencies (Like Warnings and Source). + */ + static final ClassName COMPUTE_WARNINGS = ClassName.get("org.elasticsearch.compute.aggregation", "Warnings"); + static final ClassName SOURCE = ClassName.get("org.elasticsearch.xpack.esql.core.tree", "Source"); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunction.java index 38d1b3de78265..8531da13d4c21 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunction.java @@ -27,22 +27,25 @@ public final class SumLongAggregatorFunction implements AggregatorFunction { new IntermediateStateDesc("sum", ElementType.LONG), new IntermediateStateDesc("seen", ElementType.BOOLEAN) ); + private final Warnings warnings; + private final DriverContext driverContext; private final LongState state; private final List channels; - public SumLongAggregatorFunction(DriverContext driverContext, List channels, - LongState state) { + public SumLongAggregatorFunction(Warnings warnings, DriverContext driverContext, + List channels, LongState state) { this.driverContext = driverContext; + this.warnings = warnings; this.channels = channels; this.state = state; } - public static SumLongAggregatorFunction create(DriverContext driverContext, + public static SumLongAggregatorFunction create(Warnings warnings, DriverContext driverContext, List channels) { - return new SumLongAggregatorFunction(driverContext, channels, new LongState(SumLongAggregator.init())); + return new SumLongAggregatorFunction(warnings, driverContext, channels, new LongState(SumLongAggregator.init())); } public static List intermediateStateDesc() { diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunctionSupplier.java index b4d36aa526075..8a979111627fd 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunctionSupplier.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunctionSupplier.java @@ -15,15 +15,26 @@ * This class is generated. Do not edit it. */ public final class SumLongAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + int warningsLineNumber; + + int warningsColumnNumber; + + String warningsSourceText; + private final List channels; - public SumLongAggregatorFunctionSupplier(List channels) { + public SumLongAggregatorFunctionSupplier(int warningsLineNumber, int warningsColumnNumber, + String warningsSourceText, List channels) { + this.warningsLineNumber = warningsLineNumber; + this.warningsColumnNumber = warningsColumnNumber; + this.warningsSourceText = warningsSourceText; this.channels = channels; } @Override public SumLongAggregatorFunction aggregator(DriverContext driverContext) { - return SumLongAggregatorFunction.create(driverContext, channels); + var warnings = Warnings.createWarnings(driverContext.warningsMode(), warningsLineNumber, warningsColumnNumber, warningsSourceText); + return SumLongAggregatorFunction.create(warnings, driverContext, channels); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SumLongAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SumLongAggregator.java index cd6a94e518be8..01db4b4912e36 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SumLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SumLongAggregator.java @@ -11,7 +11,10 @@ import org.elasticsearch.compute.ann.GroupingAggregator; import org.elasticsearch.compute.ann.IntermediateState; -@Aggregator({ @IntermediateState(name = "sum", type = "LONG"), @IntermediateState(name = "seen", type = "BOOLEAN") }) +@Aggregator( + value = { @IntermediateState(name = "sum", type = "LONG"), @IntermediateState(name = "seen", type = "BOOLEAN") }, + warnExceptions = ArithmeticException.class +) @GroupingAggregator class SumLongAggregator { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/Warnings.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/Warnings.java new file mode 100644 index 0000000000000..a78e9dd8d62af --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/Warnings.java @@ -0,0 +1,74 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.operator.DriverContext; +import static org.elasticsearch.common.logging.LoggerMessageFormat.format; + +import static org.elasticsearch.common.logging.HeaderWarning.addWarning; + +/** + * Utilities to collect warnings for running an executor. + */ +public class Warnings { + static final int MAX_ADDED_WARNINGS = 20; + + private final String location; + private final String first; + + private int addedWarnings; + + public static final Warnings NOOP_WARNINGS = new Warnings(-1, -2, "") { + @Override + public void registerException(Exception exception) { + // this space intentionally left blank + } + }; + + /** + * Create a new warnings object based on the given mode + * @param warningsMode The warnings collection strategy to use + * @param lineNumber The line number of the source text. Same as `source.getLineNumber()` + * @param columnNumber The column number of the source text. Same as `source.getColumnNumber()` + * @param sourceText The source text that caused the warning. Same as `source.text()` + * @return A warnings collector object + */ + public static Warnings createWarnings(DriverContext.WarningsMode warningsMode, int lineNumber, int columnNumber, String sourceText) { + switch (warningsMode) { + case COLLECT -> { + return new Warnings(lineNumber, columnNumber, sourceText); + } + case IGNORE -> { + return NOOP_WARNINGS; + } + } + throw new IllegalStateException("Unreachable"); + } + + public Warnings(int lineNumber, int columnNumber, String sourceText) { + location = format("Line {}:{}: ", lineNumber, columnNumber); + first = format( + null, + "{}evaluation of [{}] failed, treating result as null. Only first {} failures recorded.", + location, + sourceText, + MAX_ADDED_WARNINGS + ); + } + + public void registerException(Exception exception) { + if (addedWarnings < MAX_ADDED_WARNINGS) { + if (addedWarnings == 0) { + addWarning(first); + } + // location needs to be added to the exception too, since the headers are deduplicated + addWarning(location + exception.getClass().getName() + ": " + exception.getMessage()); + addedWarnings++; + } + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java index 4f85a15732a6f..edde2f7991a62 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java @@ -12,6 +12,7 @@ import org.elasticsearch.compute.aggregation.SumDoubleAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.SumIntAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.SumLongAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.Warnings; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; @@ -83,7 +84,8 @@ public DataType dataType() { @Override protected AggregatorFunctionSupplier longSupplier(List inputChannels) { - return new SumLongAggregatorFunctionSupplier(inputChannels); + var location = source().source(); + return new SumLongAggregatorFunctionSupplier(location.getLineNumber(), location.getColumnNumber(), source().text(), inputChannels); } @Override From 4937082bff152fdb48b4491814926c7b9349ea87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Tue, 6 Aug 2024 16:20:36 +0200 Subject: [PATCH 05/31] Completed single aggregator implementation --- .../compute/gen/AggregatorImplementer.java | 118 ++++++++++++------ .../SumLongAggregatorFunction.java | 37 ++++-- .../SumLongGroupingAggregatorFunction.java | 10 +- .../aggregation/SumLongAggregator.java | 6 +- 4 files changed, 122 insertions(+), 49 deletions(-) diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java index f302737ea79f2..47aab77190288 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java @@ -21,6 +21,10 @@ import java.util.Arrays; import java.util.List; import java.util.Locale; +import java.util.Objects; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import java.util.stream.Stream; import javax.lang.model.element.ExecutableElement; import javax.lang.model.element.Modifier; @@ -79,6 +83,7 @@ public class AggregatorImplementer { private final ClassName implementation; private final TypeName stateType; private final boolean stateTypeHasSeen; + private final boolean stateTypeHasFailed; private final boolean valuesIsBytesRef; private final List intermediateState; private final List createParameters; @@ -94,9 +99,12 @@ public AggregatorImplementer( this.init = findRequiredMethod(declarationType, new String[] { "init", "initSingle" }, e -> true); this.stateType = choseStateType(); - stateTypeHasSeen = elements.getAllMembers(elements.getTypeElement(stateType.toString())) + this.stateTypeHasSeen = elements.getAllMembers(elements.getTypeElement(stateType.toString())) .stream() .anyMatch(e -> e.toString().equals("seen()")); + this.stateTypeHasFailed = elements.getAllMembers(elements.getTypeElement(stateType.toString())) + .stream() + .anyMatch(e -> e.toString().equals("failed()")); this.combine = findRequiredMethod(declarationType, new String[] { "combine" }, e -> { if (e.getParameters().size() == 0) { @@ -135,7 +143,10 @@ private TypeName choseStateType() { if (false == initReturn.isPrimitive()) { return initReturn; } - return ClassName.get("org.elasticsearch.compute.aggregation", firstUpper(initReturn.toString()) + "State"); + if (warnExceptions.isEmpty()) { + return ClassName.get("org.elasticsearch.compute.aggregation", firstUpper(initReturn.toString()) + "State"); + } + return ClassName.get("org.elasticsearch.compute.aggregation", firstUpper(initReturn.toString()) + "FallibleState"); } static String valueType(ExecutableElement init, ExecutableElement combine) { @@ -391,20 +402,28 @@ private MethodSpec addRawBlock() { } private void combineRawInput(MethodSpec.Builder builder, String blockVariable) { + TypeName returnType = TypeName.get(combine.getReturnType()); + if (warnExceptions.isEmpty() == false) { + builder.beginControlFlow("try"); + } if (valuesIsBytesRef) { combineRawInputForBytesRef(builder, blockVariable); - return; - } - TypeName returnType = TypeName.get(combine.getReturnType()); - if (returnType.isPrimitive()) { + } else if (returnType.isPrimitive()) { combineRawInputForPrimitive(returnType, builder, blockVariable); - return; - } - if (returnType == TypeName.VOID) { + } else if (returnType == TypeName.VOID) { combineRawInputForVoid(builder, blockVariable); - return; + } else { + throw new IllegalArgumentException("combine must return void or a primitive"); + } + if (warnExceptions.isEmpty() == false) { + String catchPattern = "catch (" + + warnExceptions.stream().map(m -> "$T").collect(Collectors.joining(" | ")) + + " e)"; + builder.nextControlFlow(catchPattern, warnExceptions.stream().map(TypeName::get).toArray()); + builder.addStatement("warnings.registerException(e)"); + builder.addStatement("state.failed(true)"); + builder.endControlFlow(); } - throw new IllegalArgumentException("combine must return void or a primitive"); } private void combineRawInputForPrimitive(TypeName returnType, MethodSpec.Builder builder, String blockVariable) { @@ -448,15 +467,34 @@ private MethodSpec addIntermediateInput() { } builder.addStatement("$T.combineIntermediate(state, " + intermediateStateRowAccess() + ")", declarationType); } else if (hasPrimitiveState()) { - assert intermediateState.size() == 2; - assert intermediateState.get(1).name().equals("seen"); - builder.beginControlFlow("if (seen.getBoolean(0))"); - { - var state = intermediateState.get(0); - var s = "state.$L($T.combine(state.$L(), " + state.name() + "." + vectorAccessorName(state.elementType()) + "(0)))"; - builder.addStatement(s, primitiveStateMethod(), declarationType, primitiveStateMethod()); - builder.addStatement("state.seen(true)"); - builder.endControlFlow(); + if (warnExceptions.isEmpty()) { + assert intermediateState.size() == 2; + assert intermediateState.get(1).name().equals("seen"); + builder.beginControlFlow("if (seen.getBoolean(0))"); + { + var state = intermediateState.get(0); + var s = "state.$L($T.combine(state.$L(), " + state.name() + "." + vectorAccessorName(state.elementType()) + "(0)))"; + builder.addStatement(s, primitiveStateMethod(), declarationType, primitiveStateMethod()); + builder.addStatement("state.seen(true)"); + builder.endControlFlow(); + } + } else { + assert intermediateState.size() == 3; + assert intermediateState.get(1).name().equals("seen"); + assert intermediateState.get(2).name().equals("failed"); + builder.beginControlFlow("if (failed.getBoolean(0))"); + { + builder.addStatement("state.failed(true)"); + builder.addStatement("state.seen(true)"); + } + builder.nextControlFlow("else if (seen.getBoolean(0))"); + { + var state = intermediateState.get(0); + var s = "state.$L($T.combine(state.$L(), " + state.name() + "." + vectorAccessorName(state.elementType()) + "(0)))"; + builder.addStatement(s, primitiveStateMethod(), declarationType, primitiveStateMethod()); + builder.addStatement("state.seen(true)"); + builder.endControlFlow(); + } } } else { throw new IllegalArgumentException("Don't know how to combine intermediate input. Define combineIntermediate"); @@ -470,15 +508,15 @@ String intermediateStateRowAccess() { private String primitiveStateMethod() { switch (stateType.toString()) { - case "org.elasticsearch.compute.aggregation.BooleanState": + case "org.elasticsearch.compute.aggregation.BooleanState", "org.elasticsearch.compute.aggregation.BooleanFallibleState": return "booleanValue"; - case "org.elasticsearch.compute.aggregation.IntState": + case "org.elasticsearch.compute.aggregation.IntState", "org.elasticsearch.compute.aggregation.IntFallibleState": return "intValue"; - case "org.elasticsearch.compute.aggregation.LongState": + case "org.elasticsearch.compute.aggregation.LongState", "org.elasticsearch.compute.aggregation.LongFallibleState": return "longValue"; - case "org.elasticsearch.compute.aggregation.DoubleState": + case "org.elasticsearch.compute.aggregation.DoubleState", "org.elasticsearch.compute.aggregation.DoubleFallibleState": return "doubleValue"; - case "org.elasticsearch.compute.aggregation.FloatState": + case "org.elasticsearch.compute.aggregation.FloatState", "org.elasticsearch.compute.aggregation.FloatFallibleState": return "floatValue"; default: throw new IllegalArgumentException( @@ -505,8 +543,14 @@ private MethodSpec evaluateFinal() { .addParameter(BLOCK_ARRAY, "blocks") .addParameter(TypeName.INT, "offset") .addParameter(DRIVER_CONTEXT, "driverContext"); - if (stateTypeHasSeen) { - builder.beginControlFlow("if (state.seen() == false)"); + if (stateTypeHasSeen || stateTypeHasFailed) { + var condition = Stream.of( + stateTypeHasSeen ? "state.seen() == false" : null, + stateTypeHasFailed ? "state.failed()" : null + ) + .filter(Objects::nonNull) + .collect(joining(" || ")); + builder.beginControlFlow("if ($L)", condition); builder.addStatement("blocks[offset] = driverContext.blockFactory().newConstantNullBlock(1)", BLOCK); builder.addStatement("return"); builder.endControlFlow(); @@ -521,19 +565,19 @@ private MethodSpec evaluateFinal() { private void primitiveStateToResult(MethodSpec.Builder builder) { switch (stateType.toString()) { - case "org.elasticsearch.compute.aggregation.BooleanState": + case "org.elasticsearch.compute.aggregation.BooleanState", "org.elasticsearch.compute.aggregation.BooleanFallibleState": builder.addStatement("blocks[offset] = driverContext.blockFactory().newConstantBooleanBlockWith(state.booleanValue(), 1)"); return; - case "org.elasticsearch.compute.aggregation.IntState": + case "org.elasticsearch.compute.aggregation.IntState", "org.elasticsearch.compute.aggregation.IntFallibleState": builder.addStatement("blocks[offset] = driverContext.blockFactory().newConstantIntBlockWith(state.intValue(), 1)"); return; - case "org.elasticsearch.compute.aggregation.LongState": + case "org.elasticsearch.compute.aggregation.LongState", "org.elasticsearch.compute.aggregation.LongFallibleState": builder.addStatement("blocks[offset] = driverContext.blockFactory().newConstantLongBlockWith(state.longValue(), 1)"); return; - case "org.elasticsearch.compute.aggregation.DoubleState": + case "org.elasticsearch.compute.aggregation.DoubleState", "org.elasticsearch.compute.aggregation.DoubleFallibleState": builder.addStatement("blocks[offset] = driverContext.blockFactory().newConstantDoubleBlockWith(state.doubleValue(), 1)"); return; - case "org.elasticsearch.compute.aggregation.FloatState": + case "org.elasticsearch.compute.aggregation.FloatState", "org.elasticsearch.compute.aggregation.FloatFallibleState": builder.addStatement("blocks[offset] = driverContext.blockFactory().newConstantFloatBlockWith(state.floatValue(), 1)"); return; default: @@ -559,13 +603,11 @@ private MethodSpec close() { return builder.build(); } + private static final Pattern PRIMITIVE_STATE_PATTERN = Pattern.compile( + "org.elasticsearch.compute.aggregation.(Boolean|Int|Long|Double|Float)(Fallible)?State" + ); private boolean hasPrimitiveState() { - return switch (stateType.toString()) { - case "org.elasticsearch.compute.aggregation.BooleanState", "org.elasticsearch.compute.aggregation.IntState", - "org.elasticsearch.compute.aggregation.LongState", "org.elasticsearch.compute.aggregation.DoubleState", - "org.elasticsearch.compute.aggregation.FloatState" -> true; - default -> false; - }; + return PRIMITIVE_STATE_PATTERN.matcher(stateType.toString()).matches(); } record IntermediateStateDesc(String name, String elementType, boolean block) { diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunction.java index 8531da13d4c21..649e478beb9bd 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunction.java @@ -4,6 +4,7 @@ // 2.0. package org.elasticsearch.compute.aggregation; +import java.lang.ArithmeticException; import java.lang.Integer; import java.lang.Override; import java.lang.String; @@ -25,18 +26,19 @@ public final class SumLongAggregatorFunction implements AggregatorFunction { private static final List INTERMEDIATE_STATE_DESC = List.of( new IntermediateStateDesc("sum", ElementType.LONG), - new IntermediateStateDesc("seen", ElementType.BOOLEAN) ); + new IntermediateStateDesc("seen", ElementType.BOOLEAN), + new IntermediateStateDesc("failed", ElementType.BOOLEAN) ); private final Warnings warnings; private final DriverContext driverContext; - private final LongState state; + private final LongFallibleState state; private final List channels; public SumLongAggregatorFunction(Warnings warnings, DriverContext driverContext, - List channels, LongState state) { + List channels, LongFallibleState state) { this.driverContext = driverContext; this.warnings = warnings; this.channels = channels; @@ -45,7 +47,7 @@ public SumLongAggregatorFunction(Warnings warnings, DriverContext driverContext, public static SumLongAggregatorFunction create(Warnings warnings, DriverContext driverContext, List channels) { - return new SumLongAggregatorFunction(warnings, driverContext, channels, new LongState(SumLongAggregator.init())); + return new SumLongAggregatorFunction(warnings, driverContext, channels, new LongFallibleState(SumLongAggregator.init())); } public static List intermediateStateDesc() { @@ -71,7 +73,12 @@ public void addRawInput(Page page) { private void addRawVector(LongVector vector) { state.seen(true); for (int i = 0; i < vector.getPositionCount(); i++) { - state.longValue(SumLongAggregator.combine(state.longValue(), vector.getLong(i))); + try { + state.longValue(SumLongAggregator.combine(state.longValue(), vector.getLong(i))); + } catch (ArithmeticException e) { + warnings.registerException(e); + state.failed(true); + } } } @@ -84,7 +91,12 @@ private void addRawBlock(LongBlock block) { int start = block.getFirstValueIndex(p); int end = start + block.getValueCount(p); for (int i = start; i < end; i++) { - state.longValue(SumLongAggregator.combine(state.longValue(), block.getLong(i))); + try { + state.longValue(SumLongAggregator.combine(state.longValue(), block.getLong(i))); + } catch (ArithmeticException e) { + warnings.registerException(e); + state.failed(true); + } } } } @@ -105,7 +117,16 @@ public void addIntermediateInput(Page page) { } BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); assert seen.getPositionCount() == 1; - if (seen.getBoolean(0)) { + Block failedUncast = page.getBlock(channels.get(2)); + if (failedUncast.areAllValuesNull()) { + return; + } + BooleanVector failed = ((BooleanBlock) failedUncast).asVector(); + assert failed.getPositionCount() == 1; + if (failed.getBoolean(0)) { + state.failed(true); + state.seen(true); + } else if (seen.getBoolean(0)) { state.longValue(SumLongAggregator.combine(state.longValue(), sum.getLong(0))); state.seen(true); } @@ -118,7 +139,7 @@ public void evaluateIntermediate(Block[] blocks, int offset, DriverContext drive @Override public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { - if (state.seen() == false) { + if (state.seen() == false || state.failed()) { blocks[offset] = driverContext.blockFactory().newConstantNullBlock(1); return; } diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunction.java index 507aa343aa74e..774419e96666e 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunction.java @@ -27,7 +27,8 @@ public final class SumLongGroupingAggregatorFunction implements GroupingAggregatorFunction { private static final List INTERMEDIATE_STATE_DESC = List.of( new IntermediateStateDesc("sum", ElementType.LONG), - new IntermediateStateDesc("seen", ElementType.BOOLEAN) ); + new IntermediateStateDesc("seen", ElementType.BOOLEAN), + new IntermediateStateDesc("failed", ElementType.BOOLEAN) ); private final LongArrayState state; @@ -160,7 +161,12 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page return; } BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); - assert sum.getPositionCount() == seen.getPositionCount(); + Block failedUncast = page.getBlock(channels.get(2)); + if (failedUncast.areAllValuesNull()) { + return; + } + BooleanVector failed = ((BooleanBlock) failedUncast).asVector(); + assert sum.getPositionCount() == seen.getPositionCount() && sum.getPositionCount() == failed.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { int groupId = Math.toIntExact(groups.getInt(groupPosition)); if (seen.getBoolean(groupPosition + positionOffset)) { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SumLongAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SumLongAggregator.java index 01db4b4912e36..754b916f3b8fe 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SumLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SumLongAggregator.java @@ -12,7 +12,11 @@ import org.elasticsearch.compute.ann.IntermediateState; @Aggregator( - value = { @IntermediateState(name = "sum", type = "LONG"), @IntermediateState(name = "seen", type = "BOOLEAN") }, + value = { + @IntermediateState(name = "sum", type = "LONG"), + @IntermediateState(name = "seen", type = "BOOLEAN"), + @IntermediateState(name = "failed", type = "BOOLEAN") + }, warnExceptions = ArithmeticException.class ) @GroupingAggregator From 7935330c3fbd17016cfe00948d279d56862a8b39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Tue, 6 Aug 2024 16:40:41 +0200 Subject: [PATCH 06/31] Updated SumTests and format --- ...AggregatorFunctionSupplierImplementer.java | 14 +++--- .../compute/gen/AggregatorImplementer.java | 18 +++---- .../compute/gen/AggregatorProcessor.java | 9 +--- .../org/elasticsearch/compute/gen/Types.java | 1 - .../aggregation/SumLongAggregator.java | 3 +- .../compute/aggregation/Warnings.java | 2 +- .../expression/function/aggregate/Sum.java | 1 - .../expression/function/TestCaseSupplier.java | 22 +++++++++ .../function/aggregate/SumTests.java | 47 ++++++++----------- 9 files changed, 61 insertions(+), 56 deletions(-) diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorFunctionSupplierImplementer.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorFunctionSupplierImplementer.java index e43a26e89cb48..e09ecf657bf01 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorFunctionSupplierImplementer.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorFunctionSupplierImplementer.java @@ -138,21 +138,19 @@ private MethodSpec aggregator() { builder.returns(aggregatorImplementer.implementation()); if (hasWarnings) { - builder.addStatement("var warnings = Warnings.createWarnings(driverContext.warningsMode(), " + - "warningsLineNumber, warningsColumnNumber, warningsSourceText)"); + builder.addStatement( + "var warnings = Warnings.createWarnings(driverContext.warningsMode(), " + + "warningsLineNumber, warningsColumnNumber, warningsSourceText)" + ); } builder.addStatement( "return $T.create($L)", aggregatorImplementer.implementation(), Stream.concat( - Stream.concat( - hasWarnings ? Stream.of("warnings") : Stream.of(), - Stream.of("driverContext, channels") - ), + Stream.concat(hasWarnings ? Stream.of("warnings") : Stream.of(), Stream.of("driverContext, channels")), aggregatorImplementer.createParameters().stream().map(Parameter::name) - ) - .collect(Collectors.joining(", ")) + ).collect(Collectors.joining(", ")) ); return builder.build(); diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java index 47aab77190288..9119309269646 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java @@ -262,8 +262,12 @@ private MethodSpec create() { builder.addParameter(p.type(), p.name()); } if (createParameters.isEmpty()) { - builder.addStatement("return new $T($LdriverContext, channels, $L)", implementation, - warnExceptions.isEmpty() ? "" : "warnings, ", callInit()); + builder.addStatement( + "return new $T($LdriverContext, channels, $L)", + implementation, + warnExceptions.isEmpty() ? "" : "warnings, ", + callInit() + ); } else { builder.addStatement( "return new $T($LdriverContext, channels, $L, $L)", @@ -416,9 +420,7 @@ private void combineRawInput(MethodSpec.Builder builder, String blockVariable) { throw new IllegalArgumentException("combine must return void or a primitive"); } if (warnExceptions.isEmpty() == false) { - String catchPattern = "catch (" - + warnExceptions.stream().map(m -> "$T").collect(Collectors.joining(" | ")) - + " e)"; + String catchPattern = "catch (" + warnExceptions.stream().map(m -> "$T").collect(Collectors.joining(" | ")) + " e)"; builder.nextControlFlow(catchPattern, warnExceptions.stream().map(TypeName::get).toArray()); builder.addStatement("warnings.registerException(e)"); builder.addStatement("state.failed(true)"); @@ -544,10 +546,7 @@ private MethodSpec evaluateFinal() { .addParameter(TypeName.INT, "offset") .addParameter(DRIVER_CONTEXT, "driverContext"); if (stateTypeHasSeen || stateTypeHasFailed) { - var condition = Stream.of( - stateTypeHasSeen ? "state.seen() == false" : null, - stateTypeHasFailed ? "state.failed()" : null - ) + var condition = Stream.of(stateTypeHasSeen ? "state.seen() == false" : null, stateTypeHasFailed ? "state.failed()" : null) .filter(Objects::nonNull) .collect(joining(" || ")); builder.beginControlFlow("if ($L)", condition); @@ -606,6 +605,7 @@ private MethodSpec close() { private static final Pattern PRIMITIVE_STATE_PATTERN = Pattern.compile( "org.elasticsearch.compute.aggregation.(Boolean|Int|Long|Double|Float)(Fallible)?State" ); + private boolean hasPrimitiveState() { return PRIMITIVE_STATE_PATTERN.matcher(stateType.toString()).matches(); } diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorProcessor.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorProcessor.java index 0adcd4d8a6cc3..9c21af1d75b20 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorProcessor.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorProcessor.java @@ -10,11 +10,8 @@ import com.squareup.javapoet.JavaFile; import org.elasticsearch.compute.ann.Aggregator; -import org.elasticsearch.compute.ann.ConvertEvaluator; -import org.elasticsearch.compute.ann.Evaluator; import org.elasticsearch.compute.ann.GroupingAggregator; import org.elasticsearch.compute.ann.IntermediateState; -import org.elasticsearch.compute.ann.MvEvaluator; import java.io.IOException; import java.io.Writer; @@ -117,8 +114,7 @@ public boolean process(Set set, RoundEnvironment roundEnv implementer, groupingAggregatorImplementer, warnExceptionsTypes.isEmpty() == false - ) - .sourceFile(), + ).sourceFile(), env ); } @@ -151,8 +147,7 @@ private static List warnExceptions(Element aggregatorMethod) { List result = new ArrayList<>(); for (var mirror : aggregatorMethod.getAnnotationMirrors()) { String annotationType = mirror.getAnnotationType().toString(); - if (annotationType.equals(Aggregator.class.getName()) - || annotationType.equals(GroupingAggregator.class.getName())) { + if (annotationType.equals(Aggregator.class.getName()) || annotationType.equals(GroupingAggregator.class.getName())) { for (var e : mirror.getElementValues().entrySet()) { if (false == e.getKey().getSimpleName().toString().equals("warnExceptions")) { diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Types.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Types.java index 55a1c36895a5e..dd048460d2d13 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Types.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Types.java @@ -135,7 +135,6 @@ public class Types { */ static final ClassName COMPUTE_WARNINGS = ClassName.get("org.elasticsearch.compute.aggregation", "Warnings"); - static final ClassName SOURCE = ClassName.get("org.elasticsearch.xpack.esql.core.tree", "Source"); static final ClassName BYTES_REF = ClassName.get("org.apache.lucene.util", "BytesRef"); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SumLongAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SumLongAggregator.java index 754b916f3b8fe..178e4915022e1 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SumLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SumLongAggregator.java @@ -15,8 +15,7 @@ value = { @IntermediateState(name = "sum", type = "LONG"), @IntermediateState(name = "seen", type = "BOOLEAN"), - @IntermediateState(name = "failed", type = "BOOLEAN") - }, + @IntermediateState(name = "failed", type = "BOOLEAN") }, warnExceptions = ArithmeticException.class ) @GroupingAggregator diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/Warnings.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/Warnings.java index a78e9dd8d62af..eb2255a4e349b 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/Warnings.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/Warnings.java @@ -8,9 +8,9 @@ package org.elasticsearch.compute.aggregation; import org.elasticsearch.compute.operator.DriverContext; -import static org.elasticsearch.common.logging.LoggerMessageFormat.format; import static org.elasticsearch.common.logging.HeaderWarning.addWarning; +import static org.elasticsearch.common.logging.LoggerMessageFormat.format; /** * Utilities to collect warnings for running an executor. diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java index edde2f7991a62..20c1a77a2301e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java @@ -12,7 +12,6 @@ import org.elasticsearch.compute.aggregation.SumDoubleAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.SumIntAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.SumLongAggregatorFunctionSupplier; -import org.elasticsearch.compute.aggregation.Warnings; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java index 6652cca0c4527..b9754c1c57513 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java @@ -1418,6 +1418,28 @@ public TestCase withWarning(String warning) { ); } + public TestCase withWarnings(List warnings) { + String[] newWarnings; + if (expectedWarnings != null) { + newWarnings = Arrays.copyOf(expectedWarnings, expectedWarnings.length + warnings.size()); + for (int i = 0; i < warnings.size(); i++) { + newWarnings[expectedWarnings.length + i] = warnings.get(i); + } + } else { + newWarnings = warnings.toArray(String[]::new); + } + return new TestCase( + data, + evaluatorToString, + expectedType, + matcher, + newWarnings, + expectedTypeError, + foldingExceptionClass, + foldingExceptionMessage + ); + } + public TestCase withFoldingException(Class clazz, String message) { return new TestCase(data, evaluatorToString, expectedType, matcher, expectedWarnings, expectedTypeError, clazz, message); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumTests.java index 4f14dafc8b30d..88db21e69dc5f 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumTests.java @@ -10,6 +10,7 @@ import com.carrotsearch.randomizedtesting.annotations.Name; import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import org.elasticsearch.search.aggregations.metrics.CompensatedSum; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -35,13 +36,12 @@ public SumTests(@Name("TestCase") Supplier testCaseSu public static Iterable parameters() { var suppliers = new ArrayList(); - Stream.of(MultiRowTestCaseSupplier.intCases(1, 1000, Integer.MIN_VALUE, Integer.MAX_VALUE, true) - // Longs currently throw on overflow. - // Restore after https://github.com/elastic/elasticsearch/issues/110437 - // MultiRowTestCaseSupplier.longCases(1, 1000, Long.MIN_VALUE, Long.MAX_VALUE, true), - // Doubles currently return +/-Infinity on overflow. - // Restore after https://github.com/elastic/elasticsearch/issues/111026 - // MultiRowTestCaseSupplier.doubleCases(1, 1000, -Double.MAX_VALUE, Double.MAX_VALUE, true) + Stream.of( + MultiRowTestCaseSupplier.intCases(1, 1000, Integer.MIN_VALUE, Integer.MAX_VALUE, true), + MultiRowTestCaseSupplier.longCases(1, 1000, Long.MIN_VALUE, Long.MAX_VALUE, true) + // Doubles currently return +/-Infinity on overflow. + // Restore after https://github.com/elastic/elasticsearch/issues/111026 + // MultiRowTestCaseSupplier.doubleCases(1, 1000, -Double.MAX_VALUE, Double.MAX_VALUE, true) ).flatMap(List::stream).map(SumTests::makeSupplier).collect(Collectors.toCollection(() -> suppliers)); suppliers.addAll( @@ -89,44 +89,37 @@ private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier return new TestCaseSupplier(List.of(fieldSupplier.type()), () -> { var fieldTypedData = fieldSupplier.get(); + var expectedWarnings = new ArrayList(); Object expected; try { expected = switch (fieldTypedData.type().widenSmallNumeric()) { case INTEGER -> fieldTypedData.multiRowData() .stream() - .map(v -> (Integer) v) - .collect(Collectors.summarizingInt(Integer::intValue)) - .getSum(); - case LONG -> fieldTypedData.multiRowData() - .stream() - .map(v -> (Long) v) - .collect(Collectors.summarizingLong(Long::longValue)) - .getSum(); + .map(v -> ((Integer) v).longValue()) + .reduce(Math::addExact) + .orElse(null); + case LONG -> fieldTypedData.multiRowData().stream().map(v -> (Long) v).reduce(Math::addExact).orElse(null); case DOUBLE -> { - var value = fieldTypedData.multiRowData() - .stream() - .map(v -> (Double) v) - .collect(Collectors.summarizingDouble(Double::doubleValue)) - .getSum(); - - if (Double.isInfinite(value) || Double.isNaN(value)) { - yield null; - } + var sum = new CompensatedSum(); + fieldTypedData.multiRowData().stream().map(v -> (Double) v).forEach(sum::add); - yield value; + yield Double.isFinite(sum.value()) ? sum.value() : null; } default -> throw new IllegalStateException("Unexpected value: " + fieldTypedData.type()); }; - } catch (Exception e) { + } catch (ArithmeticException e) { expected = null; + expectedWarnings.add("Line -1:-1: evaluation of [] failed, treating result as null. Only first 20 failures recorded."); + expectedWarnings.add("Line -1:-1: java.lang.ArithmeticException: long overflow"); } var dataType = fieldTypedData.type().isWholeNumber() == false || fieldTypedData.type() == UNSIGNED_LONG ? DataType.DOUBLE : DataType.LONG; - return new TestCaseSupplier.TestCase(List.of(fieldTypedData), "Sum[field=Attribute[channel=0]]", dataType, equalTo(expected)); + return new TestCaseSupplier.TestCase(List.of(fieldTypedData), "Sum[field=Attribute[channel=0]]", dataType, equalTo(expected)) + .withWarnings(expectedWarnings); }); } } From 35f70027923894c95cc4b7ad82bae008196c3bd1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Tue, 6 Aug 2024 16:44:08 +0200 Subject: [PATCH 07/31] Update docs/changelog/111639.yaml --- docs/changelog/111639.yaml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 docs/changelog/111639.yaml diff --git a/docs/changelog/111639.yaml b/docs/changelog/111639.yaml new file mode 100644 index 0000000000000..135e12a8065c5 --- /dev/null +++ b/docs/changelog/111639.yaml @@ -0,0 +1,6 @@ +pr: 111639 +summary: "ESQL: Add warnings capabilities to aggregators, and prevent overflow on\ + \ SUM aggregation" +area: ES|QL +type: bug +issues: [] From 0fb304060e8d977555129a5f216ee1d7f695394c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Tue, 6 Aug 2024 17:01:55 +0200 Subject: [PATCH 08/31] Fix tests compilation --- .../aggregation/SumLongAggregatorFunctionTests.java | 2 +- .../SumLongGroupingAggregatorFunctionTests.java | 2 +- .../compute/data/BlockSerializationTests.java | 10 +++++----- .../compute/operator/AggregationOperatorTests.java | 2 +- .../compute/operator/HashAggregationOperatorTests.java | 2 +- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumLongAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumLongAggregatorFunctionTests.java index 7fd3cabb2c91e..b69a5441ae61b 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumLongAggregatorFunctionTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumLongAggregatorFunctionTests.java @@ -33,7 +33,7 @@ protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { @Override protected AggregatorFunctionSupplier aggregatorFunction(List inputChannels) { - return new SumLongAggregatorFunctionSupplier(inputChannels); + return new SumLongAggregatorFunctionSupplier(-1, -2, "", inputChannels); } @Override diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunctionTests.java index f41a5cbef94fb..6caad2fd06118 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunctionTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunctionTests.java @@ -23,7 +23,7 @@ public class SumLongGroupingAggregatorFunctionTests extends GroupingAggregatorFunctionTestCase { @Override protected AggregatorFunctionSupplier aggregatorFunction(List inputChannels) { - return new SumLongAggregatorFunctionSupplier(inputChannels); + return new SumLongAggregatorFunctionSupplier(-1, -2, "", inputChannels); } @Override diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/BlockSerializationTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/BlockSerializationTests.java index 8ca02b64f01ff..bc1beb3bd604a 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/BlockSerializationTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/BlockSerializationTests.java @@ -15,7 +15,7 @@ import org.elasticsearch.common.util.BytesRefHash; import org.elasticsearch.common.util.MockBigArrays; import org.elasticsearch.common.util.PageCacheRecycler; -import org.elasticsearch.compute.aggregation.SumLongAggregatorFunction; +import org.elasticsearch.compute.aggregation.MaxLongAggregatorFunction; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasables; import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; @@ -238,7 +238,7 @@ public void testSimulateAggs() { Page page = new Page(blockFactory.newLongArrayVector(new long[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }, 10).asBlock()); var bigArrays = BigArrays.NON_RECYCLING_INSTANCE; var params = new Object[] {}; - var function = SumLongAggregatorFunction.create(driverCtx, List.of(0)); + var function = MaxLongAggregatorFunction.create(driverCtx, List.of(0)); function.addRawInput(page); Block[] blocks = new Block[function.intermediateBlockCount()]; try { @@ -249,13 +249,13 @@ public void testSimulateAggs() { IntStream.range(0, blocks.length) .forEach(i -> EqualsHashCodeTestUtils.checkEqualsAndHashCode(blocks[i], unused -> deserBlocks[i])); - var inputChannels = IntStream.range(0, SumLongAggregatorFunction.intermediateStateDesc().size()).boxed().toList(); - try (var finalAggregator = SumLongAggregatorFunction.create(driverCtx, inputChannels)) { + var inputChannels = IntStream.range(0, MaxLongAggregatorFunction.intermediateStateDesc().size()).boxed().toList(); + try (var finalAggregator = MaxLongAggregatorFunction.create(driverCtx, inputChannels)) { finalAggregator.addIntermediateInput(new Page(deserBlocks)); Block[] finalBlocks = new Block[1]; finalAggregator.evaluateFinal(finalBlocks, 0, driverCtx); try (var finalBlock = (LongBlock) finalBlocks[0]) { - assertThat(finalBlock.getLong(0), is(55L)); + assertThat(finalBlock.getLong(0), is(10L)); } } } finally { diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/AggregationOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/AggregationOperatorTests.java index 38d83fe894170..bb21f162383d5 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/AggregationOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/AggregationOperatorTests.java @@ -47,7 +47,7 @@ protected Operator.OperatorFactory simpleWithMode(AggregatorMode mode) { return new AggregationOperator.AggregationOperatorFactory( List.of( - new SumLongAggregatorFunctionSupplier(sumChannels).aggregatorFactory(mode), + new SumLongAggregatorFunctionSupplier(-1, -2, "", sumChannels).aggregatorFactory(mode), new MaxLongAggregatorFunctionSupplier(maxChannels).aggregatorFactory(mode) ), mode diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java index f2fa94c1feb08..9d376ed9b304e 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java @@ -55,7 +55,7 @@ protected Operator.OperatorFactory simpleWithMode(AggregatorMode mode) { return new HashAggregationOperator.HashAggregationOperatorFactory( List.of(new BlockHash.GroupSpec(0, ElementType.LONG)), List.of( - new SumLongAggregatorFunctionSupplier(sumChannels).groupingAggregatorFactory(mode), + new SumLongAggregatorFunctionSupplier(-1, -2, "", sumChannels).groupingAggregatorFactory(mode), new MaxLongAggregatorFunctionSupplier(maxChannels).groupingAggregatorFactory(mode) ), randomPageSize() From 0d33140847ff66de69a4a6e4c61f0bfc5d12d5f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Wed, 7 Aug 2024 12:56:35 +0200 Subject: [PATCH 09/31] Add warn exceptions in grouping aggregator --- .../compute/ann/GroupingAggregator.java | 6 + ...AggregatorFunctionSupplierImplementer.java | 16 ++- .../compute/gen/AggregatorImplementer.java | 7 +- .../compute/gen/AggregatorProcessor.java | 1 + .../gen/GroupingAggregatorImplementer.java | 136 +++++++++++++----- .../SumLongAggregatorFunction.java | 5 +- .../SumLongAggregatorFunctionSupplier.java | 3 +- .../SumLongGroupingAggregatorFunction.java | 60 ++++++-- 8 files changed, 177 insertions(+), 57 deletions(-) diff --git a/x-pack/plugin/esql/compute/ann/src/main/java/org/elasticsearch/compute/ann/GroupingAggregator.java b/x-pack/plugin/esql/compute/ann/src/main/java/org/elasticsearch/compute/ann/GroupingAggregator.java index 0216ea07e5c7c..8d81b60e20e4d 100644 --- a/x-pack/plugin/esql/compute/ann/src/main/java/org/elasticsearch/compute/ann/GroupingAggregator.java +++ b/x-pack/plugin/esql/compute/ann/src/main/java/org/elasticsearch/compute/ann/GroupingAggregator.java @@ -22,6 +22,12 @@ IntermediateState[] value() default {}; + /** + * Exceptions thrown by the `combine*(...)` methods to catch and convert + * into a warning and turn into a null value. + */ + Class[] warnExceptions() default {}; + /** * If {@code true} then the @timestamp LongVector will be appended to the input blocks of the aggregation function. */ diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorFunctionSupplierImplementer.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorFunctionSupplierImplementer.java index e09ecf657bf01..f11ccbced6fbe 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorFunctionSupplierImplementer.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorFunctionSupplierImplementer.java @@ -157,15 +157,23 @@ private MethodSpec aggregator() { } private MethodSpec groupingAggregator() { - MethodSpec.Builder builder = MethodSpec.methodBuilder("groupingAggregator") - .addParameter(DRIVER_CONTEXT, "driverContext") - .returns(groupingAggregatorImplementer.implementation()); + MethodSpec.Builder builder = MethodSpec.methodBuilder("groupingAggregator"); builder.addAnnotation(Override.class).addModifiers(Modifier.PUBLIC); + builder.addParameter(DRIVER_CONTEXT, "driverContext"); + builder.returns(groupingAggregatorImplementer.implementation()); + + if (hasWarnings) { + builder.addStatement( + "var warnings = Warnings.createWarnings(driverContext.warningsMode(), " + + "warningsLineNumber, warningsColumnNumber, warningsSourceText)" + ); + } + builder.addStatement( "return $T.create($L)", groupingAggregatorImplementer.implementation(), Stream.concat( - Stream.of("channels, driverContext"), + Stream.concat(hasWarnings ? Stream.of("warnings") : Stream.of(), Stream.of("channels, driverContext")), groupingAggregatorImplementer.createParameters().stream().map(Parameter::name) ).collect(Collectors.joining(", ")) ); diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java index 9119309269646..66036d54c7fa4 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java @@ -316,10 +316,10 @@ private MethodSpec ctor() { builder.addParameter(LIST_INTEGER, "channels"); builder.addParameter(stateType, "state"); - builder.addStatement("this.driverContext = driverContext"); if (warnExceptions.isEmpty() == false) { builder.addStatement("this.warnings = warnings"); } + builder.addStatement("this.driverContext = driverContext"); builder.addStatement("this.channels = channels"); builder.addStatement("this.state = state"); @@ -389,6 +389,11 @@ private MethodSpec addRawBlock() { builder.beginControlFlow("if (block.isNull(p))"); builder.addStatement("continue"); builder.endControlFlow(); + if (stateTypeHasFailed) { + builder.beginControlFlow("if (state.failed())"); + builder.addStatement("continue"); + builder.endControlFlow(); + } if (stateTypeHasSeen) { builder.addStatement("state.seen(true)"); } diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorProcessor.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorProcessor.java index 9c21af1d75b20..573c9676610d7 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorProcessor.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorProcessor.java @@ -100,6 +100,7 @@ public boolean process(Set set, RoundEnvironment roundEnv env.getElementUtils(), aggClass, intermediateState, + warnExceptionsTypes, includeTimestamps ); write(aggClass, "grouping aggregator", groupingAggregatorImplementer.sourceFile(), env); diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java index 79df41f304c06..b6af73377f32d 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java @@ -22,11 +22,13 @@ import java.util.List; import java.util.Locale; import java.util.function.Consumer; +import java.util.regex.Pattern; import java.util.stream.Collectors; import javax.lang.model.element.ExecutableElement; import javax.lang.model.element.Modifier; import javax.lang.model.element.TypeElement; +import javax.lang.model.type.TypeMirror; import javax.lang.model.util.Elements; import static java.util.stream.Collectors.joining; @@ -38,6 +40,7 @@ import static org.elasticsearch.compute.gen.Types.BIG_ARRAYS; import static org.elasticsearch.compute.gen.Types.BLOCK_ARRAY; import static org.elasticsearch.compute.gen.Types.BYTES_REF; +import static org.elasticsearch.compute.gen.Types.COMPUTE_WARNINGS; import static org.elasticsearch.compute.gen.Types.DRIVER_CONTEXT; import static org.elasticsearch.compute.gen.Types.ELEMENT_TYPE; import static org.elasticsearch.compute.gen.Types.GROUPING_AGGREGATOR_FUNCTION; @@ -63,6 +66,7 @@ */ public class GroupingAggregatorImplementer { private final TypeElement declarationType; + private final List warnExceptions; private final ExecutableElement init; private final ExecutableElement combine; private final ExecutableElement combineStates; @@ -79,9 +83,11 @@ public GroupingAggregatorImplementer( Elements elements, TypeElement declarationType, IntermediateState[] interStateAnno, + List warnExceptions, boolean includeTimestampVector ) { this.declarationType = declarationType; + this.warnExceptions = warnExceptions; this.init = findRequiredMethod(declarationType, new String[] { "init", "initGrouping" }, e -> true); this.stateType = choseStateType(); @@ -129,7 +135,10 @@ private TypeName choseStateType() { } String head = initReturn.toString().substring(0, 1).toUpperCase(Locale.ROOT); String tail = initReturn.toString().substring(1); - return ClassName.get("org.elasticsearch.compute.aggregation", head + tail + "ArrayState"); + if (warnExceptions.isEmpty()) { + return ClassName.get("org.elasticsearch.compute.aggregation", head + tail + "ArrayState"); + } + return ClassName.get("org.elasticsearch.compute.aggregation", head + tail + "FallibleArrayState"); } public JavaFile sourceFile() { @@ -154,6 +163,9 @@ private TypeSpec type() { .build() ); builder.addField(stateType, "state", Modifier.PRIVATE, Modifier.FINAL); + if (warnExceptions.isEmpty() == false) { + builder.addField(COMPUTE_WARNINGS, "warnings", Modifier.PRIVATE, Modifier.FINAL); + } builder.addField(LIST_INTEGER, "channels", Modifier.PRIVATE, Modifier.FINAL); builder.addField(DRIVER_CONTEXT, "driverContext", Modifier.PRIVATE, Modifier.FINAL); @@ -182,17 +194,26 @@ private TypeSpec type() { private MethodSpec create() { MethodSpec.Builder builder = MethodSpec.methodBuilder("create"); builder.addModifiers(Modifier.PUBLIC, Modifier.STATIC).returns(implementation); + if (warnExceptions.isEmpty() == false) { + builder.addParameter(COMPUTE_WARNINGS, "warnings"); + } builder.addParameter(LIST_INTEGER, "channels"); builder.addParameter(DRIVER_CONTEXT, "driverContext"); for (Parameter p : createParameters) { builder.addParameter(p.type(), p.name()); } if (createParameters.isEmpty()) { - builder.addStatement("return new $T(channels, $L, driverContext)", implementation, callInit()); + builder.addStatement( + "return new $T($Lchannels, $L, driverContext)", + implementation, + warnExceptions.isEmpty() ? "" : "warnings, ", + callInit() + ); } else { builder.addStatement( - "return new $T(channels, $L, driverContext, $L)", + "return new $T($Lchannels, $L, driverContext, $L)", implementation, + warnExceptions.isEmpty() ? "" : "warnings, ", callInit(), createParameters.stream().map(p -> p.name()).collect(joining(", ")) ); @@ -235,9 +256,15 @@ private CodeBlock initInterState() { private MethodSpec ctor() { MethodSpec.Builder builder = MethodSpec.constructorBuilder().addModifiers(Modifier.PUBLIC); + if (warnExceptions.isEmpty() == false) { + builder.addParameter(COMPUTE_WARNINGS, "warnings"); + } builder.addParameter(LIST_INTEGER, "channels"); builder.addParameter(stateType, "state"); builder.addParameter(DRIVER_CONTEXT, "driverContext"); + if (warnExceptions.isEmpty() == false) { + builder.addStatement("this.warnings = warnings"); + } builder.addStatement("this.channels = channels"); builder.addStatement("this.state = state"); builder.addStatement("this.driverContext = driverContext"); @@ -349,6 +376,12 @@ private MethodSpec addRawInputLoop(TypeName groupsType, TypeName valuesType) { builder.addStatement("int groupId = Math.toIntExact(groups.getInt(groupPosition))"); } + if (warnExceptions.isEmpty() == false) { + builder.beginControlFlow("if (state.hasFailed(groupId))"); + builder.addStatement("continue"); + builder.endControlFlow(); + } + if (valuesIsBlock) { builder.beginControlFlow("if (values.isNull(groupPosition + positionOffset))"); builder.addStatement("continue"); @@ -371,31 +404,35 @@ private MethodSpec addRawInputLoop(TypeName groupsType, TypeName valuesType) { } private void combineRawInput(MethodSpec.Builder builder, String blockVariable, String offsetVariable) { - if (valuesIsBytesRef) { - combineRawInputForBytesRef(builder, blockVariable, offsetVariable); - return; - } - if (includeTimestampVector) { - combineRawInputWithTimestamp(builder, offsetVariable); - return; - } TypeName valueType = TypeName.get(combine.getParameters().get(combine.getParameters().size() - 1).asType()); - if (valueType.isPrimitive() == false) { - throw new IllegalArgumentException("second parameter to combine must be a primitive"); - } String secondParameterGetter = "get" + valueType.toString().substring(0, 1).toUpperCase(Locale.ROOT) + valueType.toString().substring(1); TypeName returnType = TypeName.get(combine.getReturnType()); - if (returnType.isPrimitive()) { - combineRawInputForPrimitive(builder, secondParameterGetter, blockVariable, offsetVariable); - return; + + if (warnExceptions.isEmpty() == false) { + builder.beginControlFlow("try"); } - if (returnType == TypeName.VOID) { + if (valuesIsBytesRef) { + combineRawInputForBytesRef(builder, blockVariable, offsetVariable); + } else if (includeTimestampVector) { + combineRawInputWithTimestamp(builder, offsetVariable); + } else if (valueType.isPrimitive() == false) { + throw new IllegalArgumentException("second parameter to combine must be a primitive"); + } else if (returnType.isPrimitive()) { + combineRawInputForPrimitive(builder, secondParameterGetter, blockVariable, offsetVariable); + } else if (returnType == TypeName.VOID) { combineRawInputForVoid(builder, secondParameterGetter, blockVariable, offsetVariable); - return; + } else { + throw new IllegalArgumentException("combine must return void or a primitive"); + } + if (warnExceptions.isEmpty() == false) { + String catchPattern = "catch (" + warnExceptions.stream().map(m -> "$T").collect(Collectors.joining(" | ")) + " e)"; + builder.nextControlFlow(catchPattern, warnExceptions.stream().map(TypeName::get).toArray()); + builder.addStatement("warnings.registerException(e)"); + builder.addStatement("state.setFailed(groupId)"); + builder.endControlFlow(); } - throw new IllegalArgumentException("combine must return void or a primitive"); } private void combineRawInputForPrimitive( @@ -481,19 +518,41 @@ private MethodSpec addIntermediateInput() { { builder.addStatement("int groupId = Math.toIntExact(groups.getInt(groupPosition))"); if (hasPrimitiveState()) { - assert intermediateState.size() == 2; - assert intermediateState.get(1).name().equals("seen"); - builder.beginControlFlow("if (seen.getBoolean(groupPosition + positionOffset))"); - { - var name = intermediateState.get(0).name(); - var m = vectorAccessorName(intermediateState.get(0).elementType()); - builder.addStatement( - "state.set(groupId, $T.combine(state.getOrDefault(groupId), $L.$L(groupPosition + positionOffset)))", - declarationType, - name, - m - ); - builder.endControlFlow(); + if (warnExceptions.isEmpty()) { + assert intermediateState.size() == 2; + assert intermediateState.get(1).name().equals("seen"); + builder.beginControlFlow("if (seen.getBoolean(groupPosition + positionOffset))"); + { + var name = intermediateState.get(0).name(); + var m = vectorAccessorName(intermediateState.get(0).elementType()); + builder.addStatement( + "state.set(groupId, $T.combine(state.getOrDefault(groupId), $L.$L(groupPosition + positionOffset)))", + declarationType, + name, + m + ); + builder.endControlFlow(); + } + } else { + assert intermediateState.size() == 3; + assert intermediateState.get(1).name().equals("seen"); + assert intermediateState.get(2).name().equals("failed"); + builder.beginControlFlow("if (failed.getBoolean(groupPosition + positionOffset))"); + { + builder.addStatement("state.setFailed(groupId)"); + } + builder.nextControlFlow("else if (seen.getBoolean(groupPosition + positionOffset))"); + { + var name = intermediateState.get(0).name(); + var m = vectorAccessorName(intermediateState.get(0).elementType()); + builder.addStatement( + "state.set(groupId, $T.combine(state.getOrDefault(groupId), $L.$L(groupPosition + positionOffset)))", + declarationType, + name, + m + ); + builder.endControlFlow(); + } } } else { builder.addStatement("$T.combineIntermediate(state, groupId, " + intermediateStateRowAccess() + ")", declarationType); @@ -582,12 +641,11 @@ private MethodSpec close() { return builder.build(); } + private static final Pattern PRIMITIVE_STATE_PATTERN = Pattern.compile( + "org.elasticsearch.compute.aggregation.(Boolean|Int|Long|Double|Float)(Fallible)?ArrayState" + ); + private boolean hasPrimitiveState() { - return switch (stateType.toString()) { - case "org.elasticsearch.compute.aggregation.BooleanArrayState", "org.elasticsearch.compute.aggregation.IntArrayState", - "org.elasticsearch.compute.aggregation.LongArrayState", "org.elasticsearch.compute.aggregation.DoubleArrayState", - "org.elasticsearch.compute.aggregation.FloatArrayState" -> true; - default -> false; - }; + return PRIMITIVE_STATE_PATTERN.matcher(stateType.toString()).matches(); } } diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunction.java index 649e478beb9bd..1b21fb9fc3f04 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunction.java @@ -39,8 +39,8 @@ public final class SumLongAggregatorFunction implements AggregatorFunction { public SumLongAggregatorFunction(Warnings warnings, DriverContext driverContext, List channels, LongFallibleState state) { - this.driverContext = driverContext; this.warnings = warnings; + this.driverContext = driverContext; this.channels = channels; this.state = state; } @@ -87,6 +87,9 @@ private void addRawBlock(LongBlock block) { if (block.isNull(p)) { continue; } + if (state.failed()) { + continue; + } state.seen(true); int start = block.getFirstValueIndex(p); int end = start + block.getValueCount(p); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunctionSupplier.java index 8a979111627fd..5b4acdaca0c20 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunctionSupplier.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunctionSupplier.java @@ -39,7 +39,8 @@ public SumLongAggregatorFunction aggregator(DriverContext driverContext) { @Override public SumLongGroupingAggregatorFunction groupingAggregator(DriverContext driverContext) { - return SumLongGroupingAggregatorFunction.create(channels, driverContext); + var warnings = Warnings.createWarnings(driverContext.warningsMode(), warningsLineNumber, warningsColumnNumber, warningsSourceText); + return SumLongGroupingAggregatorFunction.create(warnings, channels, driverContext); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunction.java index 774419e96666e..f3b6f7e75f722 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunction.java @@ -4,6 +4,7 @@ // 2.0. package org.elasticsearch.compute.aggregation; +import java.lang.ArithmeticException; import java.lang.Integer; import java.lang.Override; import java.lang.String; @@ -30,22 +31,25 @@ public final class SumLongGroupingAggregatorFunction implements GroupingAggregat new IntermediateStateDesc("seen", ElementType.BOOLEAN), new IntermediateStateDesc("failed", ElementType.BOOLEAN) ); - private final LongArrayState state; + private final LongFallibleArrayState state; + + private final Warnings warnings; private final List channels; private final DriverContext driverContext; - public SumLongGroupingAggregatorFunction(List channels, LongArrayState state, - DriverContext driverContext) { + public SumLongGroupingAggregatorFunction(Warnings warnings, List channels, + LongFallibleArrayState state, DriverContext driverContext) { + this.warnings = warnings; this.channels = channels; this.state = state; this.driverContext = driverContext; } - public static SumLongGroupingAggregatorFunction create(List channels, + public static SumLongGroupingAggregatorFunction create(Warnings warnings, List channels, DriverContext driverContext) { - return new SumLongGroupingAggregatorFunction(channels, new LongArrayState(driverContext.bigArrays(), SumLongAggregator.init()), driverContext); + return new SumLongGroupingAggregatorFunction(warnings, channels, new LongFallibleArrayState(driverContext.bigArrays(), SumLongAggregator.init()), driverContext); } public static List intermediateStateDesc() { @@ -94,13 +98,21 @@ public void add(int positionOffset, IntVector groupIds) { private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { int groupId = Math.toIntExact(groups.getInt(groupPosition)); + if (state.hasFailed(groupId)) { + continue; + } if (values.isNull(groupPosition + positionOffset)) { continue; } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { - state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(v))); + try { + state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(v))); + } catch (ArithmeticException e) { + warnings.registerException(e); + state.setFailed(groupId); + } } } } @@ -108,7 +120,15 @@ private void addRawInput(int positionOffset, IntVector groups, LongBlock values) private void addRawInput(int positionOffset, IntVector groups, LongVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { int groupId = Math.toIntExact(groups.getInt(groupPosition)); - state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(groupPosition + positionOffset))); + if (state.hasFailed(groupId)) { + continue; + } + try { + state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(groupPosition + positionOffset))); + } catch (ArithmeticException e) { + warnings.registerException(e); + state.setFailed(groupId); + } } } @@ -121,13 +141,21 @@ private void addRawInput(int positionOffset, IntBlock groups, LongBlock values) int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = Math.toIntExact(groups.getInt(g)); + if (state.hasFailed(groupId)) { + continue; + } if (values.isNull(groupPosition + positionOffset)) { continue; } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { - state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(v))); + try { + state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(v))); + } catch (ArithmeticException e) { + warnings.registerException(e); + state.setFailed(groupId); + } } } } @@ -142,7 +170,15 @@ private void addRawInput(int positionOffset, IntBlock groups, LongVector values) int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = Math.toIntExact(groups.getInt(g)); - state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(groupPosition + positionOffset))); + if (state.hasFailed(groupId)) { + continue; + } + try { + state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(groupPosition + positionOffset))); + } catch (ArithmeticException e) { + warnings.registerException(e); + state.setFailed(groupId); + } } } } @@ -169,7 +205,9 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page assert sum.getPositionCount() == seen.getPositionCount() && sum.getPositionCount() == failed.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { int groupId = Math.toIntExact(groups.getInt(groupPosition)); - if (seen.getBoolean(groupPosition + positionOffset)) { + if (failed.getBoolean(groupPosition + positionOffset)) { + state.setFailed(groupId); + } else if (seen.getBoolean(groupPosition + positionOffset)) { state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), sum.getLong(groupPosition + positionOffset))); } } @@ -180,7 +218,7 @@ public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction inpu if (input.getClass() != getClass()) { throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); } - LongArrayState inState = ((SumLongGroupingAggregatorFunction) input).state; + LongFallibleArrayState inState = ((SumLongGroupingAggregatorFunction) input).state; state.enableGroupIdTracking(new SeenGroupIds.Empty()); if (inState.hasValue(position)) { state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), inState.get(position))); From 523a5830361cc95f5fbcfe69c77f31f60817af76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Wed, 7 Aug 2024 13:08:47 +0200 Subject: [PATCH 10/31] Update docs/changelog/111639.yaml --- docs/changelog/111639.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/changelog/111639.yaml b/docs/changelog/111639.yaml index 135e12a8065c5..5352d84851a36 100644 --- a/docs/changelog/111639.yaml +++ b/docs/changelog/111639.yaml @@ -3,4 +3,5 @@ summary: "ESQL: Add warnings capabilities to aggregators, and prevent overflow o \ SUM aggregation" area: ES|QL type: bug -issues: [] +issues: + - 110443 From 1b09dd315b2a08a635ac434862de1286314d0f6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Wed, 7 Aug 2024 13:24:13 +0200 Subject: [PATCH 11/31] Early returns on non-grouping, and todos for missing catch --- .../compute/gen/AggregatorImplementer.java | 12 +++++++----- .../compute/gen/GroupingAggregatorImplementer.java | 1 + .../aggregation/SumLongAggregatorFunction.java | 8 +++++--- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java index 66036d54c7fa4..a3afb9dadb668 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java @@ -346,6 +346,11 @@ private MethodSpec intermediateBlockCount() { private MethodSpec addRawInput() { MethodSpec.Builder builder = MethodSpec.methodBuilder("addRawInput"); builder.addAnnotation(Override.class).addModifiers(Modifier.PUBLIC).addParameter(PAGE, "page"); + if (stateTypeHasFailed) { + builder.beginControlFlow("if (state.failed())"); + builder.addStatement("return"); + builder.endControlFlow(); + } builder.addStatement("$T block = page.getBlock(channels.get(0))", valueBlockType(init, combine)); builder.addStatement("$T vector = block.asVector()", valueVectorType(init, combine)); builder.beginControlFlow("if (vector != null)").addStatement("addRawVector(vector)"); @@ -389,11 +394,6 @@ private MethodSpec addRawBlock() { builder.beginControlFlow("if (block.isNull(p))"); builder.addStatement("continue"); builder.endControlFlow(); - if (stateTypeHasFailed) { - builder.beginControlFlow("if (state.failed())"); - builder.addStatement("continue"); - builder.endControlFlow(); - } if (stateTypeHasSeen) { builder.addStatement("state.seen(true)"); } @@ -429,6 +429,7 @@ private void combineRawInput(MethodSpec.Builder builder, String blockVariable) { builder.nextControlFlow(catchPattern, warnExceptions.stream().map(TypeName::get).toArray()); builder.addStatement("warnings.registerException(e)"); builder.addStatement("state.failed(true)"); + builder.addStatement("return"); builder.endControlFlow(); } } @@ -497,6 +498,7 @@ private MethodSpec addIntermediateInput() { builder.nextControlFlow("else if (seen.getBoolean(0))"); { var state = intermediateState.get(0); + // TODO: Add try-catch of warnExceptions here! var s = "state.$L($T.combine(state.$L(), " + state.name() + "." + vectorAccessorName(state.elementType()) + "(0)))"; builder.addStatement(s, primitiveStateMethod(), declarationType, primitiveStateMethod()); builder.addStatement("state.seen(true)"); diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java index b6af73377f32d..cbd76f78ff396 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java @@ -545,6 +545,7 @@ private MethodSpec addIntermediateInput() { { var name = intermediateState.get(0).name(); var m = vectorAccessorName(intermediateState.get(0).elementType()); + // TODO: Add try-catch of warnExceptions here! builder.addStatement( "state.set(groupId, $T.combine(state.getOrDefault(groupId), $L.$L(groupPosition + positionOffset)))", declarationType, diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunction.java index 1b21fb9fc3f04..f3a8410139bb0 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunction.java @@ -61,6 +61,9 @@ public int intermediateBlockCount() { @Override public void addRawInput(Page page) { + if (state.failed()) { + return; + } LongBlock block = page.getBlock(channels.get(0)); LongVector vector = block.asVector(); if (vector != null) { @@ -78,6 +81,7 @@ private void addRawVector(LongVector vector) { } catch (ArithmeticException e) { warnings.registerException(e); state.failed(true); + return; } } } @@ -87,9 +91,6 @@ private void addRawBlock(LongBlock block) { if (block.isNull(p)) { continue; } - if (state.failed()) { - continue; - } state.seen(true); int start = block.getFirstValueIndex(p); int end = start + block.getValueCount(p); @@ -99,6 +100,7 @@ private void addRawBlock(LongBlock block) { } catch (ArithmeticException e) { warnings.registerException(e); state.failed(true); + return; } } } From b1fdc255890ca8c3a42f03788b3930ce716c6630 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Wed, 7 Aug 2024 14:03:56 +0200 Subject: [PATCH 12/31] Add try-catch in intermediate state --- .../compute/gen/AggregatorImplementer.java | 31 ++++++------- .../gen/GroupingAggregatorImplementer.java | 43 +++++++++---------- .../SumLongAggregatorFunction.java | 9 +++- .../SumLongGroupingAggregatorFunction.java | 7 ++- 4 files changed, 49 insertions(+), 41 deletions(-) diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java index a3afb9dadb668..aa183f2d43b48 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java @@ -479,13 +479,6 @@ private MethodSpec addIntermediateInput() { assert intermediateState.size() == 2; assert intermediateState.get(1).name().equals("seen"); builder.beginControlFlow("if (seen.getBoolean(0))"); - { - var state = intermediateState.get(0); - var s = "state.$L($T.combine(state.$L(), " + state.name() + "." + vectorAccessorName(state.elementType()) + "(0)))"; - builder.addStatement(s, primitiveStateMethod(), declarationType, primitiveStateMethod()); - builder.addStatement("state.seen(true)"); - builder.endControlFlow(); - } } else { assert intermediateState.size() == 3; assert intermediateState.get(1).name().equals("seen"); @@ -496,15 +489,23 @@ private MethodSpec addIntermediateInput() { builder.addStatement("state.seen(true)"); } builder.nextControlFlow("else if (seen.getBoolean(0))"); - { - var state = intermediateState.get(0); - // TODO: Add try-catch of warnExceptions here! - var s = "state.$L($T.combine(state.$L(), " + state.name() + "." + vectorAccessorName(state.elementType()) + "(0)))"; - builder.addStatement(s, primitiveStateMethod(), declarationType, primitiveStateMethod()); - builder.addStatement("state.seen(true)"); - builder.endControlFlow(); - } } + + if (warnExceptions.isEmpty() == false) { + builder.beginControlFlow("try"); + } + var state = intermediateState.get(0); + var s = "state.$L($T.combine(state.$L(), " + state.name() + "." + vectorAccessorName(state.elementType()) + "(0)))"; + builder.addStatement(s, primitiveStateMethod(), declarationType, primitiveStateMethod()); + builder.addStatement("state.seen(true)"); + if (warnExceptions.isEmpty() == false) { + String catchPattern = "catch (" + warnExceptions.stream().map(m -> "$T").collect(Collectors.joining(" | ")) + " e)"; + builder.nextControlFlow(catchPattern, warnExceptions.stream().map(TypeName::get).toArray()); + builder.addStatement("warnings.registerException(e)"); + builder.addStatement("state.failed(true)"); + builder.endControlFlow(); + } + builder.endControlFlow(); } else { throw new IllegalArgumentException("Don't know how to combine intermediate input. Define combineIntermediate"); } diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java index cbd76f78ff396..0c4aeca996a19 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java @@ -522,17 +522,6 @@ private MethodSpec addIntermediateInput() { assert intermediateState.size() == 2; assert intermediateState.get(1).name().equals("seen"); builder.beginControlFlow("if (seen.getBoolean(groupPosition + positionOffset))"); - { - var name = intermediateState.get(0).name(); - var m = vectorAccessorName(intermediateState.get(0).elementType()); - builder.addStatement( - "state.set(groupId, $T.combine(state.getOrDefault(groupId), $L.$L(groupPosition + positionOffset)))", - declarationType, - name, - m - ); - builder.endControlFlow(); - } } else { assert intermediateState.size() == 3; assert intermediateState.get(1).name().equals("seen"); @@ -542,19 +531,27 @@ private MethodSpec addIntermediateInput() { builder.addStatement("state.setFailed(groupId)"); } builder.nextControlFlow("else if (seen.getBoolean(groupPosition + positionOffset))"); - { - var name = intermediateState.get(0).name(); - var m = vectorAccessorName(intermediateState.get(0).elementType()); - // TODO: Add try-catch of warnExceptions here! - builder.addStatement( - "state.set(groupId, $T.combine(state.getOrDefault(groupId), $L.$L(groupPosition + positionOffset)))", - declarationType, - name, - m - ); - builder.endControlFlow(); - } } + + if (warnExceptions.isEmpty() == false) { + builder.beginControlFlow("try"); + } + var name = intermediateState.get(0).name(); + var vectorAccessor = vectorAccessorName(intermediateState.get(0).elementType()); + builder.addStatement( + "state.set(groupId, $T.combine(state.getOrDefault(groupId), $L.$L(groupPosition + positionOffset)))", + declarationType, + name, + vectorAccessor + ); + if (warnExceptions.isEmpty() == false) { + String catchPattern = "catch (" + warnExceptions.stream().map(m -> "$T").collect(Collectors.joining(" | ")) + " e)"; + builder.nextControlFlow(catchPattern, warnExceptions.stream().map(TypeName::get).toArray()); + builder.addStatement("warnings.registerException(e)"); + builder.addStatement("state.setFailed(groupId)"); + builder.endControlFlow(); + } + builder.endControlFlow(); } else { builder.addStatement("$T.combineIntermediate(state, groupId, " + intermediateStateRowAccess() + ")", declarationType); } diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunction.java index f3a8410139bb0..73677a21f8e39 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunction.java @@ -132,8 +132,13 @@ public void addIntermediateInput(Page page) { state.failed(true); state.seen(true); } else if (seen.getBoolean(0)) { - state.longValue(SumLongAggregator.combine(state.longValue(), sum.getLong(0))); - state.seen(true); + try { + state.longValue(SumLongAggregator.combine(state.longValue(), sum.getLong(0))); + state.seen(true); + } catch (ArithmeticException e) { + warnings.registerException(e); + state.failed(true); + } } } diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunction.java index f3b6f7e75f722..7e0e456766e07 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunction.java @@ -208,7 +208,12 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page if (failed.getBoolean(groupPosition + positionOffset)) { state.setFailed(groupId); } else if (seen.getBoolean(groupPosition + positionOffset)) { - state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), sum.getLong(groupPosition + positionOffset))); + try { + state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), sum.getLong(groupPosition + positionOffset))); + } catch (ArithmeticException e) { + warnings.registerException(e); + state.setFailed(groupId); + } } } } From 16b1cf86670e2be37dafde78f45e57dbf45ad97b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Wed, 7 Aug 2024 17:24:32 +0200 Subject: [PATCH 13/31] Release failed array, and fixed warnings on grouping tests --- .../compute/aggregation/AbstractFallibleArrayState.java | 2 +- .../expression/function/AbstractAggregationTestCase.java | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/AbstractFallibleArrayState.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/AbstractFallibleArrayState.java index 8a5aa7580d927..22d907b1319aa 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/AbstractFallibleArrayState.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/AbstractFallibleArrayState.java @@ -64,6 +64,6 @@ protected final void setFailed(int groupId) { @Override public void close() { - Releasables.close(seen); + Releasables.close(seen, failed); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java index 65425486ea4e0..c2e3c992c9fc8 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java @@ -195,9 +195,9 @@ private void aggregateGroupingSingleMode(Expression expression) { assert testCase.getMatcher().matches(Double.NEGATIVE_INFINITY) == false; assertThat(result, not(equalTo(Double.NEGATIVE_INFINITY))); assertThat(result, testCase.getMatcher()); - if (testCase.getExpectedWarnings() != null) { - assertWarnings(testCase.getExpectedWarnings()); - } + } + if (testCase.getExpectedWarnings() != null) { + assertWarnings(testCase.getExpectedWarnings()); } } From d6d44c7f1f51a90f07369110bd4b9cf4a77983b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Wed, 7 Aug 2024 17:39:38 +0200 Subject: [PATCH 14/31] Fixed benchmark compilation --- .../benchmark/compute/operator/AggregatorBenchmark.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java index 8b22e50e4e8c9..913aebe2d509a 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java @@ -168,7 +168,7 @@ private static AggregatorFunctionSupplier supplier(String op, String dataType, i default -> throw new IllegalArgumentException("unsupported data type [" + dataType + "]"); }; case SUM -> switch (dataType) { - case LONGS -> new SumLongAggregatorFunctionSupplier(List.of(dataChannel)); + case LONGS -> new SumLongAggregatorFunctionSupplier(-1, -2, "", List.of(dataChannel)); case DOUBLES -> new SumDoubleAggregatorFunctionSupplier(List.of(dataChannel)); default -> throw new IllegalArgumentException("unsupported data type [" + dataType + "]"); }; From 10ae7b2b1027aaa0f287e5fc8d88513984c17cc2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Mon, 12 Aug 2024 16:04:45 +0200 Subject: [PATCH 15/31] Extend from AbstractArrayState, and extract annotations method --- .../compute/gen/AggregatorProcessor.java | 28 +++--------- .../compute/gen/Annotations.java | 45 +++++++++++++++++++ .../compute/gen/EvaluatorProcessor.java | 35 ++++----------- .../AbstractFallibleArrayState.java | 36 ++------------- 4 files changed, 62 insertions(+), 82 deletions(-) create mode 100644 x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Annotations.java diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorProcessor.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorProcessor.java index 573c9676610d7..4b1f946a1d176 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorProcessor.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorProcessor.java @@ -15,7 +15,6 @@ import java.io.IOException; import java.io.Writer; -import java.util.ArrayList; import java.util.Collections; import java.util.IdentityHashMap; import java.util.List; @@ -28,11 +27,9 @@ import javax.annotation.processing.RoundEnvironment; import javax.lang.model.SourceVersion; import javax.lang.model.element.AnnotationMirror; -import javax.lang.model.element.AnnotationValue; import javax.lang.model.element.Element; import javax.lang.model.element.ExecutableElement; import javax.lang.model.element.TypeElement; -import javax.lang.model.type.TypeMirror; import javax.tools.Diagnostic; import javax.tools.JavaFileObject; @@ -83,7 +80,11 @@ public boolean process(Set set, RoundEnvironment roundEnv } for (TypeElement aggClass : annotatedClasses) { AggregatorImplementer implementer = null; - var warnExceptionsTypes = warnExceptions(aggClass); + var warnExceptionsTypes = Annotations.listAttributeValues( + aggClass, + Set.of(Aggregator.class, GroupingAggregator.class), + "warnExceptions" + ); if (aggClass.getAnnotation(Aggregator.class) != null) { IntermediateState[] intermediateState = aggClass.getAnnotation(Aggregator.class).value(); implementer = new AggregatorImplementer(env.getElementUtils(), aggClass, intermediateState, warnExceptionsTypes); @@ -143,23 +144,4 @@ public static void write(Object origination, String what, JavaFile file, Process throw new RuntimeException(e); } } - - private static List warnExceptions(Element aggregatorMethod) { - List result = new ArrayList<>(); - for (var mirror : aggregatorMethod.getAnnotationMirrors()) { - String annotationType = mirror.getAnnotationType().toString(); - if (annotationType.equals(Aggregator.class.getName()) || annotationType.equals(GroupingAggregator.class.getName())) { - - for (var e : mirror.getElementValues().entrySet()) { - if (false == e.getKey().getSimpleName().toString().equals("warnExceptions")) { - continue; - } - for (var v : (List) e.getValue().getValue()) { - result.add((TypeMirror) ((AnnotationValue) v).getValue()); - } - } - } - } - return result; - } } diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Annotations.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Annotations.java new file mode 100644 index 0000000000000..d3892f7d2a40b --- /dev/null +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Annotations.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.gen; + +import java.util.ArrayList; +import java.util.List; +import java.util.Set; + +import javax.lang.model.element.AnnotationValue; +import javax.lang.model.element.Element; +import javax.lang.model.type.TypeMirror; + +public class Annotations { + private Annotations() {} + + /** + * Returns the values of the requested attribute, from all the matching annotations on the given element. + * + * @param element the element to inspect + * @param annotations the annotations to look for + * @param attributeName the attribute to extract + */ + public static List listAttributeValues(Element element, Set> annotations, String attributeName) { + List result = new ArrayList<>(); + for (var mirror : element.getAnnotationMirrors()) { + String annotationType = mirror.getAnnotationType().toString(); + if (annotations.stream().anyMatch(a -> a.getName().equals(annotationType))) { + for (var e : mirror.getElementValues().entrySet()) { + if (false == e.getKey().getSimpleName().toString().equals(attributeName)) { + continue; + } + for (var v : (List) e.getValue().getValue()) { + result.add((TypeMirror) ((AnnotationValue) v).getValue()); + } + } + } + } + return result; + } +} diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/EvaluatorProcessor.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/EvaluatorProcessor.java index ea3ee938298de..09012c7b3a48a 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/EvaluatorProcessor.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/EvaluatorProcessor.java @@ -11,7 +11,6 @@ import org.elasticsearch.compute.ann.Evaluator; import org.elasticsearch.compute.ann.MvEvaluator; -import java.util.ArrayList; import java.util.List; import java.util.Set; @@ -21,11 +20,9 @@ import javax.annotation.processing.RoundEnvironment; import javax.lang.model.SourceVersion; import javax.lang.model.element.AnnotationMirror; -import javax.lang.model.element.AnnotationValue; import javax.lang.model.element.Element; import javax.lang.model.element.ExecutableElement; import javax.lang.model.element.TypeElement; -import javax.lang.model.type.TypeMirror; import javax.tools.Diagnostic; /** @@ -69,6 +66,11 @@ public Iterable getCompletions( public boolean process(Set set, RoundEnvironment roundEnvironment) { for (TypeElement ann : set) { for (Element evaluatorMethod : roundEnvironment.getElementsAnnotatedWith(ann)) { + var warnExceptionsTypes = Annotations.listAttributeValues( + evaluatorMethod, + Set.of(Evaluator.class, MvEvaluator.class, ConvertEvaluator.class), + "warnExceptions" + ); Evaluator evaluatorAnn = evaluatorMethod.getAnnotation(Evaluator.class); if (evaluatorAnn != null) { try { @@ -80,7 +82,7 @@ public boolean process(Set set, RoundEnvironment roundEnv env.getTypeUtils(), (ExecutableElement) evaluatorMethod, evaluatorAnn.extraName(), - warnExceptions(evaluatorMethod) + warnExceptionsTypes ).sourceFile(), env ); @@ -102,7 +104,7 @@ public boolean process(Set set, RoundEnvironment roundEnv mvEvaluatorAnn.finish(), mvEvaluatorAnn.single(), mvEvaluatorAnn.ascending(), - warnExceptions(evaluatorMethod) + warnExceptionsTypes ).sourceFile(), env ); @@ -121,7 +123,7 @@ public boolean process(Set set, RoundEnvironment roundEnv env.getElementUtils(), (ExecutableElement) evaluatorMethod, convertEvaluatorAnn.extraName(), - warnExceptions(evaluatorMethod) + warnExceptionsTypes ).sourceFile(), env ); @@ -134,25 +136,4 @@ public boolean process(Set set, RoundEnvironment roundEnv } return true; } - - private static List warnExceptions(Element evaluatorMethod) { - List result = new ArrayList<>(); - for (var mirror : evaluatorMethod.getAnnotationMirrors()) { - String annotationType = mirror.getAnnotationType().toString(); - if (annotationType.equals(Evaluator.class.getName()) - || annotationType.equals(MvEvaluator.class.getName()) - || annotationType.equals(ConvertEvaluator.class.getName())) { - - for (var e : mirror.getElementValues().entrySet()) { - if (false == e.getKey().getSimpleName().toString().equals("warnExceptions")) { - continue; - } - for (var v : (List) e.getValue().getValue()) { - result.add((TypeMirror) ((AnnotationValue) v).getValue()); - } - } - } - } - return result; - } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/AbstractFallibleArrayState.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/AbstractFallibleArrayState.java index 22d907b1319aa..6a6d3946dbd62 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/AbstractFallibleArrayState.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/AbstractFallibleArrayState.java @@ -9,48 +9,19 @@ import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BitArray; -import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; -public class AbstractFallibleArrayState implements Releasable { - protected final BigArrays bigArrays; - - private BitArray seen; +public class AbstractFallibleArrayState extends AbstractArrayState { private BitArray failed; public AbstractFallibleArrayState(BigArrays bigArrays) { - this.bigArrays = bigArrays; - } - - final boolean hasValue(int groupId) { - return seen == null || seen.get(groupId); + super(bigArrays); } final boolean hasFailed(int groupId) { return failed != null && failed.get(groupId); } - /** - * Switches this array state into tracking which group ids are set. This is - * idempotent and fast if already tracking so it's safe to, say, call it once - * for every block of values that arrives containing {@code null}. - */ - final void enableGroupIdTracking(SeenGroupIds seenGroupIds) { - if (seen == null) { - seen = seenGroupIds.seenGroupIds(bigArrays); - } - } - - protected final void trackGroupId(int groupId) { - if (trackingGroupIds()) { - seen.set(groupId); - } - } - - protected final boolean trackingGroupIds() { - return seen != null; - } - protected final boolean anyFailure() { return failed != null; } @@ -64,6 +35,7 @@ protected final void setFailed(int groupId) { @Override public void close() { - Releasables.close(seen, failed); + super.close(); + Releasables.close(failed); } } From e95cc8d5b29feded19cd959c7fd1578abe02c5f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Mon, 12 Aug 2024 16:44:03 +0200 Subject: [PATCH 16/31] Removed obsolete test --- .../SumLongAggregatorFunctionTests.java | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumLongAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumLongAggregatorFunctionTests.java index b69a5441ae61b..ad17f92942605 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumLongAggregatorFunctionTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumLongAggregatorFunctionTests.java @@ -47,22 +47,6 @@ public void assertSimpleOutput(List input, Block result) { assertThat(((LongBlock) result).getLong(0), equalTo(sum)); } - public void testOverflowFails() { - DriverContext driverContext = driverContext(); - try ( - Driver d = new Driver( - driverContext, - new SequenceLongBlockSourceOperator(driverContext.blockFactory(), LongStream.of(Long.MAX_VALUE - 1, 2)), - List.of(simple().get(driverContext)), - new PageConsumerOperator(page -> fail("shouldn't have made it this far")), - () -> {} - ) - ) { - Exception e = expectThrows(ArithmeticException.class, () -> runDriver(d)); - assertThat(e.getMessage(), equalTo("long overflow")); - } - } - public void testRejectsDouble() { DriverContext driverContext = driverContext(); BlockFactory blockFactory = driverContext.blockFactory(); From 753af10d80cca0dfc6fc176340cbea5283085fc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Tue, 3 Sep 2024 18:10:52 +0200 Subject: [PATCH 17/31] Test warnings for long overflow --- .../org/elasticsearch/test/ESTestCase.java | 2 +- .../SumLongAggregatorFunctionTests.java | 38 +++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java index 58487d6552bcd..0cc62ca28be6b 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java @@ -438,7 +438,7 @@ private static void setTestSysProps(Random random) { } protected final Logger logger = LogManager.getLogger(getClass()); - private ThreadContext threadContext; + protected ThreadContext threadContext; // ----------------------------------------------------------------- // Suite and test case setup/cleanup. diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumLongAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumLongAggregatorFunctionTests.java index ad17f92942605..071982fa5fe88 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumLongAggregatorFunctionTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumLongAggregatorFunctionTests.java @@ -18,10 +18,13 @@ import org.elasticsearch.compute.operator.PageConsumerOperator; import org.elasticsearch.compute.operator.SequenceLongBlockSourceOperator; import org.elasticsearch.compute.operator.SourceOperator; +import org.elasticsearch.compute.operator.TestResultPageSinkOperator; +import java.util.ArrayList; import java.util.List; import java.util.stream.LongStream; +import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.equalTo; public class SumLongAggregatorFunctionTests extends AggregatorFunctionTestCase { @@ -47,6 +50,41 @@ public void assertSimpleOutput(List input, Block result) { assertThat(((LongBlock) result).getLong(0), equalTo(sum)); } + public void testOverflowFails() { + List results = new ArrayList<>(); + DriverContext driverContext = driverContext(); + List warnings = new ArrayList<>(); + try ( + Driver d = new Driver( + driverContext, + new SequenceLongBlockSourceOperator(driverContext.blockFactory(), LongStream.of(Long.MAX_VALUE - 1, 2)), + List.of(simple().get(driverContext)), + new TestResultPageSinkOperator(results::add), + () -> { + warnings.addAll(threadContext.getResponseHeaders().getOrDefault("Warning", List.of())); + } + ) + ) { + runDriver(d); + } + + assertDriverContext(driverContext); + + assertThat(results.size(), equalTo(1)); + assertThat(results.get(0).getBlockCount(), equalTo(1)); + assertThat(results.get(0).getPositionCount(), equalTo(1)); + assertThat(results.get(0).getBlock(0).isNull(0), equalTo(true)); + + assertThat( + warnings, + contains( + "299 Elasticsearch-8.16.0-unknown \"Line -1:-2: evaluation of [] failed, treating result as null. " + + "Only first 20 failures recorded.\"", + "299 Elasticsearch-8.16.0-unknown \"Line -1:-2: java.lang.ArithmeticException: long overflow\"" + ) + ); + } + public void testRejectsDouble() { DriverContext driverContext = driverContext(); BlockFactory blockFactory = driverContext.blockFactory(); From d71eea2abb7a3f890af8792ba84f38051c564184 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Mon, 14 Oct 2024 15:05:58 +0200 Subject: [PATCH 18/31] Added CSV test failing in mixed cluster --- .../src/main/resources/stats_sum.csv-spec | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_sum.csv-spec diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_sum.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_sum.csv-spec new file mode 100644 index 0000000000000..1bdf63c119554 --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_sum.csv-spec @@ -0,0 +1,14 @@ +sumWithOverflow +FROM employees +| LIMIT 2 +| EVAL x = languages - languages + 9223372036854775807 +| STATS overflow = SUM(x), underflow = SUM(-x) +; +warning:Line 4:20: evaluation of [SUM(x)] failed, treating result as null. Only first 20 failures recorded. +warning:Line 4:20: java.lang.ArithmeticException: long overflow +warning:Line 4:40: evaluation of [SUM(-x)] failed, treating result as null. Only first 20 failures recorded. +warning:Line 4:40: java.lang.ArithmeticException: long overflow + +overflow:long | underflow:long +null | null +; From 919403c20e8668864cc447df212cd548bc0031ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Thu, 31 Oct 2024 15:38:00 +0100 Subject: [PATCH 19/31] Added features to Configuration, and use it in aggregate functions --- .../org/elasticsearch/TransportVersions.java | 1 + .../function/EsqlFunctionRegistry.java | 12 ++- .../function/aggregate/AggregateFunction.java | 2 +- .../expression/function/aggregate/Avg.java | 19 ++--- .../ConfigurationAggregateFunction.java | 75 +++++++++++++++++++ .../expression/function/aggregate/Sum.java | 67 +++++++++++------ .../function/aggregate/WeightedAvg.java | 25 ++++--- .../xpack/esql/plugin/EsqlFeatures.java | 8 +- .../esql/plugin/TransportEsqlQueryAction.java | 7 +- .../xpack/esql/session/Configuration.java | 26 +++++++ 10 files changed, 195 insertions(+), 47 deletions(-) create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/ConfigurationAggregateFunction.java diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index ea3e649de9ef8..1bd0dd611b568 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -186,6 +186,7 @@ static TransportVersion def(int id) { public static final TransportVersion CPU_STAT_STRING_PARSING = def(8_781_00_0); public static final TransportVersion QUERY_RULES_RETRIEVER = def(8_782_00_0); public static final TransportVersion ESQL_CCS_EXEC_INFO_WITH_FAILURES = def(8_783_00_0); + public static final TransportVersion ESQL_CONFIGURATION_WITH_FEATURES = def(8_784_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index 66151275fc2e8..a2a27c1eed07c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -262,7 +262,7 @@ private FunctionDefinition[][] functions() { // since they declare two public constructors - one with filter (for nested where) and one without // use casting to disambiguate between the two new FunctionDefinition[] { - def(Avg.class, uni(Avg::new), "avg"), + def(Avg.class, uniConfig(Avg::new), "avg"), def(Count.class, uni(Count::new), "count"), def(CountDistinct.class, bi(CountDistinct::new), "count_distinct"), def(Max.class, uni(Max::new), "max"), @@ -270,10 +270,10 @@ private FunctionDefinition[][] functions() { def(MedianAbsoluteDeviation.class, uni(MedianAbsoluteDeviation::new), "median_absolute_deviation"), def(Min.class, uni(Min::new), "min"), def(Percentile.class, bi(Percentile::new), "percentile"), - def(Sum.class, uni(Sum::new), "sum"), + def(Sum.class, uniConfig(Sum::new), "sum"), def(Top.class, tri(Top::new), "top"), def(Values.class, uni(Values::new), "values"), - def(WeightedAvg.class, bi(WeightedAvg::new), "weighted_avg") }, + def(WeightedAvg.class, biConfig(WeightedAvg::new), "weighted_avg") }, // math new FunctionDefinition[] { def(Abs.class, Abs::new, "abs"), @@ -935,10 +935,16 @@ protected interface TernaryConfigurationAwareBuilder { private static BiFunction uni(BiFunction function) { return function; } + private static UnaryConfigurationAwareBuilder uniConfig(UnaryConfigurationAwareBuilder function) { + return function; + } private static BinaryBuilder bi(BinaryBuilder function) { return function; } + private static BinaryConfigurationAwareBuilder biConfig(BinaryConfigurationAwareBuilder function) { + return function; + } private static TernaryBuilder tri(TernaryBuilder function) { return function; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java index f7a74cc2ae93f..5cea17ad6cb4e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java @@ -85,7 +85,7 @@ protected AggregateFunction(StreamInput in) throws IOException { } @Override - public final void writeTo(StreamOutput out) throws IOException { + public void writeTo(StreamOutput out) throws IOException { Source.EMPTY.writeTo(out); out.writeNamedWriteable(field); if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_PER_AGGREGATE_FILTER)) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Avg.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Avg.java index 82c0f9d24899e..2e4234ce49244 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Avg.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Avg.java @@ -20,6 +20,7 @@ import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvAvg; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div; +import org.elasticsearch.xpack.esql.session.Configuration; import java.io.IOException; import java.util.List; @@ -28,7 +29,7 @@ import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; -public class Avg extends AggregateFunction implements SurrogateExpression { +public class Avg extends ConfigurationAggregateFunction implements SurrogateExpression { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Avg", Avg::new); @FunctionInfo( @@ -45,12 +46,12 @@ public class Avg extends AggregateFunction implements SurrogateExpression { tag = "docsStatsAvgNestedExpression" ) } ) - public Avg(Source source, @Param(name = "number", type = { "double", "integer", "long" }) Expression field) { - this(source, field, Literal.TRUE); + public Avg(Source source, @Param(name = "number", type = { "double", "integer", "long" }) Expression field, Configuration configuration) { + this(source, field, Literal.TRUE, configuration); } - public Avg(Source source, Expression field, Expression filter) { - super(source, field, filter, emptyList()); + public Avg(Source source, Expression field, Expression filter, Configuration configuration) { + super(source, field, filter, emptyList(), configuration); } @Override @@ -80,17 +81,17 @@ public DataType dataType() { @Override protected NodeInfo info() { - return NodeInfo.create(this, Avg::new, field(), filter()); + return NodeInfo.create(this, Avg::new, field(), filter(), configuration()); } @Override public Avg replaceChildren(List newChildren) { - return new Avg(source(), newChildren.get(0), newChildren.get(1)); + return new Avg(source(), newChildren.get(0), newChildren.get(1), configuration()); } @Override public Avg withFilter(Expression filter) { - return new Avg(source(), field(), filter); + return new Avg(source(), field(), filter, configuration()); } @Override @@ -100,6 +101,6 @@ public Expression surrogate() { return field().foldable() ? new MvAvg(s, field) - : new Div(s, new Sum(s, field, filter()), new Count(s, field, filter()), dataType()); + : new Div(s, new Sum(s, field, filter(), configuration()), new Count(s, field, filter()), dataType()); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/ConfigurationAggregateFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/ConfigurationAggregateFunction.java new file mode 100644 index 0000000000000..85e5cc68719d7 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/ConfigurationAggregateFunction.java @@ -0,0 +1,75 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.compute.data.BlockStreamInput; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; +import org.elasticsearch.xpack.esql.session.Configuration; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public abstract class ConfigurationAggregateFunction extends AggregateFunction { + + private final Configuration configuration; + + ConfigurationAggregateFunction(Source source, Expression field, List parameters, Configuration configuration) { + super(source, field, parameters); + this.configuration = configuration; + } + + ConfigurationAggregateFunction(Source source, Expression field, Expression filter, List parameters, Configuration configuration) { + super(source, field, filter, parameters); + this.configuration = configuration; + } + + ConfigurationAggregateFunction(Source source, Expression field, Configuration configuration) { + super(source, field); + this.configuration = configuration; + } + + ConfigurationAggregateFunction(StreamInput in) throws IOException { + super(in); + this.configuration = ((PlanStreamInput) in).configuration(); + } + + @Override + public final void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_CONFIGURATION_WITH_FEATURES)) { + configuration.writeTo(out); + } + } + + public Configuration configuration() { + return configuration; + } + + @Override + public int hashCode() { + return Objects.hash(getClass(), children(), configuration); + } + + @Override + public boolean equals(Object obj) { + if (super.equals(obj) == false) { + return false; + } + ConfigurationAggregateFunction other = (ConfigurationAggregateFunction) obj; + + return configuration.equals(other.configuration); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java index 7259bae209ca1..7f5fd7c6ba869 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java @@ -12,6 +12,7 @@ import org.elasticsearch.compute.aggregation.SumDoubleAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.SumIntAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.SumLongAggregatorFunctionSupplier; +import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; @@ -24,11 +25,16 @@ import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvSum; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul; +import org.elasticsearch.xpack.esql.planner.ToAggregator; +import org.elasticsearch.xpack.esql.plugin.EsqlFeatures; +import org.elasticsearch.xpack.esql.session.Configuration; import java.io.IOException; import java.util.List; import static java.util.Collections.emptyList; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE; import static org.elasticsearch.xpack.esql.core.type.DataType.LONG; import static org.elasticsearch.xpack.esql.core.type.DataType.UNSIGNED_LONG; @@ -36,7 +42,7 @@ /** * Sum all values of a field in matching documents. */ -public class Sum extends NumericAggregate implements SurrogateExpression { +public class Sum extends ConfigurationAggregateFunction implements ToAggregator, SurrogateExpression { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Sum", Sum::new); @FunctionInfo( @@ -53,12 +59,16 @@ public class Sum extends NumericAggregate implements SurrogateExpression { tag = "docsStatsSumNestedExpression" ) } ) - public Sum(Source source, @Param(name = "number", type = { "double", "integer", "long" }) Expression field) { - this(source, field, Literal.TRUE); + public Sum( + Source source, + @Param(name = "number", type = { "double", "integer", "long" }) Expression field, + Configuration configuration + ) { + this(source, field, Literal.TRUE, configuration); } - public Sum(Source source, Expression field, Expression filter) { - super(source, field, filter, emptyList()); + public Sum(Source source, Expression field, Expression filter, Configuration configuration) { + super(source, field, filter, emptyList(), configuration); } private Sum(StreamInput in) throws IOException { @@ -72,39 +82,54 @@ public String getWriteableName() { @Override protected NodeInfo info() { - return NodeInfo.create(this, Sum::new, field(), filter()); + return NodeInfo.create(this, Sum::new, field(), filter(), configuration()); } @Override public Sum replaceChildren(List newChildren) { - return new Sum(source(), newChildren.get(0), newChildren.get(1)); + return new Sum(source(), newChildren.get(0), newChildren.get(1), configuration()); } @Override public Sum withFilter(Expression filter) { - return new Sum(source(), field(), filter); + return new Sum(source(), field(), filter, configuration()); } @Override - public DataType dataType() { - DataType dt = field().dataType(); - return dt.isWholeNumber() == false || dt == UNSIGNED_LONG ? DOUBLE : LONG; + protected TypeResolution resolveType() { + return isType( + field(), + dt -> dt.isNumeric() && dt != DataType.UNSIGNED_LONG, + sourceText(), + DEFAULT, + "numeric except unsigned_long or counter types" + ); } @Override - protected AggregatorFunctionSupplier longSupplier(List inputChannels) { - var location = source().source(); - return new SumLongAggregatorFunctionSupplier(location.getLineNumber(), location.getColumnNumber(), source().text(), inputChannels); - } - - @Override - protected AggregatorFunctionSupplier intSupplier(List inputChannels) { - return new SumIntAggregatorFunctionSupplier(inputChannels); + public DataType dataType() { + DataType dt = field().dataType(); + return dt.isWholeNumber() == false || dt == UNSIGNED_LONG ? DOUBLE : LONG; } @Override - protected AggregatorFunctionSupplier doubleSupplier(List inputChannels) { - return new SumDoubleAggregatorFunctionSupplier(inputChannels); + public final AggregatorFunctionSupplier supplier(List inputChannels) { + DataType type = field().dataType(); + if (type == DataType.LONG) { + // Old Sum without overflow handling + if (configuration().clusterHasFeature(EsqlFeatures.FN_SUM_OVERFLOW_HANDLING) == false) { + // return new SumLongAggregatorFunctionSupplier(inputChannels); + } + var location = source().source(); + return new SumLongAggregatorFunctionSupplier(location.getLineNumber(), location.getColumnNumber(), source().text(), inputChannels); + } + if (type == DataType.INTEGER) { + return new SumIntAggregatorFunctionSupplier(inputChannels); + } + if (type == DataType.DOUBLE) { + return new SumDoubleAggregatorFunctionSupplier(inputChannels); + } + throw EsqlIllegalArgumentException.illegalDataType(type); } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvg.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvg.java index dbcc50cea3b9b..584109f8719dd 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvg.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvg.java @@ -25,6 +25,7 @@ import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.session.Configuration; import java.io.IOException; import java.util.List; @@ -34,7 +35,7 @@ import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; -public class WeightedAvg extends AggregateFunction implements SurrogateExpression, Validatable { +public class WeightedAvg extends ConfigurationAggregateFunction implements SurrogateExpression, Validatable { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( Expression.class, "WeightedAvg", @@ -54,13 +55,14 @@ public class WeightedAvg extends AggregateFunction implements SurrogateExpressio public WeightedAvg( Source source, @Param(name = "number", type = { "double", "integer", "long" }, description = "A numeric value.") Expression field, - @Param(name = "weight", type = { "double", "integer", "long" }, description = "A numeric weight.") Expression weight + @Param(name = "weight", type = { "double", "integer", "long" }, description = "A numeric weight.") Expression weight, + Configuration configuration ) { - this(source, field, Literal.TRUE, weight); + this(source, field, Literal.TRUE, weight, configuration); } - public WeightedAvg(Source source, Expression field, Expression filter, Expression weight) { - super(source, field, filter, List.of(weight)); + public WeightedAvg(Source source, Expression field, Expression filter, Expression weight, Configuration configuration) { + super(source, field, filter, List.of(weight), configuration); this.weight = weight; } @@ -73,7 +75,8 @@ private WeightedAvg(StreamInput in) throws IOException { : Literal.TRUE, in.getTransportVersion().onOrAfter(TransportVersions.ESQL_PER_AGGREGATE_FILTER) ? in.readNamedWriteableCollectionAsList(Expression.class).get(0) - : in.readNamedWriteable(Expression.class) + : in.readNamedWriteable(Expression.class), + ((PlanStreamInput) in).configuration() ); } @@ -132,17 +135,17 @@ public DataType dataType() { @Override protected NodeInfo info() { - return NodeInfo.create(this, WeightedAvg::new, field(), filter(), weight); + return NodeInfo.create(this, WeightedAvg::new, field(), filter(), weight, configuration()); } @Override public WeightedAvg replaceChildren(List newChildren) { - return new WeightedAvg(source(), newChildren.get(0), newChildren.get(1), newChildren.get(2)); + return new WeightedAvg(source(), newChildren.get(0), newChildren.get(1), newChildren.get(2), configuration()); } @Override public WeightedAvg withFilter(Expression filter) { - return new WeightedAvg(source(), field(), filter, weight()); + return new WeightedAvg(source(), field(), filter, weight(), configuration()); } @Override @@ -155,9 +158,9 @@ public Expression surrogate() { return new MvAvg(s, field); } if (weight.foldable()) { - return new Div(s, new Sum(s, field), new Count(s, field), dataType()); + return new Div(s, new Sum(s, field, configuration()), new Count(s, field), dataType()); } else { - return new Div(s, new Sum(s, new Mul(s, field, weight)), new Sum(s, weight), dataType()); + return new Div(s, new Sum(s, new Mul(s, field, weight), configuration()), new Sum(s, weight, configuration()), dataType()); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlFeatures.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlFeatures.java index 266f07d22eaf5..8db74635794bb 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlFeatures.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlFeatures.java @@ -178,6 +178,11 @@ public class EsqlFeatures implements FeatureSpecification { */ public static final NodeFeature RESOLVE_FIELDS_API = new NodeFeature("esql.resolve_fields_api"); + /** + * Overflow handling for Sum aggregation function + */ + public static final NodeFeature FN_SUM_OVERFLOW_HANDLING = new NodeFeature("esql.fn_sum_overflow_handling"); + private Set snapshotBuildFeatures() { assert Build.current().isSnapshot() : Build.current(); return Set.of(METRICS_SYNTAX); @@ -207,7 +212,8 @@ public Set getFeatures() { METADATA_FIELDS, TIMESPAN_ABBREVIATIONS, COUNTER_TYPES, - RESOLVE_FIELDS_API + RESOLVE_FIELDS_API, + FN_SUM_OVERFLOW_HANDLING ); if (Build.current().isSnapshot()) { return Collections.unmodifiableSet(Sets.union(features, snapshotBuildFeatures())); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java index 04e5fdc4b3bd2..8dae3d1022fb9 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.operator.exchange.ExchangeService; +import org.elasticsearch.features.FeatureService; import org.elasticsearch.injection.guice.Inject; import org.elasticsearch.search.SearchService; import org.elasticsearch.tasks.CancellableTask; @@ -62,6 +63,7 @@ public class TransportEsqlQueryAction extends HandledTransportAction listener) { Configuration configuration = new Configuration( + featureService, + clusterService, ZoneOffset.UTC, request.locale() != null ? request.locale() : Locale.US, // TODO: plug-in security diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/Configuration.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/Configuration.java index 4ec2746b24ee4..c707c73c9a0c7 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/Configuration.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/Configuration.java @@ -8,13 +8,17 @@ package org.elasticsearch.xpack.esql.session; import org.elasticsearch.TransportVersions; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.compress.CompressorFactory; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.compute.data.BlockStreamInput; +import org.elasticsearch.features.FeatureService; +import org.elasticsearch.features.NodeFeature; import org.elasticsearch.xpack.esql.Column; +import org.elasticsearch.xpack.esql.plugin.EsqlFeatures; import org.elasticsearch.xpack.esql.plugin.QueryPragmas; import java.io.IOException; @@ -27,6 +31,8 @@ import java.util.Locale; import java.util.Map; import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; import static org.elasticsearch.common.unit.ByteSizeUnit.KB; @@ -53,7 +59,11 @@ public class Configuration implements Writeable { private final Map> tables; private final long queryStartTimeNanos; + private final Set activeEsqlFeatures; + public Configuration( + FeatureService featureService, + ClusterService clusterService, ZoneId zi, Locale locale, String username, @@ -79,6 +89,10 @@ public Configuration( this.tables = tables; assert tables != null; this.queryStartTimeNanos = queryStartTimeNanos; + this.activeEsqlFeatures = new EsqlFeatures().getFeatures().stream() + .filter(f -> featureService.clusterHasFeature(clusterService.state(), f)) + .map(NodeFeature::id) + .collect(Collectors.toSet()); } public Configuration(BlockStreamInput in) throws IOException { @@ -106,6 +120,11 @@ public Configuration(BlockStreamInput in) throws IOException { } else { this.queryStartTimeNanos = -1; } + if (in.getTransportVersion().onOrAfter(TransportVersions.ESQL_CONFIGURATION_WITH_FEATURES)) { + this.activeEsqlFeatures = in.readCollectionAsImmutableSet(StreamInput::readString); + } else { + this.activeEsqlFeatures = Set.of(); + } } @Override @@ -130,6 +149,9 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_CCS_EXECUTION_INFO)) { out.writeLong(queryStartTimeNanos); } + if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_CONFIGURATION_WITH_FEATURES)) { + out.writeCollection(activeEsqlFeatures, StreamOutput::writeString); + } } public ZoneId zoneId() { @@ -198,6 +220,10 @@ public boolean profile() { return profile; } + public boolean clusterHasFeature(NodeFeature feature) { + return activeEsqlFeatures.contains(feature.id()); + } + private static void writeQuery(StreamOutput out, String query) throws IOException { if (query.length() > QUERY_COMPRESS_THRESHOLD_CHARS) { // compare on chars to avoid UTF-8 encoding unless actually required out.writeBoolean(true); From 5d026c2c72a96f973fe1caa17fa16699ba9d3f7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Thu, 31 Oct 2024 16:07:00 +0100 Subject: [PATCH 20/31] Updated tests with config, and fixed serialization and AggregateMapper for Sum. CsvTests passing --- .../xpack/esql/ConfigurationTestUtils.java | 8 ++++- .../xpack/esql/EsqlTestUtils.java | 11 ++++++- .../ConfigurationAggregateFunction.java | 8 ----- .../xpack/esql/planner/AggregateMapper.java | 2 +- .../esql/plugin/TransportEsqlQueryAction.java | 7 ++-- .../xpack/esql/session/Configuration.java | 17 ++++++---- .../function/AbstractFunctionTestCase.java | 13 ++++++++ .../aggregate/AvgSerializationTests.java | 4 +-- .../function/aggregate/AvgTests.java | 2 +- .../aggregate/SumSerializationTests.java | 4 +-- .../function/aggregate/SumTests.java | 2 +- .../WeightedAvgSerializationTests.java | 33 +++++++++++++++++++ .../function/aggregate/WeightedAvgTests.java | 2 +- ...AbstractConfigurationFunctionTestCase.java | 8 ++++- .../function/scalar/string/ToLowerTests.java | 7 +++- .../function/scalar/string/ToUpperTests.java | 4 ++- .../rules/logical/FoldNullTests.java | 9 ++--- .../logical/AggregateSerializationTests.java | 3 +- .../xpack/esql/planner/EvalMapperTests.java | 8 ++++- .../planner/LocalExecutionPlannerTests.java | 8 ++++- .../ConfigurationSerializationTests.java | 8 ++++- 21 files changed, 129 insertions(+), 39 deletions(-) create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvgSerializationTests.java diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/ConfigurationTestUtils.java b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/ConfigurationTestUtils.java index 39e79b33327a9..0d054e6100309 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/ConfigurationTestUtils.java +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/ConfigurationTestUtils.java @@ -18,14 +18,17 @@ import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.lucene.DataPartitioning; import org.elasticsearch.core.Releasables; +import org.elasticsearch.features.NodeFeature; import org.elasticsearch.xpack.esql.action.ParseTables; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.planner.PlannerUtils; +import org.elasticsearch.xpack.esql.plugin.EsqlFeatures; import org.elasticsearch.xpack.esql.plugin.QueryPragmas; import org.elasticsearch.xpack.esql.session.Configuration; import java.util.HashMap; import java.util.Map; +import java.util.stream.Collectors; import static org.apache.lucene.tests.util.LuceneTestCase.random; import static org.apache.lucene.tests.util.LuceneTestCase.randomLocale; @@ -71,7 +74,10 @@ public static Configuration randomConfiguration(String query, Map getNonNull(AggDef aggDef) { private static Stream, Tuple>> typeAndNames(Class clazz) { List types; List extraConfigs = List.of(""); - if (NumericAggregate.class.isAssignableFrom(clazz)) { + if (NumericAggregate.class.isAssignableFrom(clazz) || Sum.class.isAssignableFrom(clazz)) { types = NUMERIC; } else if (Max.class.isAssignableFrom(clazz) || Min.class.isAssignableFrom(clazz)) { types = List.of("Boolean", "Int", "Long", "Double", "Ip", "BytesRef"); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java index 8dae3d1022fb9..a28d70d5e2a1f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java @@ -20,6 +20,7 @@ import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.operator.exchange.ExchangeService; import org.elasticsearch.features.FeatureService; +import org.elasticsearch.features.NodeFeature; import org.elasticsearch.injection.guice.Inject; import org.elasticsearch.search.SearchService; import org.elasticsearch.tasks.CancellableTask; @@ -51,6 +52,7 @@ import java.util.Locale; import java.util.Map; import java.util.concurrent.Executor; +import java.util.stream.Collectors; import static org.elasticsearch.xpack.core.ClientHelper.ASYNC_SEARCH_ORIGIN; @@ -158,8 +160,6 @@ public void execute(EsqlQueryRequest request, EsqlQueryTask task, ActionListener private void innerExecute(Task task, EsqlQueryRequest request, ActionListener listener) { Configuration configuration = new Configuration( - featureService, - clusterService, ZoneOffset.UTC, request.locale() != null ? request.locale() : Locale.US, // TODO: plug-in security @@ -171,7 +171,8 @@ private void innerExecute(Task task, EsqlQueryRequest request, ActionListener activeEsqlFeatures; public Configuration( - FeatureService featureService, - ClusterService clusterService, ZoneId zi, Locale locale, String username, @@ -74,7 +72,8 @@ public Configuration( String query, boolean profile, Map> tables, - long queryStartTimeNanos + long queryStartTimeNanos, + Set activeEsqlFeatures ) { this.zoneId = zi.normalized(); this.now = ZonedDateTime.now(Clock.tick(Clock.system(zoneId), Duration.ofNanos(1))); @@ -89,10 +88,7 @@ public Configuration( this.tables = tables; assert tables != null; this.queryStartTimeNanos = queryStartTimeNanos; - this.activeEsqlFeatures = new EsqlFeatures().getFeatures().stream() - .filter(f -> featureService.clusterHasFeature(clusterService.state(), f)) - .map(NodeFeature::id) - .collect(Collectors.toSet()); + this.activeEsqlFeatures = activeEsqlFeatures; } public Configuration(BlockStreamInput in) throws IOException { @@ -127,6 +123,13 @@ public Configuration(BlockStreamInput in) throws IOException { } } + public static Set calculateActiveClusterFeatures(FeatureService featureService, ClusterService clusterService) { + return new EsqlFeatures().getFeatures().stream() + .filter(f -> featureService.clusterHasFeature(clusterService.state(), f)) + .map(NodeFeature::id) + .collect(Collectors.toSet()); + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeZoneId(zoneId); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java index c05f8e0990b3c..2b1c6a1c7d544 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java @@ -69,6 +69,7 @@ import org.hamcrest.Matcher; import org.junit.After; import org.junit.AfterClass; +import org.junit.Before; import java.io.IOException; import java.lang.reflect.Constructor; @@ -94,6 +95,7 @@ import static java.util.Map.entry; import static org.elasticsearch.compute.data.BlockUtils.toJavaObject; +import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomConfiguration; import static org.elasticsearch.xpack.esql.EsqlTestUtils.randomLiteral; import static org.elasticsearch.xpack.esql.SerializationTestUtils.assertSerialization; import static org.hamcrest.Matchers.either; @@ -134,6 +136,8 @@ public abstract class AbstractFunctionTestCase extends ESTestCase { private static EsqlFunctionRegistry functionRegistry = new EsqlFunctionRegistry().snapshotRegistry(); + private Configuration config; + protected TestCaseSupplier.TestCase testCase; /** @@ -1349,4 +1353,13 @@ private static boolean shouldHideSignature(List argTypes, DataType ret } return false; } + + @Before + public void initConfig() { + config = randomConfiguration("FROM test", Map.of()); + } + + protected Configuration configuration() { + return config; + } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgSerializationTests.java index 52d3128af5c1c..2b8c34c258229 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgSerializationTests.java @@ -14,12 +14,12 @@ public class AvgSerializationTests extends AbstractExpressionSerializationTests { @Override protected Avg createTestInstance() { - return new Avg(randomSource(), randomChild()); + return new Avg(randomSource(), randomChild(), configuration()); } @Override protected Avg mutateInstance(Avg instance) throws IOException { - return new Avg(instance.source(), randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild)); + return new Avg(instance.source(), randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild), instance.configuration()); } @Override diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgTests.java index ac599c7ff05f8..679b492cd65f2 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgTests.java @@ -58,7 +58,7 @@ public static Iterable parameters() { @Override protected Expression build(Source source, List args) { - return new Avg(source, args.get(0)); + return new Avg(source, args.get(0), configuration()); } private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier fieldSupplier) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumSerializationTests.java index 863392f7eb451..5440985895936 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumSerializationTests.java @@ -14,12 +14,12 @@ public class SumSerializationTests extends AbstractExpressionSerializationTests { @Override protected Sum createTestInstance() { - return new Sum(randomSource(), randomChild()); + return new Sum(randomSource(), randomChild(), configuration()); } @Override protected Sum mutateInstance(Sum instance) throws IOException { - return new Sum(instance.source(), randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild)); + return new Sum(instance.source(), randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild), instance.configuration()); } @Override diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumTests.java index 88db21e69dc5f..79059062cdf1c 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumTests.java @@ -82,7 +82,7 @@ public static Iterable parameters() { @Override protected Expression build(Source source, List args) { - return new Sum(source, args.get(0)); + return new Sum(source, args.get(0), configuration()); } private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier fieldSupplier) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvgSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvgSerializationTests.java new file mode 100644 index 0000000000000..a96e8429e2da0 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvgSerializationTests.java @@ -0,0 +1,33 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; + +import java.io.IOException; + +public class WeightedAvgSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected WeightedAvg createTestInstance() { + return new WeightedAvg(randomSource(), randomChild(), randomChild(), configuration()); + } + + @Override + protected WeightedAvg mutateInstance(WeightedAvg instance) throws IOException { + return new WeightedAvg( + instance.source(), + randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild), + randomValueOtherThan(instance.weight(), AbstractExpressionSerializationTests::randomChild), + instance.configuration()); + } + + @Override + protected boolean alwaysEmptySource() { + return true; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvgTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvgTests.java index 2c2ffc97f268c..2c967da7c251d 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvgTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvgTests.java @@ -95,7 +95,7 @@ public static Iterable parameters() { @Override protected Expression build(Source source, List args) { - return new WeightedAvg(source, args.get(0), args.get(1)); + return new WeightedAvg(source, args.get(0), args.get(1), configuration()); } private static TestCaseSupplier makeSupplier( diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/AbstractConfigurationFunctionTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/AbstractConfigurationFunctionTestCase.java index a3a18d7a30b59..cecad24bb0980 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/AbstractConfigurationFunctionTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/AbstractConfigurationFunctionTestCase.java @@ -8,17 +8,20 @@ package org.elasticsearch.xpack.esql.expression.function.scalar; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.features.NodeFeature; import org.elasticsearch.xpack.esql.EsqlTestUtils; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.util.StringUtils; import org.elasticsearch.xpack.esql.expression.function.AbstractScalarFunctionTestCase; +import org.elasticsearch.xpack.esql.plugin.EsqlFeatures; import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; import org.elasticsearch.xpack.esql.plugin.QueryPragmas; import org.elasticsearch.xpack.esql.session.Configuration; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; import static org.elasticsearch.xpack.esql.SerializationTestUtils.assertSerialization; @@ -43,7 +46,10 @@ static Configuration randomConfiguration() { StringUtils.EMPTY, randomBoolean(), Map.of(), - System.nanoTime() + System.nanoTime(), + new EsqlFeatures().getFeatures().stream() + .map(NodeFeature::id) + .collect(Collectors.toSet()) ); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToLowerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToLowerTests.java index 1f564ecb87f1e..69993a89baa49 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToLowerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToLowerTests.java @@ -13,6 +13,7 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.lucene.BytesRefs; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.features.NodeFeature; import org.elasticsearch.xpack.esql.EsqlTestUtils; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; @@ -21,6 +22,7 @@ import org.elasticsearch.xpack.esql.core.util.DateUtils; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; import org.elasticsearch.xpack.esql.expression.function.scalar.AbstractConfigurationFunctionTestCase; +import org.elasticsearch.xpack.esql.plugin.EsqlFeatures; import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; import org.elasticsearch.xpack.esql.plugin.QueryPragmas; import org.elasticsearch.xpack.esql.session.Configuration; @@ -28,7 +30,9 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.function.Supplier; +import java.util.stream.Collectors; import static org.hamcrest.Matchers.equalTo; @@ -69,7 +73,8 @@ private Configuration randomLocaleConfig() { "", false, Map.of(), - System.nanoTime() + System.nanoTime(), + Set.of() ); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToUpperTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToUpperTests.java index 7c136c3bb83c2..9d86c89514abd 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToUpperTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToUpperTests.java @@ -28,6 +28,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.function.Supplier; import static org.hamcrest.Matchers.equalTo; @@ -69,7 +70,8 @@ private Configuration randomLocaleConfig() { "", false, Map.of(), - System.nanoTime() + System.nanoTime(), + Set.of() ); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNullTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNullTests.java index 89117b5d4e729..5b43948a0bed1 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNullTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNullTests.java @@ -65,6 +65,7 @@ import java.util.List; import static org.elasticsearch.xpack.esql.EsqlTestUtils.L; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_CFG; import static org.elasticsearch.xpack.esql.EsqlTestUtils.configuration; import static org.elasticsearch.xpack.esql.EsqlTestUtils.getFieldAttribute; import static org.elasticsearch.xpack.esql.EsqlTestUtils.greaterThanOf; @@ -191,9 +192,9 @@ public void testNullFoldingDoesNotApplyOnAggregate() throws Exception { assertEquals(NULL, rule.rule(conditionalFunction)); } - Avg avg = new Avg(EMPTY, getFieldAttribute("a")); + Avg avg = new Avg(EMPTY, getFieldAttribute("a"), TEST_CFG); assertEquals(avg, rule.rule(avg)); - avg = new Avg(EMPTY, NULL); + avg = new Avg(EMPTY, NULL, TEST_CFG); assertEquals(new Literal(EMPTY, null, DOUBLE), rule.rule(avg)); Count count = new Count(EMPTY, getFieldAttribute("a")); @@ -221,9 +222,9 @@ public void testNullFoldingDoesNotApplyOnAggregate() throws Exception { percentile = new Percentile(EMPTY, NULL, NULL); assertEquals(new Literal(EMPTY, null, DOUBLE), rule.rule(percentile)); - Sum sum = new Sum(EMPTY, getFieldAttribute("a")); + Sum sum = new Sum(EMPTY, getFieldAttribute("a"), TEST_CFG); assertEquals(sum, rule.rule(sum)); - sum = new Sum(EMPTY, NULL); + sum = new Sum(EMPTY, NULL, TEST_CFG); assertEquals(new Literal(EMPTY, null, DOUBLE), rule.rule(sum)); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/AggregateSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/AggregateSerializationTests.java index 01f797491103c..9f998f1735c16 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/AggregateSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/AggregateSerializationTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.esql.plan.logical; +import org.elasticsearch.xpack.esql.EsqlTestUtils; import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; @@ -51,7 +52,7 @@ public static List randomAggregates() { new Literal(randomSource(), randomFrom("ASC", "DESC"), DataType.KEYWORD) ); case 4 -> new Values(randomSource(), FieldAttributeTests.createFieldAttribute(1, true)); - case 5 -> new Sum(randomSource(), FieldAttributeTests.createFieldAttribute(1, true)); + case 5 -> new Sum(randomSource(), FieldAttributeTests.createFieldAttribute(1, true), EsqlTestUtils.TEST_CFG); default -> throw new IllegalArgumentException(); }; result.add(new Alias(randomSource(), randomAlphaOfLength(5), agg)); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/EvalMapperTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/EvalMapperTests.java index 0e09809d16902..03f9287d94a1a 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/EvalMapperTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/EvalMapperTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.util.PageCacheRecycler; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.features.NodeFeature; import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.esql.SerializationTestUtils; @@ -49,6 +50,7 @@ import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThan; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThanOrEqual; +import org.elasticsearch.xpack.esql.plugin.EsqlFeatures; import org.elasticsearch.xpack.esql.session.Configuration; import java.time.Duration; @@ -58,6 +60,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.stream.Collectors; public class EvalMapperTests extends ESTestCase { private static final FieldAttribute DOUBLE1 = field("foo", DataType.DOUBLE); @@ -76,7 +79,10 @@ public class EvalMapperTests extends ESTestCase { StringUtils.EMPTY, false, Map.of(), - System.nanoTime() + System.nanoTime(), + new EsqlFeatures().getFeatures().stream() + .map(NodeFeature::id) + .collect(Collectors.toSet()) ); @ParametersFactory(argumentFormatting = "%1$s") diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java index f60e5384e1a6f..11ed43fdfa854 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java @@ -25,6 +25,7 @@ import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; +import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.IndexMode; import org.elasticsearch.index.cache.query.TrivialQueryCachingPolicy; import org.elasticsearch.index.mapper.MapperServiceTestCase; @@ -40,6 +41,7 @@ import org.elasticsearch.xpack.esql.expression.Order; import org.elasticsearch.xpack.esql.index.EsIndex; import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec; +import org.elasticsearch.xpack.esql.plugin.EsqlFeatures; import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; import org.elasticsearch.xpack.esql.plugin.QueryPragmas; import org.elasticsearch.xpack.esql.session.Configuration; @@ -50,6 +52,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.lessThanOrEqualTo; @@ -162,7 +165,10 @@ private Configuration config() { StringUtils.EMPTY, false, Map.of(), - System.nanoTime() + System.nanoTime(), + new EsqlFeatures().getFeatures().stream() + .map(NodeFeature::id) + .collect(Collectors.toSet()) ); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/session/ConfigurationSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/session/ConfigurationSerializationTests.java index 1f35bb5312b20..15c19defee4d0 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/session/ConfigurationSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/session/ConfigurationSerializationTests.java @@ -18,13 +18,16 @@ import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BlockStreamInput; import org.elasticsearch.core.Releasables; +import org.elasticsearch.features.NodeFeature; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.esql.Column; +import org.elasticsearch.xpack.esql.plugin.EsqlFeatures; import org.elasticsearch.xpack.esql.plugin.QueryPragmas; import java.time.ZoneId; import java.util.Locale; import java.util.Map; +import java.util.stream.Collectors; import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomConfiguration; import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomTables; @@ -103,7 +106,10 @@ protected Configuration mutateInstance(Configuration in) { query, profile, tables, - System.nanoTime() + System.nanoTime(), + new EsqlFeatures().getFeatures().stream() + .map(NodeFeature::id) + .collect(Collectors.toSet()) ); } From d207d6864cc4c1e4ac685afb934f82613d8eaad9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Thu, 31 Oct 2024 16:13:58 +0100 Subject: [PATCH 21/31] Added old aggregator, AggregatorMapper change remaining --- .../OverflowingSumLongAggregatorFunction.java | 181 ++++++++++++++ ...wingSumLongAggregatorFunctionSupplier.java | 39 ++++ ...wingSumLongGroupingAggregatorFunction.java | 221 ++++++++++++++++++ .../OverflowingSumLongAggregator.java | 38 +++ .../expression/function/aggregate/Sum.java | 5 +- 5 files changed, 482 insertions(+), 2 deletions(-) create mode 100644 x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregatorFunction.java create mode 100644 x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregatorFunctionSupplier.java create mode 100644 x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/OverflowingSumLongGroupingAggregatorFunction.java create mode 100644 x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregator.java diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregatorFunction.java new file mode 100644 index 0000000000000..87c9f073653b3 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregatorFunction.java @@ -0,0 +1,181 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link OverflowingSumLongAggregator}. + * This class is generated. Do not edit it. + */ +public final class OverflowingSumLongAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("sum", ElementType.LONG), + new IntermediateStateDesc("seen", ElementType.BOOLEAN) ); + + private final DriverContext driverContext; + + private final LongState state; + + private final List channels; + + public OverflowingSumLongAggregatorFunction(DriverContext driverContext, List channels, + LongState state) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + } + + public static OverflowingSumLongAggregatorFunction create(DriverContext driverContext, + List channels) { + return new OverflowingSumLongAggregatorFunction(driverContext, channels, new LongState(OverflowingSumLongAggregator.init())); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + return; + } + if (mask.allTrue()) { + // No masking + LongBlock block = page.getBlock(channels.get(0)); + LongVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector); + } else { + addRawBlock(block); + } + return; + } + // Some positions masked away, others kept + LongBlock block = page.getBlock(channels.get(0)); + LongVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector, mask); + } else { + addRawBlock(block, mask); + } + } + + private void addRawVector(LongVector vector) { + state.seen(true); + for (int i = 0; i < vector.getPositionCount(); i++) { + state.longValue(OverflowingSumLongAggregator.combine(state.longValue(), vector.getLong(i))); + } + } + + private void addRawVector(LongVector vector, BooleanVector mask) { + state.seen(true); + for (int i = 0; i < vector.getPositionCount(); i++) { + if (mask.getBoolean(i) == false) { + continue; + } + state.longValue(OverflowingSumLongAggregator.combine(state.longValue(), vector.getLong(i))); + } + } + + private void addRawBlock(LongBlock block) { + for (int p = 0; p < block.getPositionCount(); p++) { + if (block.isNull(p)) { + continue; + } + state.seen(true); + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + state.longValue(OverflowingSumLongAggregator.combine(state.longValue(), block.getLong(i))); + } + } + } + + private void addRawBlock(LongBlock block, BooleanVector mask) { + for (int p = 0; p < block.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + if (block.isNull(p)) { + continue; + } + state.seen(true); + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + state.longValue(OverflowingSumLongAggregator.combine(state.longValue(), block.getLong(i))); + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block sumUncast = page.getBlock(channels.get(0)); + if (sumUncast.areAllValuesNull()) { + return; + } + LongVector sum = ((LongBlock) sumUncast).asVector(); + assert sum.getPositionCount() == 1; + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert seen.getPositionCount() == 1; + if (seen.getBoolean(0)) { + state.longValue(OverflowingSumLongAggregator.combine(state.longValue(), sum.getLong(0))); + state.seen(true); + } + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + if (state.seen() == false) { + blocks[offset] = driverContext.blockFactory().newConstantNullBlock(1); + return; + } + blocks[offset] = driverContext.blockFactory().newConstantLongBlockWith(state.longValue(), 1); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..8ba67ff08568c --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregatorFunctionSupplier.java @@ -0,0 +1,39 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link OverflowingSumLongAggregator}. + * This class is generated. Do not edit it. + */ +public final class OverflowingSumLongAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final List channels; + + public OverflowingSumLongAggregatorFunctionSupplier(List channels) { + this.channels = channels; + } + + @Override + public OverflowingSumLongAggregatorFunction aggregator(DriverContext driverContext) { + return OverflowingSumLongAggregatorFunction.create(driverContext, channels); + } + + @Override + public OverflowingSumLongGroupingAggregatorFunction groupingAggregator( + DriverContext driverContext) { + return OverflowingSumLongGroupingAggregatorFunction.create(channels, driverContext); + } + + @Override + public String describe() { + return "overflowing_sum of longs"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/OverflowingSumLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/OverflowingSumLongGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..fd2a2895a25c6 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/OverflowingSumLongGroupingAggregatorFunction.java @@ -0,0 +1,221 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link OverflowingSumLongAggregator}. + * This class is generated. Do not edit it. + */ +public final class OverflowingSumLongGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("sum", ElementType.LONG), + new IntermediateStateDesc("seen", ElementType.BOOLEAN) ); + + private final LongArrayState state; + + private final List channels; + + private final DriverContext driverContext; + + public OverflowingSumLongGroupingAggregatorFunction(List channels, LongArrayState state, + DriverContext driverContext) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + } + + public static OverflowingSumLongGroupingAggregatorFunction create(List channels, + DriverContext driverContext) { + return new OverflowingSumLongGroupingAggregatorFunction(channels, new LongArrayState(driverContext.bigArrays(), OverflowingSumLongAggregator.init()), driverContext); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + Page page) { + LongBlock valuesBlock = page.getBlock(channels.get(0)); + LongVector valuesVector = valuesBlock.asVector(); + if (valuesVector == null) { + if (valuesBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + state.set(groupId, OverflowingSumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(v))); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, LongVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + state.set(groupId, OverflowingSumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(groupPosition + positionOffset))); + } + } + + private void addRawInput(int positionOffset, IntBlock groups, LongBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + state.set(groupId, OverflowingSumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(v))); + } + } + } + } + + private void addRawInput(int positionOffset, IntBlock groups, LongVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + state.set(groupId, OverflowingSumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(groupPosition + positionOffset))); + } + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block sumUncast = page.getBlock(channels.get(0)); + if (sumUncast.areAllValuesNull()) { + return; + } + LongVector sum = ((LongBlock) sumUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert sum.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, OverflowingSumLongAggregator.combine(state.getOrDefault(groupId), sum.getLong(groupPosition + positionOffset))); + } + } + } + + @Override + public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { + if (input.getClass() != getClass()) { + throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); + } + LongArrayState inState = ((OverflowingSumLongGroupingAggregatorFunction) input).state; + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + if (inState.hasValue(position)) { + state.set(groupId, OverflowingSumLongAggregator.combine(state.getOrDefault(groupId), inState.get(position))); + } + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + DriverContext driverContext) { + blocks[offset] = state.toValuesBlock(selected, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregator.java new file mode 100644 index 0000000000000..6444be2cc2c52 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregator.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; + +/** + * Sum long aggregator for compatibility with old versions. + *

+ * Replaced by {@link org.elasticsearch.compute.aggregation.SumLongAggregator} since {@code EsqlFeatures#FN_SUM_OVERFLOW_HANDLING}. + *

+ *

+ * Can't be removed, as the new aggregator's layout is different. + *

+ */ +@Aggregator( + value = { + @IntermediateState(name = "sum", type = "LONG"), + @IntermediateState(name = "seen", type = "BOOLEAN") } +) +@GroupingAggregator +class OverflowingSumLongAggregator { + + public static long init() { + return 0; + } + + public static long combine(long current, long v) { + return Math.addExact(current, v); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java index 7f5fd7c6ba869..6be59c1c876e3 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java @@ -9,6 +9,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.OverflowingSumLongAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.SumDoubleAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.SumIntAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.SumLongAggregatorFunctionSupplier; @@ -116,9 +117,9 @@ public DataType dataType() { public final AggregatorFunctionSupplier supplier(List inputChannels) { DataType type = field().dataType(); if (type == DataType.LONG) { - // Old Sum without overflow handling + // Old aggregator without overflow handling if (configuration().clusterHasFeature(EsqlFeatures.FN_SUM_OVERFLOW_HANDLING) == false) { - // return new SumLongAggregatorFunctionSupplier(inputChannels); + return new OverflowingSumLongAggregatorFunctionSupplier(inputChannels); } var location = source().source(); return new SumLongAggregatorFunctionSupplier(location.getLineNumber(), location.getColumnNumber(), source().text(), inputChannels); From fd69d052575f590ce7fc8954dd5e598d57b58ab6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Thu, 31 Oct 2024 18:12:34 +0100 Subject: [PATCH 22/31] Added aggregator extra tests, and AggregateMapper updates to allow for any special aggregator --- ...flowingSumLongAggregatorFunctionTests.java | 89 +++++++++++++++++++ ...umLongGroupingAggregatorFunctionTests.java | 75 ++++++++++++++++ .../expression/function/aggregate/Sum.java | 17 +++- .../xpack/esql/planner/AggregateMapper.java | 11 ++- .../esql/planner/ToIntermediateState.java | 28 ++++++ 5 files changed, 216 insertions(+), 4 deletions(-) create mode 100644 x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregatorFunctionTests.java create mode 100644 x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/OverflowingSumLongGroupingAggregatorFunctionTests.java create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/ToIntermediateState.java diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregatorFunctionTests.java new file mode 100644 index 0000000000000..603bea194891a --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregatorFunctionTests.java @@ -0,0 +1,89 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.collect.Iterators; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.CannedSourceOperator; +import org.elasticsearch.compute.operator.Driver; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.PageConsumerOperator; +import org.elasticsearch.compute.operator.SequenceLongBlockSourceOperator; +import org.elasticsearch.compute.operator.SourceOperator; +import org.elasticsearch.compute.operator.TestResultPageSinkOperator; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.LongStream; + +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; + +public class OverflowingSumLongAggregatorFunctionTests extends AggregatorFunctionTestCase { + @Override + protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { + long max = randomLongBetween(1, Long.MAX_VALUE / size); + return new SequenceLongBlockSourceOperator(blockFactory, LongStream.range(0, size).map(l -> randomLongBetween(-max, max))); + } + + @Override + protected AggregatorFunctionSupplier aggregatorFunction(List inputChannels) { + return new OverflowingSumLongAggregatorFunctionSupplier(inputChannels); + } + + @Override + protected String expectedDescriptionOfAggregator() { + return "overflowing_sum of longs"; + } + + @Override + public void assertSimpleOutput(List input, Block result) { + long sum = input.stream().flatMapToLong(b -> allLongs(b)).sum(); + assertThat(((LongBlock) result).getLong(0), equalTo(sum)); + } + + public void testOverflowFails() { + DriverContext driverContext = driverContext(); + + assertThrows(ArithmeticException.class, () -> { + try ( + Driver d = new Driver( + driverContext, + new SequenceLongBlockSourceOperator(driverContext.blockFactory(), LongStream.of(Long.MAX_VALUE - 1, 2)), + List.of(simple().get(driverContext)), + new TestResultPageSinkOperator(r -> {}), + () -> {} + ) + ) { + runDriver(d); + } + }); + + assertDriverContext(driverContext); + } + + public void testRejectsDouble() { + DriverContext driverContext = driverContext(); + BlockFactory blockFactory = driverContext.blockFactory(); + try ( + Driver d = new Driver( + driverContext, + new CannedSourceOperator(Iterators.single(new Page(blockFactory.newDoubleArrayVector(new double[] { 1.0 }, 1).asBlock()))), + List.of(simple().get(driverContext)), + new PageConsumerOperator(page -> fail("shouldn't have made it this far")), + () -> {} + ) + ) { + expectThrows(Exception.class, () -> runDriver(d)); // ### find a more specific exception type + } + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/OverflowingSumLongGroupingAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/OverflowingSumLongGroupingAggregatorFunctionTests.java new file mode 100644 index 0000000000000..5b0c0f87c3f65 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/OverflowingSumLongGroupingAggregatorFunctionTests.java @@ -0,0 +1,75 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.ConstantBooleanExpressionEvaluator; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BlockTestUtils; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.data.TestBlockFactory; +import org.elasticsearch.compute.operator.CannedSourceOperator; +import org.elasticsearch.compute.operator.Driver; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.compute.operator.SequenceLongBlockSourceOperator; +import org.elasticsearch.compute.operator.SourceOperator; +import org.elasticsearch.compute.operator.TestResultPageSinkOperator; +import org.elasticsearch.compute.operator.TupleBlockSourceOperator; +import org.elasticsearch.core.Tuple; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.LongStream; + +import static org.hamcrest.Matchers.equalTo; + +public class OverflowingSumLongGroupingAggregatorFunctionTests extends GroupingAggregatorFunctionTestCase { + @Override + protected AggregatorFunctionSupplier aggregatorFunction(List inputChannels) { + return new OverflowingSumLongAggregatorFunctionSupplier(inputChannels); + } + + @Override + protected String expectedDescriptionOfAggregator() { + return "overflowing_sum of longs"; + } + + @Override + protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { + long max = randomLongBetween(1, Long.MAX_VALUE / size / 5); + return new TupleBlockSourceOperator( + blockFactory, + LongStream.range(0, size).mapToObj(l -> Tuple.tuple(randomLongBetween(0, 4), randomLongBetween(-max, max))) + ); + } + + @Override + public void assertSimpleGroup(List input, Block result, int position, Long group) { + long sum = input.stream().flatMapToLong(p -> allLongs(p, group)).sum(); + assertThat(((LongBlock) result).getLong(position), equalTo(sum)); + } + + public void testOverflowFails() { + DriverContext driverContext = driverContext(); + + assertThrows(ArithmeticException.class, () -> { + Operator.OperatorFactory factory = simpleWithMode(AggregatorMode.SINGLE); + List input = CannedSourceOperator.collectPages( + new TupleBlockSourceOperator( + driverContext.blockFactory(), + LongStream.range(0, 10).mapToObj(l -> Tuple.tuple(randomLongBetween(0, 4), Long.MAX_VALUE - 1)) + ) + ); + drive(factory.get(driverContext), input.iterator(), driverContext); + }); + + assertDriverContext(driverContext); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java index 6be59c1c876e3..5ecba8565e6ea 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java @@ -9,7 +9,10 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.IntermediateStateDesc; +import org.elasticsearch.compute.aggregation.OverflowingSumLongAggregatorFunction; import org.elasticsearch.compute.aggregation.OverflowingSumLongAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.OverflowingSumLongGroupingAggregatorFunction; import org.elasticsearch.compute.aggregation.SumDoubleAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.SumIntAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.SumLongAggregatorFunctionSupplier; @@ -27,6 +30,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvSum; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul; import org.elasticsearch.xpack.esql.planner.ToAggregator; +import org.elasticsearch.xpack.esql.planner.ToIntermediateState; import org.elasticsearch.xpack.esql.plugin.EsqlFeatures; import org.elasticsearch.xpack.esql.session.Configuration; @@ -43,7 +47,7 @@ /** * Sum all values of a field in matching documents. */ -public class Sum extends ConfigurationAggregateFunction implements ToAggregator, SurrogateExpression { +public class Sum extends ConfigurationAggregateFunction implements ToAggregator, ToIntermediateState, SurrogateExpression { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Sum", Sum::new); @FunctionInfo( @@ -133,6 +137,17 @@ public final AggregatorFunctionSupplier supplier(List inputChannels) { throw EsqlIllegalArgumentException.illegalDataType(type); } + @Override + public List intermediateState(boolean grouping) { + DataType type = field().dataType(); + if (type == DataType.LONG && configuration().clusterHasFeature(EsqlFeatures.FN_SUM_OVERFLOW_HANDLING) == false) { + return grouping + ? OverflowingSumLongGroupingAggregatorFunction.intermediateStateDesc() + : OverflowingSumLongAggregatorFunction.intermediateStateDesc(); + } + return null; + } + @Override public Expression surrogate() { var s = source(); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java index b09247d415f1b..fef9953ad2059 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java @@ -131,11 +131,16 @@ private Stream map(Expression aggregate, boolean grouping) { } private static List computeEntryForAgg(Expression aggregate, boolean grouping) { + if (aggregate instanceof ToIntermediateState intermediateStateGetter) { + var intermediateStates = intermediateStateGetter.intermediateState(grouping); + if (intermediateStates != null) { + return isToNE(intermediateStates).toList(); + } + } var aggDef = aggDefOrNull(aggregate, grouping); if (aggDef != null) { - var is = getNonNull(aggDef); - var exp = isToNE(is).toList(); - return exp; + var intermediateStates = getNonNull(aggDef); + return isToNE(intermediateStates).toList(); } if (aggregate instanceof FieldAttribute || aggregate instanceof MetadataAttribute || aggregate instanceof ReferenceAttribute) { // This condition is a little pedantic, but do we expected other expressions here? if so, then add them diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/ToIntermediateState.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/ToIntermediateState.java new file mode 100644 index 0000000000000..fe0a1d8932f80 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/ToIntermediateState.java @@ -0,0 +1,28 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.planner; + +import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.IntermediateStateDesc; + +import java.util.List; + +/** + * Expressions that have a mapping to {@link org.elasticsearch.compute.aggregation.IntermediateStateDesc}s. + */ +public interface ToIntermediateState { + /** + * Returns the intermediate state descriptions for this expression. + *

+ * If null, the default method of {@link AggregateMapper} will be used to get them. + *

+ */ + default List intermediateState(boolean grouping) { + return null; + } +} From 5b525f41b2b23f709abe81dcb5d64301c6a22b0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Thu, 31 Oct 2024 18:15:07 +0100 Subject: [PATCH 23/31] Format --- .../aggregation/OverflowingSumLongAggregator.java | 6 +----- .../OverflowingSumLongAggregatorFunctionTests.java | 3 --- ...owingSumLongGroupingAggregatorFunctionTests.java | 7 ------- .../xpack/esql/ConfigurationTestUtils.java | 4 +--- .../org/elasticsearch/xpack/esql/EsqlTestUtils.java | 7 +------ .../expression/function/EsqlFunctionRegistry.java | 2 ++ .../esql/expression/function/aggregate/Avg.java | 6 +++++- .../aggregate/ConfigurationAggregateFunction.java | 13 +++++++------ .../esql/expression/function/aggregate/Sum.java | 7 ++++++- .../xpack/esql/planner/ToIntermediateState.java | 1 - .../xpack/esql/plugin/TransportEsqlQueryAction.java | 2 -- .../xpack/esql/session/Configuration.java | 3 ++- .../function/aggregate/AvgSerializationTests.java | 6 +++++- .../function/aggregate/SumSerializationTests.java | 6 +++++- .../aggregate/WeightedAvgSerializationTests.java | 3 ++- .../AbstractConfigurationFunctionTestCase.java | 4 +--- .../function/scalar/string/ToLowerTests.java | 3 --- .../xpack/esql/planner/EvalMapperTests.java | 4 +--- .../esql/planner/LocalExecutionPlannerTests.java | 4 +--- .../session/ConfigurationSerializationTests.java | 4 +--- 20 files changed, 41 insertions(+), 54 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregator.java index 6444be2cc2c52..8253f463694d3 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregator.java @@ -20,11 +20,7 @@ * Can't be removed, as the new aggregator's layout is different. *

*/ -@Aggregator( - value = { - @IntermediateState(name = "sum", type = "LONG"), - @IntermediateState(name = "seen", type = "BOOLEAN") } -) +@Aggregator(value = { @IntermediateState(name = "sum", type = "LONG"), @IntermediateState(name = "seen", type = "BOOLEAN") }) @GroupingAggregator class OverflowingSumLongAggregator { diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregatorFunctionTests.java index 603bea194891a..79458fc915ce7 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregatorFunctionTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregatorFunctionTests.java @@ -20,12 +20,9 @@ import org.elasticsearch.compute.operator.SourceOperator; import org.elasticsearch.compute.operator.TestResultPageSinkOperator; -import java.util.ArrayList; import java.util.List; import java.util.stream.LongStream; -import static org.hamcrest.Matchers.contains; -import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; public class OverflowingSumLongAggregatorFunctionTests extends AggregatorFunctionTestCase { diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/OverflowingSumLongGroupingAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/OverflowingSumLongGroupingAggregatorFunctionTests.java index 5b0c0f87c3f65..24345b35266ff 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/OverflowingSumLongGroupingAggregatorFunctionTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/OverflowingSumLongGroupingAggregatorFunctionTests.java @@ -7,24 +7,17 @@ package org.elasticsearch.compute.aggregation; -import org.elasticsearch.compute.ConstantBooleanExpressionEvaluator; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; -import org.elasticsearch.compute.data.BlockTestUtils; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.Page; -import org.elasticsearch.compute.data.TestBlockFactory; import org.elasticsearch.compute.operator.CannedSourceOperator; -import org.elasticsearch.compute.operator.Driver; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.Operator; -import org.elasticsearch.compute.operator.SequenceLongBlockSourceOperator; import org.elasticsearch.compute.operator.SourceOperator; -import org.elasticsearch.compute.operator.TestResultPageSinkOperator; import org.elasticsearch.compute.operator.TupleBlockSourceOperator; import org.elasticsearch.core.Tuple; -import java.util.ArrayList; import java.util.List; import java.util.stream.LongStream; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/ConfigurationTestUtils.java b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/ConfigurationTestUtils.java index 0d054e6100309..8d916ebfee01a 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/ConfigurationTestUtils.java +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/ConfigurationTestUtils.java @@ -75,9 +75,7 @@ public static Configuration randomConfiguration(String query, Map { private static BiFunction uni(BiFunction function) { return function; } + private static UnaryConfigurationAwareBuilder uniConfig(UnaryConfigurationAwareBuilder function) { return function; } @@ -942,6 +943,7 @@ private static UnaryConfigurationAwareBuilder uniConfig( private static BinaryBuilder bi(BinaryBuilder function) { return function; } + private static BinaryConfigurationAwareBuilder biConfig(BinaryConfigurationAwareBuilder function) { return function; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Avg.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Avg.java index 2e4234ce49244..7c64431c2a8f4 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Avg.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Avg.java @@ -46,7 +46,11 @@ public class Avg extends ConfigurationAggregateFunction implements SurrogateExpr tag = "docsStatsAvgNestedExpression" ) } ) - public Avg(Source source, @Param(name = "number", type = { "double", "integer", "long" }) Expression field, Configuration configuration) { + public Avg( + Source source, + @Param(name = "number", type = { "double", "integer", "long" }) Expression field, + Configuration configuration + ) { this(source, field, Literal.TRUE, configuration); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/ConfigurationAggregateFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/ConfigurationAggregateFunction.java index 117e2dc4e20fc..1693cb61ffc6b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/ConfigurationAggregateFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/ConfigurationAggregateFunction.java @@ -7,15 +7,10 @@ package org.elasticsearch.xpack.esql.expression.function.aggregate; -import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.compute.data.BlockStreamInput; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.Source; -import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; -import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; import org.elasticsearch.xpack.esql.session.Configuration; import java.io.IOException; @@ -31,7 +26,13 @@ public abstract class ConfigurationAggregateFunction extends AggregateFunction { this.configuration = configuration; } - ConfigurationAggregateFunction(Source source, Expression field, Expression filter, List parameters, Configuration configuration) { + ConfigurationAggregateFunction( + Source source, + Expression field, + Expression filter, + List parameters, + Configuration configuration + ) { super(source, field, filter, parameters); this.configuration = configuration; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java index 5ecba8565e6ea..73e919373ef47 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java @@ -126,7 +126,12 @@ public final AggregatorFunctionSupplier supplier(List inputChannels) { return new OverflowingSumLongAggregatorFunctionSupplier(inputChannels); } var location = source().source(); - return new SumLongAggregatorFunctionSupplier(location.getLineNumber(), location.getColumnNumber(), source().text(), inputChannels); + return new SumLongAggregatorFunctionSupplier( + location.getLineNumber(), + location.getColumnNumber(), + source().text(), + inputChannels + ); } if (type == DataType.INTEGER) { return new SumIntAggregatorFunctionSupplier(inputChannels); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/ToIntermediateState.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/ToIntermediateState.java index fe0a1d8932f80..5591ed439f386 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/ToIntermediateState.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/ToIntermediateState.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.esql.planner; -import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.IntermediateStateDesc; import java.util.List; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java index a28d70d5e2a1f..85727935a348d 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java @@ -20,7 +20,6 @@ import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.operator.exchange.ExchangeService; import org.elasticsearch.features.FeatureService; -import org.elasticsearch.features.NodeFeature; import org.elasticsearch.injection.guice.Inject; import org.elasticsearch.search.SearchService; import org.elasticsearch.tasks.CancellableTask; @@ -52,7 +51,6 @@ import java.util.Locale; import java.util.Map; import java.util.concurrent.Executor; -import java.util.stream.Collectors; import static org.elasticsearch.xpack.core.ClientHelper.ASYNC_SEARCH_ORIGIN; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/Configuration.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/Configuration.java index 40b5a5f49785c..4f83794257510 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/Configuration.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/Configuration.java @@ -124,7 +124,8 @@ public Configuration(BlockStreamInput in) throws IOException { } public static Set calculateActiveClusterFeatures(FeatureService featureService, ClusterService clusterService) { - return new EsqlFeatures().getFeatures().stream() + return new EsqlFeatures().getFeatures() + .stream() .filter(f -> featureService.clusterHasFeature(clusterService.state(), f)) .map(NodeFeature::id) .collect(Collectors.toSet()); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgSerializationTests.java index 2b8c34c258229..d010930177881 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgSerializationTests.java @@ -19,7 +19,11 @@ protected Avg createTestInstance() { @Override protected Avg mutateInstance(Avg instance) throws IOException { - return new Avg(instance.source(), randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild), instance.configuration()); + return new Avg( + instance.source(), + randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild), + instance.configuration() + ); } @Override diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumSerializationTests.java index 5440985895936..5b50c748f9a4a 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumSerializationTests.java @@ -19,7 +19,11 @@ protected Sum createTestInstance() { @Override protected Sum mutateInstance(Sum instance) throws IOException { - return new Sum(instance.source(), randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild), instance.configuration()); + return new Sum( + instance.source(), + randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild), + instance.configuration() + ); } @Override diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvgSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvgSerializationTests.java index a96e8429e2da0..598bf697c6f33 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvgSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvgSerializationTests.java @@ -23,7 +23,8 @@ protected WeightedAvg mutateInstance(WeightedAvg instance) throws IOException { instance.source(), randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild), randomValueOtherThan(instance.weight(), AbstractExpressionSerializationTests::randomChild), - instance.configuration()); + instance.configuration() + ); } @Override diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/AbstractConfigurationFunctionTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/AbstractConfigurationFunctionTestCase.java index cecad24bb0980..852fe0daae111 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/AbstractConfigurationFunctionTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/AbstractConfigurationFunctionTestCase.java @@ -47,9 +47,7 @@ static Configuration randomConfiguration() { randomBoolean(), Map.of(), System.nanoTime(), - new EsqlFeatures().getFeatures().stream() - .map(NodeFeature::id) - .collect(Collectors.toSet()) + new EsqlFeatures().getFeatures().stream().map(NodeFeature::id).collect(Collectors.toSet()) ); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToLowerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToLowerTests.java index 69993a89baa49..400564134d280 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToLowerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToLowerTests.java @@ -13,7 +13,6 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.lucene.BytesRefs; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.features.NodeFeature; import org.elasticsearch.xpack.esql.EsqlTestUtils; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; @@ -22,7 +21,6 @@ import org.elasticsearch.xpack.esql.core.util.DateUtils; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; import org.elasticsearch.xpack.esql.expression.function.scalar.AbstractConfigurationFunctionTestCase; -import org.elasticsearch.xpack.esql.plugin.EsqlFeatures; import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; import org.elasticsearch.xpack.esql.plugin.QueryPragmas; import org.elasticsearch.xpack.esql.session.Configuration; @@ -32,7 +30,6 @@ import java.util.Map; import java.util.Set; import java.util.function.Supplier; -import java.util.stream.Collectors; import static org.hamcrest.Matchers.equalTo; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/EvalMapperTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/EvalMapperTests.java index 03f9287d94a1a..896ee3cad08c9 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/EvalMapperTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/EvalMapperTests.java @@ -80,9 +80,7 @@ public class EvalMapperTests extends ESTestCase { false, Map.of(), System.nanoTime(), - new EsqlFeatures().getFeatures().stream() - .map(NodeFeature::id) - .collect(Collectors.toSet()) + new EsqlFeatures().getFeatures().stream().map(NodeFeature::id).collect(Collectors.toSet()) ); @ParametersFactory(argumentFormatting = "%1$s") diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java index 11ed43fdfa854..65c31f462afb0 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java @@ -166,9 +166,7 @@ private Configuration config() { false, Map.of(), System.nanoTime(), - new EsqlFeatures().getFeatures().stream() - .map(NodeFeature::id) - .collect(Collectors.toSet()) + new EsqlFeatures().getFeatures().stream().map(NodeFeature::id).collect(Collectors.toSet()) ); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/session/ConfigurationSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/session/ConfigurationSerializationTests.java index 15c19defee4d0..a72b15faa80cd 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/session/ConfigurationSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/session/ConfigurationSerializationTests.java @@ -107,9 +107,7 @@ protected Configuration mutateInstance(Configuration in) { profile, tables, System.nanoTime(), - new EsqlFeatures().getFeatures().stream() - .map(NodeFeature::id) - .collect(Collectors.toSet()) + new EsqlFeatures().getFeatures().stream().map(NodeFeature::id).collect(Collectors.toSet()) ); } From bde36ddc961ae44d6e0b625127a1c14584ef137b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Mon, 4 Nov 2024 13:18:00 +0100 Subject: [PATCH 24/31] Added required capability to CSV test on overflows --- .../src/main/resources/stats_sum.csv-spec | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_sum.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_sum.csv-spec index 1bdf63c119554..db68733883608 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_sum.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_sum.csv-spec @@ -1,4 +1,5 @@ sumWithOverflow +required_capability: fn_sum_overflow_handling FROM employees | LIMIT 2 | EVAL x = languages - languages + 9223372036854775807 @@ -12,3 +13,23 @@ warning:Line 4:40: java.lang.ArithmeticException: long overflow overflow:long | underflow:long null | null ; + +sumWithOverflowGrouping +required_capability: fn_sum_overflow_handling +FROM employees +| SORT emp_no +| LIMIT 4 +| EVAL group = emp_no % 2, x = languages - languages + 9223372036854775807 +| STATS overflow = SUM(x), underflow = SUM(-x) by group +| SORT group +; +warning:Line 5:20: evaluation of [SUM(x)] failed, treating result as null. Only first 20 failures recorded. +warning:Line 5:20: java.lang.ArithmeticException: long overflow +warning:Line 5:40: evaluation of [SUM(-x)] failed, treating result as null. Only first 20 failures recorded. +warning:Line 5:40: java.lang.ArithmeticException: long overflow + +overflow:long | underflow:long | group:integer +null | null | 0 +null | null | 1 +; + From 0978eeb9f96b6f95ca5652308b7b6f1566b3b15f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Mon, 4 Nov 2024 13:41:54 +0100 Subject: [PATCH 25/31] Update docs/changelog/116170.yaml --- docs/changelog/116170.yaml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 docs/changelog/116170.yaml diff --git a/docs/changelog/116170.yaml b/docs/changelog/116170.yaml new file mode 100644 index 0000000000000..7190fb70e9522 --- /dev/null +++ b/docs/changelog/116170.yaml @@ -0,0 +1,6 @@ +pr: 116170 +summary: "ESQL: Prevent overflow on SUM using multiple aggregators" +area: ES|QL +type: bug +issues: + - 110443 From 423549de6e517a964b06d885649632695eac458a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Mon, 4 Nov 2024 13:45:53 +0100 Subject: [PATCH 26/31] Remove old branch changelog entry --- docs/changelog/111639.yaml | 7 ------- 1 file changed, 7 deletions(-) delete mode 100644 docs/changelog/111639.yaml diff --git a/docs/changelog/111639.yaml b/docs/changelog/111639.yaml deleted file mode 100644 index 5352d84851a36..0000000000000 --- a/docs/changelog/111639.yaml +++ /dev/null @@ -1,7 +0,0 @@ -pr: 111639 -summary: "ESQL: Add warnings capabilities to aggregators, and prevent overflow on\ - \ SUM aggregation" -area: ES|QL -type: bug -issues: - - 110443 From fb98bec0557a5b13f7011fe9b1882910aa4bcede Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Mon, 4 Nov 2024 14:18:05 +0100 Subject: [PATCH 27/31] Add "final" again to AggregateFunction#writeTo --- .../esql/expression/function/aggregate/AggregateFunction.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java index 5cea17ad6cb4e..f7a74cc2ae93f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java @@ -85,7 +85,7 @@ protected AggregateFunction(StreamInput in) throws IOException { } @Override - public void writeTo(StreamOutput out) throws IOException { + public final void writeTo(StreamOutput out) throws IOException { Source.EMPTY.writeTo(out); out.writeNamedWriteable(field); if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_PER_AGGREGATE_FILTER)) { From abcd246b88ae04f963ca2f1b89904a5ef002736e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Mon, 4 Nov 2024 14:52:01 +0100 Subject: [PATCH 28/31] Fixed serialization of functions with configuration --- .../xpack/esql/session/Configuration.java | 8 ++- ...tractConfigurationAggregationTestCase.java | 67 +++++++++++++++++++ .../function/AbstractFunctionTestCase.java | 11 --- .../function/aggregate/AvgTests.java | 8 ++- .../function/aggregate/SumTests.java | 8 ++- .../function/aggregate/WeightedAvgTests.java | 8 ++- 6 files changed, 88 insertions(+), 22 deletions(-) create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractConfigurationAggregationTestCase.java diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/Configuration.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/Configuration.java index 4f83794257510..eb145b42a9ebd 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/Configuration.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/Configuration.java @@ -266,7 +266,8 @@ public boolean equals(Object o) { && Objects.equals(locale, that.locale) && Objects.equals(that.query, query) && profile == that.profile - && tables.equals(that.tables); + && tables.equals(that.tables) + && activeEsqlFeatures.equals(that.activeEsqlFeatures); } @Override @@ -282,7 +283,8 @@ public int hashCode() { locale, query, profile, - tables + tables, + activeEsqlFeatures ); } @@ -304,6 +306,8 @@ public String toString() { + profile + ", tables=" + tables + + ", activeEsqlFeatures=" + + activeEsqlFeatures + '}'; } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractConfigurationAggregationTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractConfigurationAggregationTestCase.java new file mode 100644 index 0000000000000..4074cbd993c59 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractConfigurationAggregationTestCase.java @@ -0,0 +1,67 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function; + +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.features.NodeFeature; +import org.elasticsearch.xpack.esql.EsqlTestUtils; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.util.StringUtils; +import org.elasticsearch.xpack.esql.plugin.EsqlFeatures; +import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; +import org.elasticsearch.xpack.esql.plugin.QueryPragmas; +import org.elasticsearch.xpack.esql.session.Configuration; + +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.elasticsearch.xpack.esql.SerializationTestUtils.assertSerialization; + +public abstract class AbstractConfigurationAggregationTestCase extends AbstractAggregationTestCase { + protected abstract Expression buildWithConfiguration(Source source, List args, Configuration configuration); + + @Override + protected Expression build(Source source, List args) { + return buildWithConfiguration(source, args, EsqlTestUtils.TEST_CFG); + } + + static Configuration randomConfiguration() { + // TODO: Randomize the query and maybe the pragmas. + return new Configuration( + randomZone(), + randomLocale(random()), + randomBoolean() ? null : randomAlphaOfLength(randomInt(64)), + randomBoolean() ? null : randomAlphaOfLength(randomInt(64)), + QueryPragmas.EMPTY, + EsqlPlugin.QUERY_RESULT_TRUNCATION_MAX_SIZE.getDefault(Settings.EMPTY), + EsqlPlugin.QUERY_RESULT_TRUNCATION_DEFAULT_SIZE.getDefault(Settings.EMPTY), + StringUtils.EMPTY, + randomBoolean(), + Map.of(), + System.nanoTime(), + new EsqlFeatures().getFeatures().stream().map(NodeFeature::id).collect(Collectors.toSet()) + ); + } + + public void testSerializationWithConfiguration() { + Configuration config = randomConfiguration(); + Expression expr = buildWithConfiguration(testCase.getSource(), testCase.getDataAsFields(), config); + + assertSerialization(expr, config); + + Configuration differentConfig; + do { + differentConfig = randomConfiguration(); + } while (config.equals(differentConfig)); + + Expression differentExpr = buildWithConfiguration(testCase.getSource(), testCase.getDataAsFields(), differentConfig); + assertFalse(expr.equals(differentExpr)); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java index 2b1c6a1c7d544..b6bde582115df 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java @@ -136,8 +136,6 @@ public abstract class AbstractFunctionTestCase extends ESTestCase { private static EsqlFunctionRegistry functionRegistry = new EsqlFunctionRegistry().snapshotRegistry(); - private Configuration config; - protected TestCaseSupplier.TestCase testCase; /** @@ -1353,13 +1351,4 @@ private static boolean shouldHideSignature(List argTypes, DataType ret } return false; } - - @Before - public void initConfig() { - config = randomConfiguration("FROM test", Map.of()); - } - - protected Configuration configuration() { - return config; - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgTests.java index 679b492cd65f2..a13401c351ddb 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgTests.java @@ -14,8 +14,10 @@ import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.function.AbstractAggregationTestCase; +import org.elasticsearch.xpack.esql.expression.function.AbstractConfigurationAggregationTestCase; import org.elasticsearch.xpack.esql.expression.function.MultiRowTestCaseSupplier; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import org.elasticsearch.xpack.esql.session.Configuration; import java.util.ArrayList; import java.util.List; @@ -25,7 +27,7 @@ import static org.hamcrest.Matchers.equalTo; -public class AvgTests extends AbstractAggregationTestCase { +public class AvgTests extends AbstractConfigurationAggregationTestCase { public AvgTests(@Name("TestCase") Supplier testCaseSupplier) { this.testCase = testCaseSupplier.get(); } @@ -57,8 +59,8 @@ public static Iterable parameters() { } @Override - protected Expression build(Source source, List args) { - return new Avg(source, args.get(0), configuration()); + protected Expression buildWithConfiguration(Source source, List args, Configuration configuration) { + return new Avg(source, args.get(0), configuration); } private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier fieldSupplier) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumTests.java index 79059062cdf1c..b28e697b0256c 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumTests.java @@ -15,8 +15,10 @@ import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.function.AbstractAggregationTestCase; +import org.elasticsearch.xpack.esql.expression.function.AbstractConfigurationAggregationTestCase; import org.elasticsearch.xpack.esql.expression.function.MultiRowTestCaseSupplier; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import org.elasticsearch.xpack.esql.session.Configuration; import java.util.ArrayList; import java.util.List; @@ -27,7 +29,7 @@ import static org.elasticsearch.xpack.esql.core.type.DataType.UNSIGNED_LONG; import static org.hamcrest.Matchers.equalTo; -public class SumTests extends AbstractAggregationTestCase { +public class SumTests extends AbstractConfigurationAggregationTestCase { public SumTests(@Name("TestCase") Supplier testCaseSupplier) { this.testCase = testCaseSupplier.get(); } @@ -81,8 +83,8 @@ public static Iterable parameters() { } @Override - protected Expression build(Source source, List args) { - return new Sum(source, args.get(0), configuration()); + protected Expression buildWithConfiguration(Source source, List args, Configuration configuration) { + return new Sum(source, args.get(0), configuration); } private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier fieldSupplier) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvgTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvgTests.java index 2c967da7c251d..4cb81c2ddf61f 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvgTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvgTests.java @@ -14,8 +14,10 @@ import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.function.AbstractAggregationTestCase; +import org.elasticsearch.xpack.esql.expression.function.AbstractConfigurationAggregationTestCase; import org.elasticsearch.xpack.esql.expression.function.MultiRowTestCaseSupplier; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import org.elasticsearch.xpack.esql.session.Configuration; import java.util.ArrayList; import java.util.List; @@ -25,7 +27,7 @@ import static org.hamcrest.Matchers.equalTo; -public class WeightedAvgTests extends AbstractAggregationTestCase { +public class WeightedAvgTests extends AbstractConfigurationAggregationTestCase { public WeightedAvgTests(@Name("TestCase") Supplier testCaseSupplier) { this.testCase = testCaseSupplier.get(); } @@ -94,8 +96,8 @@ public static Iterable parameters() { } @Override - protected Expression build(Source source, List args) { - return new WeightedAvg(source, args.get(0), args.get(1), configuration()); + protected Expression buildWithConfiguration(Source source, List args, Configuration configuration) { + return new WeightedAvg(source, args.get(0), args.get(1), configuration); } private static TestCaseSupplier makeSupplier( From b9d66b4faba499fe52557a13025fbaa1e9d9ff00 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Mon, 4 Nov 2024 14:57:46 +0100 Subject: [PATCH 29/31] Format --- .../esql/expression/function/AbstractFunctionTestCase.java | 2 -- .../xpack/esql/expression/function/aggregate/AvgTests.java | 1 - .../xpack/esql/expression/function/aggregate/SumTests.java | 1 - .../esql/expression/function/aggregate/WeightedAvgTests.java | 1 - 4 files changed, 5 deletions(-) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java index b6bde582115df..c05f8e0990b3c 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java @@ -69,7 +69,6 @@ import org.hamcrest.Matcher; import org.junit.After; import org.junit.AfterClass; -import org.junit.Before; import java.io.IOException; import java.lang.reflect.Constructor; @@ -95,7 +94,6 @@ import static java.util.Map.entry; import static org.elasticsearch.compute.data.BlockUtils.toJavaObject; -import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomConfiguration; import static org.elasticsearch.xpack.esql.EsqlTestUtils.randomLiteral; import static org.elasticsearch.xpack.esql.SerializationTestUtils.assertSerialization; import static org.hamcrest.Matchers.either; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgTests.java index a13401c351ddb..b618066a45726 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgTests.java @@ -13,7 +13,6 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; -import org.elasticsearch.xpack.esql.expression.function.AbstractAggregationTestCase; import org.elasticsearch.xpack.esql.expression.function.AbstractConfigurationAggregationTestCase; import org.elasticsearch.xpack.esql.expression.function.MultiRowTestCaseSupplier; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumTests.java index b28e697b0256c..8e84f12b1d14c 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumTests.java @@ -14,7 +14,6 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; -import org.elasticsearch.xpack.esql.expression.function.AbstractAggregationTestCase; import org.elasticsearch.xpack.esql.expression.function.AbstractConfigurationAggregationTestCase; import org.elasticsearch.xpack.esql.expression.function.MultiRowTestCaseSupplier; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvgTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvgTests.java index 4cb81c2ddf61f..7a30d6f499ee1 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvgTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvgTests.java @@ -13,7 +13,6 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; -import org.elasticsearch.xpack.esql.expression.function.AbstractAggregationTestCase; import org.elasticsearch.xpack.esql.expression.function.AbstractConfigurationAggregationTestCase; import org.elasticsearch.xpack.esql.expression.function.MultiRowTestCaseSupplier; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; From 134a0d4f1829c0e9eebcf3794fa8a0f2507111a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Mon, 4 Nov 2024 17:43:03 +0100 Subject: [PATCH 30/31] Fixed PhysicalOptimizerTests not using same config for everything --- .../xpack/esql/session/Configuration.java | 4 ++ .../optimizer/PhysicalPlanOptimizerTests.java | 4 +- .../plan/AbstractNodeSerializationTests.java | 4 +- .../logical/AggregateSerializationTests.java | 2 +- .../esql/plugin/DataNodeRequestTests.java | 46 +++++++++++++++---- 5 files changed, 46 insertions(+), 14 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/Configuration.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/Configuration.java index eb145b42a9ebd..94160c08c4ff6 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/Configuration.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/Configuration.java @@ -224,6 +224,10 @@ public boolean profile() { return profile; } + public Set activeEsqlFeatures() { + return activeEsqlFeatures; + } + public boolean clusterHasFeature(NodeFeature feature) { return activeEsqlFeatures.contains(feature.id()); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java index 3b59a1d176a98..25e592ae89908 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java @@ -217,7 +217,7 @@ public PhysicalPlanOptimizerTests(String name, Configuration config) { @Before public void init() { parser = new EsqlParser(); - logicalOptimizer = new LogicalPlanOptimizer(new LogicalOptimizerContext(EsqlTestUtils.TEST_CFG)); + logicalOptimizer = new LogicalPlanOptimizer(new LogicalOptimizerContext(config)); physicalPlanOptimizer = new PhysicalPlanOptimizer(new PhysicalOptimizerContext(config)); EsqlFunctionRegistry functionRegistry = new EsqlFunctionRegistry(); mapper = new Mapper(); @@ -6588,7 +6588,7 @@ private PhysicalPlan physicalPlan(String query, TestDataSource dataSource) { // System.out.println("Logical\n" + logical); var physical = mapper.map(logical); // System.out.println(physical); - assertSerialization(physical); + assertSerialization(physical, config); return physical; } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/AbstractNodeSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/AbstractNodeSerializationTests.java index e6faa9a253d76..f83df6440c5cf 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/AbstractNodeSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/AbstractNodeSerializationTests.java @@ -35,7 +35,7 @@ public abstract class AbstractNodeSerializationTests> * We use a single random config for all serialization because it's pretty * heavy to build, especially in {@link #testConcurrentSerialization()}. */ - private Configuration config; + private static Configuration config; public static Source randomSource() { int lineNumber = between(0, EXAMPLE_QUERY.length - 1); @@ -77,7 +77,7 @@ protected boolean alwaysEmptySource() { return false; } - public final Configuration configuration() { + public static final Configuration configuration() { return config; } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/AggregateSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/AggregateSerializationTests.java index 9f998f1735c16..8d2d3f5699256 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/AggregateSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/AggregateSerializationTests.java @@ -52,7 +52,7 @@ public static List randomAggregates() { new Literal(randomSource(), randomFrom("ASC", "DESC"), DataType.KEYWORD) ); case 4 -> new Values(randomSource(), FieldAttributeTests.createFieldAttribute(1, true)); - case 5 -> new Sum(randomSource(), FieldAttributeTests.createFieldAttribute(1, true), EsqlTestUtils.TEST_CFG); + case 5 -> new Sum(randomSource(), FieldAttributeTests.createFieldAttribute(1, true), configuration()); default -> throw new IllegalArgumentException(); }; result.add(new Alias(randomSource(), randomAlphaOfLength(5), agg)); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestTests.java index 4553551c40cd3..95b066136a838 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexMode; import org.elasticsearch.index.query.TermQueryBuilder; @@ -33,15 +34,16 @@ import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; import org.elasticsearch.xpack.esql.planner.mapper.Mapper; +import org.elasticsearch.xpack.esql.session.Configuration; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomConfiguration; import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomTables; -import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_CFG; import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER; import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyPolicyResolution; import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping; @@ -79,14 +81,15 @@ protected DataNodeRequest createTestInstance() { | stats x = avg(salary) """); List shardIds = randomList(1, 10, () -> new ShardId("index-" + between(1, 10), "n/a", between(1, 10))); - PhysicalPlan physicalPlan = mapAndMaybeOptimize(parse(query)); + Configuration configuration = randomLimitedConfiguration(query); + PhysicalPlan physicalPlan = mapAndMaybeOptimize(parse(query, configuration), configuration); Map aliasFilters = Map.of( new Index("concrete-index", "n/a"), AliasFilter.of(new TermQueryBuilder("id", "1"), "alias-1") ); DataNodeRequest request = new DataNodeRequest( sessionId, - randomConfiguration(query, randomTables()), + configuration, randomAlphaOfLength(10), shardIds, aliasFilters, @@ -164,7 +167,7 @@ protected DataNodeRequest mutateInstance(DataNodeRequest in) throws IOException in.clusterAlias(), in.shardIds(), in.aliasFilters(), - mapAndMaybeOptimize(parse(newQuery)), + mapAndMaybeOptimize(parse(newQuery, in.configuration()), in.configuration()), in.indices(), in.indicesOptions() ); @@ -260,20 +263,20 @@ protected DataNodeRequest mutateInstance(DataNodeRequest in) throws IOException }; } - static LogicalPlan parse(String query) { + static LogicalPlan parse(String query, Configuration configuration) { Map mapping = loadMapping("mapping-basic.json"); EsIndex test = new EsIndex("test", mapping, Map.of("test", IndexMode.STANDARD)); IndexResolution getIndexResult = IndexResolution.valid(test); - var logicalOptimizer = new LogicalPlanOptimizer(new LogicalOptimizerContext(TEST_CFG)); + var logicalOptimizer = new LogicalPlanOptimizer(new LogicalOptimizerContext(configuration)); var analyzer = new Analyzer( - new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResult, emptyPolicyResolution()), + new AnalyzerContext(configuration, new EsqlFunctionRegistry(), getIndexResult, emptyPolicyResolution()), TEST_VERIFIER ); return logicalOptimizer.optimize(analyzer.analyze(new EsqlParser().createStatement(query))); } - static PhysicalPlan mapAndMaybeOptimize(LogicalPlan logicalPlan) { - var physicalPlanOptimizer = new PhysicalPlanOptimizer(new PhysicalOptimizerContext(TEST_CFG)); + static PhysicalPlan mapAndMaybeOptimize(LogicalPlan logicalPlan, Configuration configuration) { + var physicalPlanOptimizer = new PhysicalPlanOptimizer(new PhysicalOptimizerContext(configuration)); var mapper = new Mapper(); var physical = mapper.map(logicalPlan); if (randomBoolean()) { @@ -286,4 +289,29 @@ static PhysicalPlan mapAndMaybeOptimize(LogicalPlan logicalPlan) { protected List filteredWarnings() { return withDefaultLimitWarning(super.filteredWarnings()); } + + /** + * Builds a configuration with a fixed limit of 10000 rows and 1000 bytes. + *

+ * Without this, warnings would be randomized, and hard to filter from the warnings. + *

+ */ + private Configuration randomLimitedConfiguration(String query) { + Configuration configuration = randomConfiguration(query, randomTables()); + + return new Configuration( + configuration.zoneId(), + configuration.locale(), + configuration.username(), + configuration.clusterName(), + configuration.pragmas(), + 10000, + 1000, + configuration.query(), + configuration.profile(), + configuration.tables(), + configuration.getQueryStartTimeNanos(), + configuration.activeEsqlFeatures() + ); + } } From 20d9adf85edfa4138bcbb5d6167ad0f5cc50bfbe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Mon, 4 Nov 2024 17:50:42 +0100 Subject: [PATCH 31/31] Format --- .../compute/aggregation/OverflowingSumLongAggregator.java | 2 +- .../xpack/esql/plan/logical/AggregateSerializationTests.java | 1 - .../elasticsearch/xpack/esql/plugin/DataNodeRequestTests.java | 3 --- 3 files changed, 1 insertion(+), 5 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregator.java index 8253f463694d3..b3956508d813b 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregator.java @@ -17,7 +17,7 @@ * Replaced by {@link org.elasticsearch.compute.aggregation.SumLongAggregator} since {@code EsqlFeatures#FN_SUM_OVERFLOW_HANDLING}. *

*

- * Can't be removed, as the new aggregator's layout is different. + * Should be kept for as long as we need compatibility with the version this was added on, as the new aggregator's layout is different. *

*/ @Aggregator(value = { @IntermediateState(name = "sum", type = "LONG"), @IntermediateState(name = "seen", type = "BOOLEAN") }) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/AggregateSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/AggregateSerializationTests.java index 8d2d3f5699256..34d32477ab610 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/AggregateSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/AggregateSerializationTests.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.esql.plan.logical; -import org.elasticsearch.xpack.esql.EsqlTestUtils; import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestTests.java index 95b066136a838..19be4581ded73 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestTests.java @@ -11,7 +11,6 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexMode; import org.elasticsearch.index.query.TermQueryBuilder; @@ -19,7 +18,6 @@ import org.elasticsearch.search.SearchModule; import org.elasticsearch.search.internal.AliasFilter; import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.esql.EsqlTestUtils; import org.elasticsearch.xpack.esql.analysis.Analyzer; import org.elasticsearch.xpack.esql.analysis.AnalyzerContext; import org.elasticsearch.xpack.esql.core.type.EsField; @@ -40,7 +38,6 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.stream.Collectors; import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomConfiguration; import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomTables;