From 7594d5a18b7a0ef7f8f97f9ef24a474d931db9c9 Mon Sep 17 00:00:00 2001 From: Larisa Motova Date: Thu, 17 Oct 2024 08:59:35 -1000 Subject: [PATCH 01/20] Add a standard deviation function Uses Welford's online algorithm, as well as the parallel version, to calculate standard deviation. --- .../functions/aggregation-functions.asciidoc | 2 + .../description/std_deviation.asciidoc | 5 + .../functions/examples/std_deviation.asciidoc | 22 ++ .../kibana/definition/std_deviation.json | 50 ++++ .../functions/kibana/docs/std_deviation.md | 11 + .../functions/layout/std_deviation.asciidoc | 15 ++ .../parameters/std_deviation.asciidoc | 6 + .../functions/signature/std_deviation.svg | 1 + .../functions/types/std_deviation.asciidoc | 11 + x-pack/plugin/esql/compute/build.gradle | 21 ++ .../StdDeviationDoubleAggregator.java | 234 +++++++++++++++++ .../StdDeviationFloatAggregator.java | 234 +++++++++++++++++ .../StdDeviationIntAggregator.java | 234 +++++++++++++++++ .../StdDeviationLongAggregator.java | 234 +++++++++++++++++ .../StdDeviationDoubleAggregatorFunction.java | 178 +++++++++++++ ...ationDoubleAggregatorFunctionSupplier.java | 39 +++ ...ationDoubleGroupingAggregatorFunction.java | 224 ++++++++++++++++ .../StdDeviationFloatAggregatorFunction.java | 180 +++++++++++++ ...iationFloatAggregatorFunctionSupplier.java | 39 +++ ...iationFloatGroupingAggregatorFunction.java | 226 ++++++++++++++++ .../StdDeviationIntAggregatorFunction.java | 180 +++++++++++++ ...eviationIntAggregatorFunctionSupplier.java | 38 +++ ...eviationIntGroupingAggregatorFunction.java | 223 ++++++++++++++++ .../StdDeviationLongAggregatorFunction.java | 178 +++++++++++++ ...viationLongAggregatorFunctionSupplier.java | 39 +++ ...viationLongGroupingAggregatorFunction.java | 223 ++++++++++++++++ .../compute/aggregation/WelfordAlgorithm.java | 79 ++++++ .../X-StdDeviationAggregator.java.st | 241 ++++++++++++++++++ .../src/main/resources/stats.csv-spec | 130 ++++++++++ .../xpack/esql/action/EsqlCapabilities.java | 5 + .../function/EsqlFunctionRegistry.java | 2 + .../function/aggregate/AggregateFunction.java | 1 + .../function/aggregate/StdDeviation.java | 103 ++++++++ .../xpack/esql/planner/AggregateMapper.java | 4 +- .../function/aggregate/StdDeviationTests.java | 89 +++++++ 35 files changed, 3500 insertions(+), 1 deletion(-) create mode 100644 docs/reference/esql/functions/description/std_deviation.asciidoc create mode 100644 docs/reference/esql/functions/examples/std_deviation.asciidoc create mode 100644 docs/reference/esql/functions/kibana/definition/std_deviation.json create mode 100644 docs/reference/esql/functions/kibana/docs/std_deviation.md create mode 100644 docs/reference/esql/functions/layout/std_deviation.asciidoc create mode 100644 docs/reference/esql/functions/parameters/std_deviation.asciidoc create mode 100644 docs/reference/esql/functions/signature/std_deviation.svg create mode 100644 docs/reference/esql/functions/types/std_deviation.asciidoc create mode 100644 x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregator.java create mode 100644 x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationFloatAggregator.java create mode 100644 x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationIntAggregator.java create mode 100644 x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationLongAggregator.java create mode 100644 x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregatorFunction.java create mode 100644 x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregatorFunctionSupplier.java create mode 100644 x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationDoubleGroupingAggregatorFunction.java create mode 100644 x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationFloatAggregatorFunction.java create mode 100644 x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationFloatAggregatorFunctionSupplier.java create mode 100644 x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationFloatGroupingAggregatorFunction.java create mode 100644 x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationIntAggregatorFunction.java create mode 100644 x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationIntAggregatorFunctionSupplier.java create mode 100644 x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationIntGroupingAggregatorFunction.java create mode 100644 x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationLongAggregatorFunction.java create mode 100644 x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationLongAggregatorFunctionSupplier.java create mode 100644 x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationLongGroupingAggregatorFunction.java create mode 100644 x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/WelfordAlgorithm.java create mode 100644 x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDeviationAggregator.java.st create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDeviation.java create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDeviationTests.java diff --git a/docs/reference/esql/functions/aggregation-functions.asciidoc b/docs/reference/esql/functions/aggregation-functions.asciidoc index 7cdc42ea6cbf9..7777859e898ce 100644 --- a/docs/reference/esql/functions/aggregation-functions.asciidoc +++ b/docs/reference/esql/functions/aggregation-functions.asciidoc @@ -17,6 +17,7 @@ The <> command supports these aggregate functions: * <> * <> * experimental:[] <> +* <> * <> * <> * <> @@ -32,6 +33,7 @@ include::layout/median_absolute_deviation.asciidoc[] include::layout/min.asciidoc[] include::layout/percentile.asciidoc[] include::layout/st_centroid_agg.asciidoc[] +include::layout/std_deviation.asciidoc[] include::layout/sum.asciidoc[] include::layout/top.asciidoc[] include::layout/values.asciidoc[] diff --git a/docs/reference/esql/functions/description/std_deviation.asciidoc b/docs/reference/esql/functions/description/std_deviation.asciidoc new file mode 100644 index 0000000000000..b78ddd7dbba13 --- /dev/null +++ b/docs/reference/esql/functions/description/std_deviation.asciidoc @@ -0,0 +1,5 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Description* + +The standard deviation of a numeric field. diff --git a/docs/reference/esql/functions/examples/std_deviation.asciidoc b/docs/reference/esql/functions/examples/std_deviation.asciidoc new file mode 100644 index 0000000000000..741f5e886b945 --- /dev/null +++ b/docs/reference/esql/functions/examples/std_deviation.asciidoc @@ -0,0 +1,22 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Examples* + +[source.merge.styled,esql] +---- +include::{esql-specs}/stats.csv-spec[tag=stdev] +---- +[%header.monospaced.styled,format=dsv,separator=|] +|=== +include::{esql-specs}/stats.csv-spec[tag=stdev-result] +|=== +The expression can use inline functions. For example, to calculate the standard deviation of each employee's maximum salary changes, first use `MV_MAX` on each row, and then use `StdDeviation` on the result +[source.merge.styled,esql] +---- +include::{esql-specs}/stats.csv-spec[tag=docsStatsStdDeviationNestedExpression] +---- +[%header.monospaced.styled,format=dsv,separator=|] +|=== +include::{esql-specs}/stats.csv-spec[tag=docsStatsStdDeviationNestedExpression-result] +|=== + diff --git a/docs/reference/esql/functions/kibana/definition/std_deviation.json b/docs/reference/esql/functions/kibana/definition/std_deviation.json new file mode 100644 index 0000000000000..0beb5c8b75ec9 --- /dev/null +++ b/docs/reference/esql/functions/kibana/definition/std_deviation.json @@ -0,0 +1,50 @@ +{ + "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.", + "type" : "agg", + "name" : "std_deviation", + "description" : "The standard deviation of a numeric field.", + "signatures" : [ + { + "params" : [ + { + "name" : "number", + "type" : "double", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "number", + "type" : "integer", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "number", + "type" : "long", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "double" + } + ], + "examples" : [ + "FROM employees\n| STATS STD_DEVIATION(height)", + "FROM employees\n| STATS stdev_salary_change = STD_DEVIATION(MV_MAX(salary_change))" + ], + "preview" : false, + "snapshot_only" : false +} diff --git a/docs/reference/esql/functions/kibana/docs/std_deviation.md b/docs/reference/esql/functions/kibana/docs/std_deviation.md new file mode 100644 index 0000000000000..d3dad54b3c5b4 --- /dev/null +++ b/docs/reference/esql/functions/kibana/docs/std_deviation.md @@ -0,0 +1,11 @@ + + +### STD_DEVIATION +The standard deviation of a numeric field. + +``` +FROM employees +| STATS STD_DEVIATION(height) +``` diff --git a/docs/reference/esql/functions/layout/std_deviation.asciidoc b/docs/reference/esql/functions/layout/std_deviation.asciidoc new file mode 100644 index 0000000000000..93fbcbb87aba6 --- /dev/null +++ b/docs/reference/esql/functions/layout/std_deviation.asciidoc @@ -0,0 +1,15 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +[discrete] +[[esql-std_deviation]] +=== `STD_DEVIATION` + +*Syntax* + +[.text-center] +image::esql/functions/signature/std_deviation.svg[Embedded,opts=inline] + +include::../parameters/std_deviation.asciidoc[] +include::../description/std_deviation.asciidoc[] +include::../types/std_deviation.asciidoc[] +include::../examples/std_deviation.asciidoc[] diff --git a/docs/reference/esql/functions/parameters/std_deviation.asciidoc b/docs/reference/esql/functions/parameters/std_deviation.asciidoc new file mode 100644 index 0000000000000..91c56709d182a --- /dev/null +++ b/docs/reference/esql/functions/parameters/std_deviation.asciidoc @@ -0,0 +1,6 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Parameters* + +`number`:: + diff --git a/docs/reference/esql/functions/signature/std_deviation.svg b/docs/reference/esql/functions/signature/std_deviation.svg new file mode 100644 index 0000000000000..af83594d04871 --- /dev/null +++ b/docs/reference/esql/functions/signature/std_deviation.svg @@ -0,0 +1 @@ +STD_DEVIATION(number) \ No newline at end of file diff --git a/docs/reference/esql/functions/types/std_deviation.asciidoc b/docs/reference/esql/functions/types/std_deviation.asciidoc new file mode 100644 index 0000000000000..273dae4af76c2 --- /dev/null +++ b/docs/reference/esql/functions/types/std_deviation.asciidoc @@ -0,0 +1,11 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Supported types* + +[%header.monospaced.styled,format=dsv,separator=|] +|=== +number | result +double | double +integer | double +long | double +|=== diff --git a/x-pack/plugin/esql/compute/build.gradle b/x-pack/plugin/esql/compute/build.gradle index 49e819b7cdc88..cb36a6e65c77f 100644 --- a/x-pack/plugin/esql/compute/build.gradle +++ b/x-pack/plugin/esql/compute/build.gradle @@ -608,6 +608,27 @@ tasks.named('stringTemplates').configure { it.outputFile = "org/elasticsearch/compute/aggregation/RateDoubleAggregator.java" } + File stdDeviationAggregatorInputFile = file("src/main/java/org/elasticsearch/compute/aggregation/X-StdDeviationAggregator.java.st") + template { + it.properties = intProperties + it.inputFile = stdDeviationAggregatorInputFile + it.outputFile = "org/elasticsearch/compute/aggregation/StdDeviationIntAggregator.java" + } + template { + it.properties = longProperties + it.inputFile = stdDeviationAggregatorInputFile + it.outputFile = "org/elasticsearch/compute/aggregation/StdDeviationLongAggregator.java" + } + template { + it.properties = floatProperties + it.inputFile = stdDeviationAggregatorInputFile + it.outputFile = "org/elasticsearch/compute/aggregation/StdDeviationFloatAggregator.java" + } + template { + it.properties = doubleProperties + it.inputFile = stdDeviationAggregatorInputFile + it.outputFile = "org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregator.java" + } File topAggregatorInputFile = new File("${projectDir}/src/main/java/org/elasticsearch/compute/aggregation/X-TopAggregator.java.st") template { diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregator.java new file mode 100644 index 0000000000000..4b0d1e7d79881 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregator.java @@ -0,0 +1,234 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.ObjectArray; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; + +/** + * A standard deviation aggregation definition for double. + * This class is generated. Edit `X-StdDeviationAggregator.java.st` instead. + */ +@Aggregator( + { + @IntermediateState(name = "mean", type = "DOUBLE"), + @IntermediateState(name = "m2", type = "DOUBLE"), + @IntermediateState(name = "count", type = "LONG") } +) +@GroupingAggregator +public class StdDeviationDoubleAggregator { + + public static StdDeviationDoubleState initSingle() { + return new StdDeviationDoubleState(); + } + + public static void combine(StdDeviationDoubleState state, double value) { + state.add(value); + } + + public static void combineIntermediate(StdDeviationDoubleState state, double mean, double m2, long count) { + state.combine(mean, m2, count); + } + + public static void evaluateIntermediate(StdDeviationDoubleState state, DriverContext driverContext, Block[] blocks, int offset) { + assert blocks.length >= offset + 3; + BlockFactory blockFactory = driverContext.blockFactory(); + blocks[offset + 0] = blockFactory.newConstantDoubleBlockWith(state.mean(), 1); + blocks[offset + 1] = blockFactory.newConstantDoubleBlockWith(state.m2(), 1); + blocks[offset + 2] = blockFactory.newConstantLongBlockWith(state.count(), 1); + } + + public static Block evaluateFinal(StdDeviationDoubleState state, DriverContext driverContext) { + final long count = state.count(); + final double m2 = state.m2(); + if (count == 0 || Double.isFinite(m2) == false) { + return driverContext.blockFactory().newConstantNullBlock(1); + } + return driverContext.blockFactory().newConstantDoubleBlockWith(state.evaluateFinal(), 1); + } + + public static GroupingStdDeviationDoubleState initGrouping(BigArrays bigArrays) { + return new GroupingStdDeviationDoubleState(bigArrays); + } + + public static void combine(GroupingStdDeviationDoubleState current, int groupId, double value) { + current.add(groupId, value); + } + + public static void combineStates( + GroupingStdDeviationDoubleState current, + int groupId, + GroupingStdDeviationDoubleState state, + int statePosition + ) { + var st = state.states.get(statePosition); + if (st != null) { + current.combine(groupId, st.mean(), st.m2(), st.count()); + } + } + + public static void combineIntermediate(GroupingStdDeviationDoubleState state, int groupId, double mean, double m2, long count) { + state.combine(groupId, mean, m2, count); + } + + public static void evaluateIntermediate( + GroupingStdDeviationDoubleState state, + Block[] blocks, + int offset, + IntVector selected, + DriverContext driverContext + ) { + assert blocks.length >= offset + 3 : "blocks=" + blocks.length + ",offset=" + offset; + try ( + var meanBuilder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount()); + var m2Builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount()); + var countBuilder = driverContext.blockFactory().newLongBlockBuilder(selected.getPositionCount()); + ) { + for (int i = 0; i < selected.getPositionCount(); i++) { + final var groupId = selected.getInt(i); + final var st = groupId < state.states.size() ? state.states.get(groupId) : null; + if (st != null) { + meanBuilder.appendDouble(st.mean()); + m2Builder.appendDouble(st.m2()); + countBuilder.appendLong(st.count()); + } else { + meanBuilder.appendNull(); + m2Builder.appendNull(); + countBuilder.appendNull(); + } + } + blocks[offset + 0] = meanBuilder.build(); + blocks[offset + 1] = m2Builder.build(); + blocks[offset + 2] = countBuilder.build(); + } + } + + public static Block evaluateFinal(GroupingStdDeviationDoubleState state, IntVector selected, DriverContext driverContext) { + try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount())) { + for (int i = 0; i < selected.getPositionCount(); i++) { + final var groupId = selected.getInt(i); + final var st = groupId < state.states.size() ? state.states.get(groupId) : null; + if (st != null) { + final var m2 = st.m2(); + if (Double.isFinite(m2) == false) { + builder.appendNull(); + } else { + builder.appendDouble(st.evaluateFinal()); + } + } else { + builder.appendNull(); + } + } + return builder.build(); + } + } + + static final class StdDeviationDoubleState implements AggregatorState { + + private WelfordAlgorithm welfordAlgorithm; + + StdDeviationDoubleState() { + this(0, 0, 0); + } + + StdDeviationDoubleState(double mean, double m2, long count) { + this.welfordAlgorithm = new WelfordAlgorithm(mean, m2, count); + } + + public void add(double value) { + welfordAlgorithm.add(value); + } + + public void combine(double mean, double m2, long count) { + welfordAlgorithm.add(mean, m2, count); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + StdDeviationDoubleAggregator.evaluateIntermediate(this, driverContext, blocks, offset); + } + + @Override + public void close() {} + + public double mean() { + return welfordAlgorithm.mean(); + } + + public double m2() { + return welfordAlgorithm.m2(); + } + + public long count() { + return welfordAlgorithm.count(); + } + + public double evaluateFinal() { + return welfordAlgorithm.evaluate(); + } + } + + static final class GroupingStdDeviationDoubleState implements GroupingAggregatorState { + + private ObjectArray states; + private final BigArrays bigArrays; + + GroupingStdDeviationDoubleState(BigArrays bigArrays) { + this.states = bigArrays.newObjectArray(1); + this.bigArrays = bigArrays; + } + + public void combine(int groupId, double meanValue, double m2Value, long countValue) { + ensureCapacity(groupId); + var state = states.get(groupId); + if (state == null) { + state = new StdDeviationDoubleState(meanValue, m2Value, countValue); + states.set(groupId, state); + } else { + state.combine(meanValue, m2Value, countValue); + } + } + + public void add(int groupId, double value) { + ensureCapacity(groupId); + var state = states.get(groupId); + if (state == null) { + state = new StdDeviationDoubleState(); + states.set(groupId, state); + } + state.add(value); + } + + private void ensureCapacity(int groupId) { + states = bigArrays.grow(states, groupId + 1); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + StdDeviationDoubleAggregator.evaluateIntermediate(this, blocks, offset, selected, driverContext); + } + + @Override + public void close() { + Releasables.close(states); + } + + void enableGroupIdTracking(SeenGroupIds seenGroupIds) { + // noop - we handle the null states inside `toIntermediate` and `evaluateFinal` + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationFloatAggregator.java new file mode 100644 index 0000000000000..de96ce9d5c1ea --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationFloatAggregator.java @@ -0,0 +1,234 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.ObjectArray; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; + +/** + * A standard deviation aggregation definition for float. + * This class is generated. Edit `X-StdDeviationAggregator.java.st` instead. + */ +@Aggregator( + { + @IntermediateState(name = "mean", type = "DOUBLE"), + @IntermediateState(name = "m2", type = "DOUBLE"), + @IntermediateState(name = "count", type = "LONG") } +) +@GroupingAggregator +public class StdDeviationFloatAggregator { + + public static StdDeviationFloatState initSingle() { + return new StdDeviationFloatState(); + } + + public static void combine(StdDeviationFloatState state, float value) { + state.add(value); + } + + public static void combineIntermediate(StdDeviationFloatState state, double mean, double m2, long count) { + state.combine(mean, m2, count); + } + + public static void evaluateIntermediate(StdDeviationFloatState state, DriverContext driverContext, Block[] blocks, int offset) { + assert blocks.length >= offset + 3; + BlockFactory blockFactory = driverContext.blockFactory(); + blocks[offset + 0] = blockFactory.newConstantDoubleBlockWith(state.mean(), 1); + blocks[offset + 1] = blockFactory.newConstantDoubleBlockWith(state.m2(), 1); + blocks[offset + 2] = blockFactory.newConstantLongBlockWith(state.count(), 1); + } + + public static Block evaluateFinal(StdDeviationFloatState state, DriverContext driverContext) { + final long count = state.count(); + final double m2 = state.m2(); + if (count == 0 || Double.isFinite(m2) == false) { + return driverContext.blockFactory().newConstantNullBlock(1); + } + return driverContext.blockFactory().newConstantDoubleBlockWith(state.evaluateFinal(), 1); + } + + public static GroupingStdDeviationFloatState initGrouping(BigArrays bigArrays) { + return new GroupingStdDeviationFloatState(bigArrays); + } + + public static void combine(GroupingStdDeviationFloatState current, int groupId, float value) { + current.add(groupId, value); + } + + public static void combineStates( + GroupingStdDeviationFloatState current, + int groupId, + GroupingStdDeviationFloatState state, + int statePosition + ) { + var st = state.states.get(statePosition); + if (st != null) { + current.combine(groupId, st.mean(), st.m2(), st.count()); + } + } + + public static void combineIntermediate(GroupingStdDeviationFloatState state, int groupId, double mean, double m2, long count) { + state.combine(groupId, mean, m2, count); + } + + public static void evaluateIntermediate( + GroupingStdDeviationFloatState state, + Block[] blocks, + int offset, + IntVector selected, + DriverContext driverContext + ) { + assert blocks.length >= offset + 3 : "blocks=" + blocks.length + ",offset=" + offset; + try ( + var meanBuilder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount()); + var m2Builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount()); + var countBuilder = driverContext.blockFactory().newLongBlockBuilder(selected.getPositionCount()); + ) { + for (int i = 0; i < selected.getPositionCount(); i++) { + final var groupId = selected.getInt(i); + final var st = groupId < state.states.size() ? state.states.get(groupId) : null; + if (st != null) { + meanBuilder.appendDouble(st.mean()); + m2Builder.appendDouble(st.m2()); + countBuilder.appendLong(st.count()); + } else { + meanBuilder.appendNull(); + m2Builder.appendNull(); + countBuilder.appendNull(); + } + } + blocks[offset + 0] = meanBuilder.build(); + blocks[offset + 1] = m2Builder.build(); + blocks[offset + 2] = countBuilder.build(); + } + } + + public static Block evaluateFinal(GroupingStdDeviationFloatState state, IntVector selected, DriverContext driverContext) { + try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount())) { + for (int i = 0; i < selected.getPositionCount(); i++) { + final var groupId = selected.getInt(i); + final var st = groupId < state.states.size() ? state.states.get(groupId) : null; + if (st != null) { + final var m2 = st.m2(); + if (Double.isFinite(m2) == false) { + builder.appendNull(); + } else { + builder.appendDouble(st.evaluateFinal()); + } + } else { + builder.appendNull(); + } + } + return builder.build(); + } + } + + static final class StdDeviationFloatState implements AggregatorState { + + private WelfordAlgorithm welfordAlgorithm; + + StdDeviationFloatState() { + this(0, 0, 0); + } + + StdDeviationFloatState(double mean, double m2, long count) { + this.welfordAlgorithm = new WelfordAlgorithm(mean, m2, count); + } + + public void add(float value) { + welfordAlgorithm.add(value); + } + + public void combine(double mean, double m2, long count) { + welfordAlgorithm.add(mean, m2, count); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + StdDeviationFloatAggregator.evaluateIntermediate(this, driverContext, blocks, offset); + } + + @Override + public void close() {} + + public double mean() { + return welfordAlgorithm.mean(); + } + + public double m2() { + return welfordAlgorithm.m2(); + } + + public long count() { + return welfordAlgorithm.count(); + } + + public double evaluateFinal() { + return welfordAlgorithm.evaluate(); + } + } + + static final class GroupingStdDeviationFloatState implements GroupingAggregatorState { + + private ObjectArray states; + private final BigArrays bigArrays; + + GroupingStdDeviationFloatState(BigArrays bigArrays) { + this.states = bigArrays.newObjectArray(1); + this.bigArrays = bigArrays; + } + + public void combine(int groupId, double meanValue, double m2Value, long countValue) { + ensureCapacity(groupId); + var state = states.get(groupId); + if (state == null) { + state = new StdDeviationFloatState(meanValue, m2Value, countValue); + states.set(groupId, state); + } else { + state.combine(meanValue, m2Value, countValue); + } + } + + public void add(int groupId, float value) { + ensureCapacity(groupId); + var state = states.get(groupId); + if (state == null) { + state = new StdDeviationFloatState(); + states.set(groupId, state); + } + state.add(value); + } + + private void ensureCapacity(int groupId) { + states = bigArrays.grow(states, groupId + 1); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + StdDeviationFloatAggregator.evaluateIntermediate(this, blocks, offset, selected, driverContext); + } + + @Override + public void close() { + Releasables.close(states); + } + + void enableGroupIdTracking(SeenGroupIds seenGroupIds) { + // noop - we handle the null states inside `toIntermediate` and `evaluateFinal` + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationIntAggregator.java new file mode 100644 index 0000000000000..0532d89675c9d --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationIntAggregator.java @@ -0,0 +1,234 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.ObjectArray; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; + +/** + * A standard deviation aggregation definition for int. + * This class is generated. Edit `X-StdDeviationAggregator.java.st` instead. + */ +@Aggregator( + { + @IntermediateState(name = "mean", type = "DOUBLE"), + @IntermediateState(name = "m2", type = "DOUBLE"), + @IntermediateState(name = "count", type = "LONG") } +) +@GroupingAggregator +public class StdDeviationIntAggregator { + + public static StdDeviationIntState initSingle() { + return new StdDeviationIntState(); + } + + public static void combine(StdDeviationIntState state, int value) { + state.add(value); + } + + public static void combineIntermediate(StdDeviationIntState state, double mean, double m2, long count) { + state.combine(mean, m2, count); + } + + public static void evaluateIntermediate(StdDeviationIntState state, DriverContext driverContext, Block[] blocks, int offset) { + assert blocks.length >= offset + 3; + BlockFactory blockFactory = driverContext.blockFactory(); + blocks[offset + 0] = blockFactory.newConstantDoubleBlockWith(state.mean(), 1); + blocks[offset + 1] = blockFactory.newConstantDoubleBlockWith(state.m2(), 1); + blocks[offset + 2] = blockFactory.newConstantLongBlockWith(state.count(), 1); + } + + public static Block evaluateFinal(StdDeviationIntState state, DriverContext driverContext) { + final long count = state.count(); + final double m2 = state.m2(); + if (count == 0 || Double.isFinite(m2) == false) { + return driverContext.blockFactory().newConstantNullBlock(1); + } + return driverContext.blockFactory().newConstantDoubleBlockWith(state.evaluateFinal(), 1); + } + + public static GroupingStdDeviationIntState initGrouping(BigArrays bigArrays) { + return new GroupingStdDeviationIntState(bigArrays); + } + + public static void combine(GroupingStdDeviationIntState current, int groupId, int value) { + current.add(groupId, value); + } + + public static void combineStates( + GroupingStdDeviationIntState current, + int groupId, + GroupingStdDeviationIntState state, + int statePosition + ) { + var st = state.states.get(statePosition); + if (st != null) { + current.combine(groupId, st.mean(), st.m2(), st.count()); + } + } + + public static void combineIntermediate(GroupingStdDeviationIntState state, int groupId, double mean, double m2, long count) { + state.combine(groupId, mean, m2, count); + } + + public static void evaluateIntermediate( + GroupingStdDeviationIntState state, + Block[] blocks, + int offset, + IntVector selected, + DriverContext driverContext + ) { + assert blocks.length >= offset + 3 : "blocks=" + blocks.length + ",offset=" + offset; + try ( + var meanBuilder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount()); + var m2Builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount()); + var countBuilder = driverContext.blockFactory().newLongBlockBuilder(selected.getPositionCount()); + ) { + for (int i = 0; i < selected.getPositionCount(); i++) { + final var groupId = selected.getInt(i); + final var st = groupId < state.states.size() ? state.states.get(groupId) : null; + if (st != null) { + meanBuilder.appendDouble(st.mean()); + m2Builder.appendDouble(st.m2()); + countBuilder.appendLong(st.count()); + } else { + meanBuilder.appendNull(); + m2Builder.appendNull(); + countBuilder.appendNull(); + } + } + blocks[offset + 0] = meanBuilder.build(); + blocks[offset + 1] = m2Builder.build(); + blocks[offset + 2] = countBuilder.build(); + } + } + + public static Block evaluateFinal(GroupingStdDeviationIntState state, IntVector selected, DriverContext driverContext) { + try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount())) { + for (int i = 0; i < selected.getPositionCount(); i++) { + final var groupId = selected.getInt(i); + final var st = groupId < state.states.size() ? state.states.get(groupId) : null; + if (st != null) { + final var m2 = st.m2(); + if (Double.isFinite(m2) == false) { + builder.appendNull(); + } else { + builder.appendDouble(st.evaluateFinal()); + } + } else { + builder.appendNull(); + } + } + return builder.build(); + } + } + + static final class StdDeviationIntState implements AggregatorState { + + private WelfordAlgorithm welfordAlgorithm; + + StdDeviationIntState() { + this(0, 0, 0); + } + + StdDeviationIntState(double mean, double m2, long count) { + this.welfordAlgorithm = new WelfordAlgorithm(mean, m2, count); + } + + public void add(int value) { + welfordAlgorithm.add(value); + } + + public void combine(double mean, double m2, long count) { + welfordAlgorithm.add(mean, m2, count); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + StdDeviationIntAggregator.evaluateIntermediate(this, driverContext, blocks, offset); + } + + @Override + public void close() {} + + public double mean() { + return welfordAlgorithm.mean(); + } + + public double m2() { + return welfordAlgorithm.m2(); + } + + public long count() { + return welfordAlgorithm.count(); + } + + public double evaluateFinal() { + return welfordAlgorithm.evaluate(); + } + } + + static final class GroupingStdDeviationIntState implements GroupingAggregatorState { + + private ObjectArray states; + private final BigArrays bigArrays; + + GroupingStdDeviationIntState(BigArrays bigArrays) { + this.states = bigArrays.newObjectArray(1); + this.bigArrays = bigArrays; + } + + public void combine(int groupId, double meanValue, double m2Value, long countValue) { + ensureCapacity(groupId); + var state = states.get(groupId); + if (state == null) { + state = new StdDeviationIntState(meanValue, m2Value, countValue); + states.set(groupId, state); + } else { + state.combine(meanValue, m2Value, countValue); + } + } + + public void add(int groupId, int value) { + ensureCapacity(groupId); + var state = states.get(groupId); + if (state == null) { + state = new StdDeviationIntState(); + states.set(groupId, state); + } + state.add(value); + } + + private void ensureCapacity(int groupId) { + states = bigArrays.grow(states, groupId + 1); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + StdDeviationIntAggregator.evaluateIntermediate(this, blocks, offset, selected, driverContext); + } + + @Override + public void close() { + Releasables.close(states); + } + + void enableGroupIdTracking(SeenGroupIds seenGroupIds) { + // noop - we handle the null states inside `toIntermediate` and `evaluateFinal` + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationLongAggregator.java new file mode 100644 index 0000000000000..e13131104d82b --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationLongAggregator.java @@ -0,0 +1,234 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.ObjectArray; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; + +/** + * A standard deviation aggregation definition for long. + * This class is generated. Edit `X-StdDeviationAggregator.java.st` instead. + */ +@Aggregator( + { + @IntermediateState(name = "mean", type = "DOUBLE"), + @IntermediateState(name = "m2", type = "DOUBLE"), + @IntermediateState(name = "count", type = "LONG") } +) +@GroupingAggregator +public class StdDeviationLongAggregator { + + public static StdDeviationLongState initSingle() { + return new StdDeviationLongState(); + } + + public static void combine(StdDeviationLongState state, long value) { + state.add(value); + } + + public static void combineIntermediate(StdDeviationLongState state, double mean, double m2, long count) { + state.combine(mean, m2, count); + } + + public static void evaluateIntermediate(StdDeviationLongState state, DriverContext driverContext, Block[] blocks, int offset) { + assert blocks.length >= offset + 3; + BlockFactory blockFactory = driverContext.blockFactory(); + blocks[offset + 0] = blockFactory.newConstantDoubleBlockWith(state.mean(), 1); + blocks[offset + 1] = blockFactory.newConstantDoubleBlockWith(state.m2(), 1); + blocks[offset + 2] = blockFactory.newConstantLongBlockWith(state.count(), 1); + } + + public static Block evaluateFinal(StdDeviationLongState state, DriverContext driverContext) { + final long count = state.count(); + final double m2 = state.m2(); + if (count == 0 || Double.isFinite(m2) == false) { + return driverContext.blockFactory().newConstantNullBlock(1); + } + return driverContext.blockFactory().newConstantDoubleBlockWith(state.evaluateFinal(), 1); + } + + public static GroupingStdDeviationLongState initGrouping(BigArrays bigArrays) { + return new GroupingStdDeviationLongState(bigArrays); + } + + public static void combine(GroupingStdDeviationLongState current, int groupId, long value) { + current.add(groupId, value); + } + + public static void combineStates( + GroupingStdDeviationLongState current, + int groupId, + GroupingStdDeviationLongState state, + int statePosition + ) { + var st = state.states.get(statePosition); + if (st != null) { + current.combine(groupId, st.mean(), st.m2(), st.count()); + } + } + + public static void combineIntermediate(GroupingStdDeviationLongState state, int groupId, double mean, double m2, long count) { + state.combine(groupId, mean, m2, count); + } + + public static void evaluateIntermediate( + GroupingStdDeviationLongState state, + Block[] blocks, + int offset, + IntVector selected, + DriverContext driverContext + ) { + assert blocks.length >= offset + 3 : "blocks=" + blocks.length + ",offset=" + offset; + try ( + var meanBuilder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount()); + var m2Builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount()); + var countBuilder = driverContext.blockFactory().newLongBlockBuilder(selected.getPositionCount()); + ) { + for (int i = 0; i < selected.getPositionCount(); i++) { + final var groupId = selected.getInt(i); + final var st = groupId < state.states.size() ? state.states.get(groupId) : null; + if (st != null) { + meanBuilder.appendDouble(st.mean()); + m2Builder.appendDouble(st.m2()); + countBuilder.appendLong(st.count()); + } else { + meanBuilder.appendNull(); + m2Builder.appendNull(); + countBuilder.appendNull(); + } + } + blocks[offset + 0] = meanBuilder.build(); + blocks[offset + 1] = m2Builder.build(); + blocks[offset + 2] = countBuilder.build(); + } + } + + public static Block evaluateFinal(GroupingStdDeviationLongState state, IntVector selected, DriverContext driverContext) { + try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount())) { + for (int i = 0; i < selected.getPositionCount(); i++) { + final var groupId = selected.getInt(i); + final var st = groupId < state.states.size() ? state.states.get(groupId) : null; + if (st != null) { + final var m2 = st.m2(); + if (Double.isFinite(m2) == false) { + builder.appendNull(); + } else { + builder.appendDouble(st.evaluateFinal()); + } + } else { + builder.appendNull(); + } + } + return builder.build(); + } + } + + static final class StdDeviationLongState implements AggregatorState { + + private WelfordAlgorithm welfordAlgorithm; + + StdDeviationLongState() { + this(0, 0, 0); + } + + StdDeviationLongState(double mean, double m2, long count) { + this.welfordAlgorithm = new WelfordAlgorithm(mean, m2, count); + } + + public void add(long value) { + welfordAlgorithm.add(value); + } + + public void combine(double mean, double m2, long count) { + welfordAlgorithm.add(mean, m2, count); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + StdDeviationLongAggregator.evaluateIntermediate(this, driverContext, blocks, offset); + } + + @Override + public void close() {} + + public double mean() { + return welfordAlgorithm.mean(); + } + + public double m2() { + return welfordAlgorithm.m2(); + } + + public long count() { + return welfordAlgorithm.count(); + } + + public double evaluateFinal() { + return welfordAlgorithm.evaluate(); + } + } + + static final class GroupingStdDeviationLongState implements GroupingAggregatorState { + + private ObjectArray states; + private final BigArrays bigArrays; + + GroupingStdDeviationLongState(BigArrays bigArrays) { + this.states = bigArrays.newObjectArray(1); + this.bigArrays = bigArrays; + } + + public void combine(int groupId, double meanValue, double m2Value, long countValue) { + ensureCapacity(groupId); + var state = states.get(groupId); + if (state == null) { + state = new StdDeviationLongState(meanValue, m2Value, countValue); + states.set(groupId, state); + } else { + state.combine(meanValue, m2Value, countValue); + } + } + + public void add(int groupId, long value) { + ensureCapacity(groupId); + var state = states.get(groupId); + if (state == null) { + state = new StdDeviationLongState(); + states.set(groupId, state); + } + state.add(value); + } + + private void ensureCapacity(int groupId) { + states = bigArrays.grow(states, groupId + 1); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + StdDeviationLongAggregator.evaluateIntermediate(this, blocks, offset, selected, driverContext); + } + + @Override + public void close() { + Releasables.close(states); + } + + void enableGroupIdTracking(SeenGroupIds seenGroupIds) { + // noop - we handle the null states inside `toIntermediate` and `evaluateFinal` + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregatorFunction.java new file mode 100644 index 0000000000000..c479ac9234b8b --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregatorFunction.java @@ -0,0 +1,178 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link StdDeviationDoubleAggregator}. + * This class is generated. Do not edit it. + */ +public final class StdDeviationDoubleAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("mean", ElementType.DOUBLE), + new IntermediateStateDesc("m2", ElementType.DOUBLE), + new IntermediateStateDesc("count", ElementType.LONG) ); + + private final DriverContext driverContext; + + private final StdDeviationDoubleAggregator.StdDeviationDoubleState state; + + private final List channels; + + public StdDeviationDoubleAggregatorFunction(DriverContext driverContext, List channels, + StdDeviationDoubleAggregator.StdDeviationDoubleState state) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + } + + public static StdDeviationDoubleAggregatorFunction create(DriverContext driverContext, + List channels) { + return new StdDeviationDoubleAggregatorFunction(driverContext, channels, StdDeviationDoubleAggregator.initSingle()); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + return; + } + if (mask.allTrue()) { + // No masking + DoubleBlock block = page.getBlock(channels.get(0)); + DoubleVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector); + } else { + addRawBlock(block); + } + return; + } + // Some positions masked away, others kept + DoubleBlock block = page.getBlock(channels.get(0)); + DoubleVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector, mask); + } else { + addRawBlock(block, mask); + } + } + + private void addRawVector(DoubleVector vector) { + for (int i = 0; i < vector.getPositionCount(); i++) { + StdDeviationDoubleAggregator.combine(state, vector.getDouble(i)); + } + } + + private void addRawVector(DoubleVector vector, BooleanVector mask) { + for (int i = 0; i < vector.getPositionCount(); i++) { + if (mask.getBoolean(i) == false) { + continue; + } + StdDeviationDoubleAggregator.combine(state, vector.getDouble(i)); + } + } + + private void addRawBlock(DoubleBlock block) { + for (int p = 0; p < block.getPositionCount(); p++) { + if (block.isNull(p)) { + continue; + } + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + StdDeviationDoubleAggregator.combine(state, block.getDouble(i)); + } + } + } + + private void addRawBlock(DoubleBlock block, BooleanVector mask) { + for (int p = 0; p < block.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + if (block.isNull(p)) { + continue; + } + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + StdDeviationDoubleAggregator.combine(state, block.getDouble(i)); + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block meanUncast = page.getBlock(channels.get(0)); + if (meanUncast.areAllValuesNull()) { + return; + } + DoubleVector mean = ((DoubleBlock) meanUncast).asVector(); + assert mean.getPositionCount() == 1; + Block m2Uncast = page.getBlock(channels.get(1)); + if (m2Uncast.areAllValuesNull()) { + return; + } + DoubleVector m2 = ((DoubleBlock) m2Uncast).asVector(); + assert m2.getPositionCount() == 1; + Block countUncast = page.getBlock(channels.get(2)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert count.getPositionCount() == 1; + StdDeviationDoubleAggregator.combineIntermediate(state, mean.getDouble(0), m2.getDouble(0), count.getLong(0)); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = StdDeviationDoubleAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..28915079c4b7a --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregatorFunctionSupplier.java @@ -0,0 +1,39 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link StdDeviationDoubleAggregator}. + * This class is generated. Do not edit it. + */ +public final class StdDeviationDoubleAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final List channels; + + public StdDeviationDoubleAggregatorFunctionSupplier(List channels) { + this.channels = channels; + } + + @Override + public StdDeviationDoubleAggregatorFunction aggregator(DriverContext driverContext) { + return StdDeviationDoubleAggregatorFunction.create(driverContext, channels); + } + + @Override + public StdDeviationDoubleGroupingAggregatorFunction groupingAggregator( + DriverContext driverContext) { + return StdDeviationDoubleGroupingAggregatorFunction.create(channels, driverContext); + } + + @Override + public String describe() { + return "std_deviation of doubles"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationDoubleGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..f9ad9ff3eb5db --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationDoubleGroupingAggregatorFunction.java @@ -0,0 +1,224 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link StdDeviationDoubleAggregator}. + * This class is generated. Do not edit it. + */ +public final class StdDeviationDoubleGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("mean", ElementType.DOUBLE), + new IntermediateStateDesc("m2", ElementType.DOUBLE), + new IntermediateStateDesc("count", ElementType.LONG) ); + + private final StdDeviationDoubleAggregator.GroupingStdDeviationDoubleState state; + + private final List channels; + + private final DriverContext driverContext; + + public StdDeviationDoubleGroupingAggregatorFunction(List channels, + StdDeviationDoubleAggregator.GroupingStdDeviationDoubleState state, + DriverContext driverContext) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + } + + public static StdDeviationDoubleGroupingAggregatorFunction create(List channels, + DriverContext driverContext) { + return new StdDeviationDoubleGroupingAggregatorFunction(channels, StdDeviationDoubleAggregator.initGrouping(driverContext.bigArrays()), driverContext); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + Page page) { + DoubleBlock valuesBlock = page.getBlock(channels.get(0)); + DoubleVector valuesVector = valuesBlock.asVector(); + if (valuesVector == null) { + if (valuesBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + StdDeviationDoubleAggregator.combine(state, groupId, values.getDouble(v)); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, DoubleVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + StdDeviationDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset)); + } + } + + private void addRawInput(int positionOffset, IntBlock groups, DoubleBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + StdDeviationDoubleAggregator.combine(state, groupId, values.getDouble(v)); + } + } + } + } + + private void addRawInput(int positionOffset, IntBlock groups, DoubleVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + StdDeviationDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset)); + } + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block meanUncast = page.getBlock(channels.get(0)); + if (meanUncast.areAllValuesNull()) { + return; + } + DoubleVector mean = ((DoubleBlock) meanUncast).asVector(); + Block m2Uncast = page.getBlock(channels.get(1)); + if (m2Uncast.areAllValuesNull()) { + return; + } + DoubleVector m2 = ((DoubleBlock) m2Uncast).asVector(); + Block countUncast = page.getBlock(channels.get(2)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert mean.getPositionCount() == m2.getPositionCount() && mean.getPositionCount() == count.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + StdDeviationDoubleAggregator.combineIntermediate(state, groupId, mean.getDouble(groupPosition + positionOffset), m2.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + + @Override + public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { + if (input.getClass() != getClass()) { + throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); + } + StdDeviationDoubleAggregator.GroupingStdDeviationDoubleState inState = ((StdDeviationDoubleGroupingAggregatorFunction) input).state; + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + StdDeviationDoubleAggregator.combineStates(state, groupId, inState, position); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + DriverContext driverContext) { + blocks[offset] = StdDeviationDoubleAggregator.evaluateFinal(state, selected, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationFloatAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationFloatAggregatorFunction.java new file mode 100644 index 0000000000000..82f2edcf5b6da --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationFloatAggregatorFunction.java @@ -0,0 +1,180 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.FloatVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link StdDeviationFloatAggregator}. + * This class is generated. Do not edit it. + */ +public final class StdDeviationFloatAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("mean", ElementType.DOUBLE), + new IntermediateStateDesc("m2", ElementType.DOUBLE), + new IntermediateStateDesc("count", ElementType.LONG) ); + + private final DriverContext driverContext; + + private final StdDeviationFloatAggregator.StdDeviationFloatState state; + + private final List channels; + + public StdDeviationFloatAggregatorFunction(DriverContext driverContext, List channels, + StdDeviationFloatAggregator.StdDeviationFloatState state) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + } + + public static StdDeviationFloatAggregatorFunction create(DriverContext driverContext, + List channels) { + return new StdDeviationFloatAggregatorFunction(driverContext, channels, StdDeviationFloatAggregator.initSingle()); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + return; + } + if (mask.allTrue()) { + // No masking + FloatBlock block = page.getBlock(channels.get(0)); + FloatVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector); + } else { + addRawBlock(block); + } + return; + } + // Some positions masked away, others kept + FloatBlock block = page.getBlock(channels.get(0)); + FloatVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector, mask); + } else { + addRawBlock(block, mask); + } + } + + private void addRawVector(FloatVector vector) { + for (int i = 0; i < vector.getPositionCount(); i++) { + StdDeviationFloatAggregator.combine(state, vector.getFloat(i)); + } + } + + private void addRawVector(FloatVector vector, BooleanVector mask) { + for (int i = 0; i < vector.getPositionCount(); i++) { + if (mask.getBoolean(i) == false) { + continue; + } + StdDeviationFloatAggregator.combine(state, vector.getFloat(i)); + } + } + + private void addRawBlock(FloatBlock block) { + for (int p = 0; p < block.getPositionCount(); p++) { + if (block.isNull(p)) { + continue; + } + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + StdDeviationFloatAggregator.combine(state, block.getFloat(i)); + } + } + } + + private void addRawBlock(FloatBlock block, BooleanVector mask) { + for (int p = 0; p < block.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + if (block.isNull(p)) { + continue; + } + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + StdDeviationFloatAggregator.combine(state, block.getFloat(i)); + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block meanUncast = page.getBlock(channels.get(0)); + if (meanUncast.areAllValuesNull()) { + return; + } + DoubleVector mean = ((DoubleBlock) meanUncast).asVector(); + assert mean.getPositionCount() == 1; + Block m2Uncast = page.getBlock(channels.get(1)); + if (m2Uncast.areAllValuesNull()) { + return; + } + DoubleVector m2 = ((DoubleBlock) m2Uncast).asVector(); + assert m2.getPositionCount() == 1; + Block countUncast = page.getBlock(channels.get(2)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert count.getPositionCount() == 1; + StdDeviationFloatAggregator.combineIntermediate(state, mean.getDouble(0), m2.getDouble(0), count.getLong(0)); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = StdDeviationFloatAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationFloatAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationFloatAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..e761b32fc8f0c --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationFloatAggregatorFunctionSupplier.java @@ -0,0 +1,39 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link StdDeviationFloatAggregator}. + * This class is generated. Do not edit it. + */ +public final class StdDeviationFloatAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final List channels; + + public StdDeviationFloatAggregatorFunctionSupplier(List channels) { + this.channels = channels; + } + + @Override + public StdDeviationFloatAggregatorFunction aggregator(DriverContext driverContext) { + return StdDeviationFloatAggregatorFunction.create(driverContext, channels); + } + + @Override + public StdDeviationFloatGroupingAggregatorFunction groupingAggregator( + DriverContext driverContext) { + return StdDeviationFloatGroupingAggregatorFunction.create(channels, driverContext); + } + + @Override + public String describe() { + return "std_deviation of floats"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationFloatGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..82f568283378c --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationFloatGroupingAggregatorFunction.java @@ -0,0 +1,226 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.FloatVector; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link StdDeviationFloatAggregator}. + * This class is generated. Do not edit it. + */ +public final class StdDeviationFloatGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("mean", ElementType.DOUBLE), + new IntermediateStateDesc("m2", ElementType.DOUBLE), + new IntermediateStateDesc("count", ElementType.LONG) ); + + private final StdDeviationFloatAggregator.GroupingStdDeviationFloatState state; + + private final List channels; + + private final DriverContext driverContext; + + public StdDeviationFloatGroupingAggregatorFunction(List channels, + StdDeviationFloatAggregator.GroupingStdDeviationFloatState state, + DriverContext driverContext) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + } + + public static StdDeviationFloatGroupingAggregatorFunction create(List channels, + DriverContext driverContext) { + return new StdDeviationFloatGroupingAggregatorFunction(channels, StdDeviationFloatAggregator.initGrouping(driverContext.bigArrays()), driverContext); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + Page page) { + FloatBlock valuesBlock = page.getBlock(channels.get(0)); + FloatVector valuesVector = valuesBlock.asVector(); + if (valuesVector == null) { + if (valuesBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + StdDeviationFloatAggregator.combine(state, groupId, values.getFloat(v)); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, FloatVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + StdDeviationFloatAggregator.combine(state, groupId, values.getFloat(groupPosition + positionOffset)); + } + } + + private void addRawInput(int positionOffset, IntBlock groups, FloatBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + StdDeviationFloatAggregator.combine(state, groupId, values.getFloat(v)); + } + } + } + } + + private void addRawInput(int positionOffset, IntBlock groups, FloatVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + StdDeviationFloatAggregator.combine(state, groupId, values.getFloat(groupPosition + positionOffset)); + } + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block meanUncast = page.getBlock(channels.get(0)); + if (meanUncast.areAllValuesNull()) { + return; + } + DoubleVector mean = ((DoubleBlock) meanUncast).asVector(); + Block m2Uncast = page.getBlock(channels.get(1)); + if (m2Uncast.areAllValuesNull()) { + return; + } + DoubleVector m2 = ((DoubleBlock) m2Uncast).asVector(); + Block countUncast = page.getBlock(channels.get(2)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert mean.getPositionCount() == m2.getPositionCount() && mean.getPositionCount() == count.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + StdDeviationFloatAggregator.combineIntermediate(state, groupId, mean.getDouble(groupPosition + positionOffset), m2.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + + @Override + public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { + if (input.getClass() != getClass()) { + throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); + } + StdDeviationFloatAggregator.GroupingStdDeviationFloatState inState = ((StdDeviationFloatGroupingAggregatorFunction) input).state; + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + StdDeviationFloatAggregator.combineStates(state, groupId, inState, position); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + DriverContext driverContext) { + blocks[offset] = StdDeviationFloatAggregator.evaluateFinal(state, selected, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationIntAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationIntAggregatorFunction.java new file mode 100644 index 0000000000000..90b1e163f4668 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationIntAggregatorFunction.java @@ -0,0 +1,180 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link StdDeviationIntAggregator}. + * This class is generated. Do not edit it. + */ +public final class StdDeviationIntAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("mean", ElementType.DOUBLE), + new IntermediateStateDesc("m2", ElementType.DOUBLE), + new IntermediateStateDesc("count", ElementType.LONG) ); + + private final DriverContext driverContext; + + private final StdDeviationIntAggregator.StdDeviationIntState state; + + private final List channels; + + public StdDeviationIntAggregatorFunction(DriverContext driverContext, List channels, + StdDeviationIntAggregator.StdDeviationIntState state) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + } + + public static StdDeviationIntAggregatorFunction create(DriverContext driverContext, + List channels) { + return new StdDeviationIntAggregatorFunction(driverContext, channels, StdDeviationIntAggregator.initSingle()); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + return; + } + if (mask.allTrue()) { + // No masking + IntBlock block = page.getBlock(channels.get(0)); + IntVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector); + } else { + addRawBlock(block); + } + return; + } + // Some positions masked away, others kept + IntBlock block = page.getBlock(channels.get(0)); + IntVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector, mask); + } else { + addRawBlock(block, mask); + } + } + + private void addRawVector(IntVector vector) { + for (int i = 0; i < vector.getPositionCount(); i++) { + StdDeviationIntAggregator.combine(state, vector.getInt(i)); + } + } + + private void addRawVector(IntVector vector, BooleanVector mask) { + for (int i = 0; i < vector.getPositionCount(); i++) { + if (mask.getBoolean(i) == false) { + continue; + } + StdDeviationIntAggregator.combine(state, vector.getInt(i)); + } + } + + private void addRawBlock(IntBlock block) { + for (int p = 0; p < block.getPositionCount(); p++) { + if (block.isNull(p)) { + continue; + } + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + StdDeviationIntAggregator.combine(state, block.getInt(i)); + } + } + } + + private void addRawBlock(IntBlock block, BooleanVector mask) { + for (int p = 0; p < block.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + if (block.isNull(p)) { + continue; + } + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + StdDeviationIntAggregator.combine(state, block.getInt(i)); + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block meanUncast = page.getBlock(channels.get(0)); + if (meanUncast.areAllValuesNull()) { + return; + } + DoubleVector mean = ((DoubleBlock) meanUncast).asVector(); + assert mean.getPositionCount() == 1; + Block m2Uncast = page.getBlock(channels.get(1)); + if (m2Uncast.areAllValuesNull()) { + return; + } + DoubleVector m2 = ((DoubleBlock) m2Uncast).asVector(); + assert m2.getPositionCount() == 1; + Block countUncast = page.getBlock(channels.get(2)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert count.getPositionCount() == 1; + StdDeviationIntAggregator.combineIntermediate(state, mean.getDouble(0), m2.getDouble(0), count.getLong(0)); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = StdDeviationIntAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationIntAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationIntAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..0d1c8d5c2415b --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationIntAggregatorFunctionSupplier.java @@ -0,0 +1,38 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link StdDeviationIntAggregator}. + * This class is generated. Do not edit it. + */ +public final class StdDeviationIntAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final List channels; + + public StdDeviationIntAggregatorFunctionSupplier(List channels) { + this.channels = channels; + } + + @Override + public StdDeviationIntAggregatorFunction aggregator(DriverContext driverContext) { + return StdDeviationIntAggregatorFunction.create(driverContext, channels); + } + + @Override + public StdDeviationIntGroupingAggregatorFunction groupingAggregator(DriverContext driverContext) { + return StdDeviationIntGroupingAggregatorFunction.create(channels, driverContext); + } + + @Override + public String describe() { + return "std_deviation of ints"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationIntGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..efac6e27f7ea5 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationIntGroupingAggregatorFunction.java @@ -0,0 +1,223 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link StdDeviationIntAggregator}. + * This class is generated. Do not edit it. + */ +public final class StdDeviationIntGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("mean", ElementType.DOUBLE), + new IntermediateStateDesc("m2", ElementType.DOUBLE), + new IntermediateStateDesc("count", ElementType.LONG) ); + + private final StdDeviationIntAggregator.GroupingStdDeviationIntState state; + + private final List channels; + + private final DriverContext driverContext; + + public StdDeviationIntGroupingAggregatorFunction(List channels, + StdDeviationIntAggregator.GroupingStdDeviationIntState state, DriverContext driverContext) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + } + + public static StdDeviationIntGroupingAggregatorFunction create(List channels, + DriverContext driverContext) { + return new StdDeviationIntGroupingAggregatorFunction(channels, StdDeviationIntAggregator.initGrouping(driverContext.bigArrays()), driverContext); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + Page page) { + IntBlock valuesBlock = page.getBlock(channels.get(0)); + IntVector valuesVector = valuesBlock.asVector(); + if (valuesVector == null) { + if (valuesBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + StdDeviationIntAggregator.combine(state, groupId, values.getInt(v)); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, IntVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + StdDeviationIntAggregator.combine(state, groupId, values.getInt(groupPosition + positionOffset)); + } + } + + private void addRawInput(int positionOffset, IntBlock groups, IntBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + StdDeviationIntAggregator.combine(state, groupId, values.getInt(v)); + } + } + } + } + + private void addRawInput(int positionOffset, IntBlock groups, IntVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + StdDeviationIntAggregator.combine(state, groupId, values.getInt(groupPosition + positionOffset)); + } + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block meanUncast = page.getBlock(channels.get(0)); + if (meanUncast.areAllValuesNull()) { + return; + } + DoubleVector mean = ((DoubleBlock) meanUncast).asVector(); + Block m2Uncast = page.getBlock(channels.get(1)); + if (m2Uncast.areAllValuesNull()) { + return; + } + DoubleVector m2 = ((DoubleBlock) m2Uncast).asVector(); + Block countUncast = page.getBlock(channels.get(2)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert mean.getPositionCount() == m2.getPositionCount() && mean.getPositionCount() == count.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + StdDeviationIntAggregator.combineIntermediate(state, groupId, mean.getDouble(groupPosition + positionOffset), m2.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + + @Override + public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { + if (input.getClass() != getClass()) { + throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); + } + StdDeviationIntAggregator.GroupingStdDeviationIntState inState = ((StdDeviationIntGroupingAggregatorFunction) input).state; + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + StdDeviationIntAggregator.combineStates(state, groupId, inState, position); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + DriverContext driverContext) { + blocks[offset] = StdDeviationIntAggregator.evaluateFinal(state, selected, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationLongAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationLongAggregatorFunction.java new file mode 100644 index 0000000000000..fc801e35403b3 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationLongAggregatorFunction.java @@ -0,0 +1,178 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link StdDeviationLongAggregator}. + * This class is generated. Do not edit it. + */ +public final class StdDeviationLongAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("mean", ElementType.DOUBLE), + new IntermediateStateDesc("m2", ElementType.DOUBLE), + new IntermediateStateDesc("count", ElementType.LONG) ); + + private final DriverContext driverContext; + + private final StdDeviationLongAggregator.StdDeviationLongState state; + + private final List channels; + + public StdDeviationLongAggregatorFunction(DriverContext driverContext, List channels, + StdDeviationLongAggregator.StdDeviationLongState state) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + } + + public static StdDeviationLongAggregatorFunction create(DriverContext driverContext, + List channels) { + return new StdDeviationLongAggregatorFunction(driverContext, channels, StdDeviationLongAggregator.initSingle()); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + return; + } + if (mask.allTrue()) { + // No masking + LongBlock block = page.getBlock(channels.get(0)); + LongVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector); + } else { + addRawBlock(block); + } + return; + } + // Some positions masked away, others kept + LongBlock block = page.getBlock(channels.get(0)); + LongVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector, mask); + } else { + addRawBlock(block, mask); + } + } + + private void addRawVector(LongVector vector) { + for (int i = 0; i < vector.getPositionCount(); i++) { + StdDeviationLongAggregator.combine(state, vector.getLong(i)); + } + } + + private void addRawVector(LongVector vector, BooleanVector mask) { + for (int i = 0; i < vector.getPositionCount(); i++) { + if (mask.getBoolean(i) == false) { + continue; + } + StdDeviationLongAggregator.combine(state, vector.getLong(i)); + } + } + + private void addRawBlock(LongBlock block) { + for (int p = 0; p < block.getPositionCount(); p++) { + if (block.isNull(p)) { + continue; + } + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + StdDeviationLongAggregator.combine(state, block.getLong(i)); + } + } + } + + private void addRawBlock(LongBlock block, BooleanVector mask) { + for (int p = 0; p < block.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + if (block.isNull(p)) { + continue; + } + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + StdDeviationLongAggregator.combine(state, block.getLong(i)); + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block meanUncast = page.getBlock(channels.get(0)); + if (meanUncast.areAllValuesNull()) { + return; + } + DoubleVector mean = ((DoubleBlock) meanUncast).asVector(); + assert mean.getPositionCount() == 1; + Block m2Uncast = page.getBlock(channels.get(1)); + if (m2Uncast.areAllValuesNull()) { + return; + } + DoubleVector m2 = ((DoubleBlock) m2Uncast).asVector(); + assert m2.getPositionCount() == 1; + Block countUncast = page.getBlock(channels.get(2)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert count.getPositionCount() == 1; + StdDeviationLongAggregator.combineIntermediate(state, mean.getDouble(0), m2.getDouble(0), count.getLong(0)); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = StdDeviationLongAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationLongAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationLongAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..b2c141647572a --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationLongAggregatorFunctionSupplier.java @@ -0,0 +1,39 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link StdDeviationLongAggregator}. + * This class is generated. Do not edit it. + */ +public final class StdDeviationLongAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final List channels; + + public StdDeviationLongAggregatorFunctionSupplier(List channels) { + this.channels = channels; + } + + @Override + public StdDeviationLongAggregatorFunction aggregator(DriverContext driverContext) { + return StdDeviationLongAggregatorFunction.create(driverContext, channels); + } + + @Override + public StdDeviationLongGroupingAggregatorFunction groupingAggregator( + DriverContext driverContext) { + return StdDeviationLongGroupingAggregatorFunction.create(channels, driverContext); + } + + @Override + public String describe() { + return "std_deviation of longs"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationLongGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..4154b70f43277 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationLongGroupingAggregatorFunction.java @@ -0,0 +1,223 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link StdDeviationLongAggregator}. + * This class is generated. Do not edit it. + */ +public final class StdDeviationLongGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("mean", ElementType.DOUBLE), + new IntermediateStateDesc("m2", ElementType.DOUBLE), + new IntermediateStateDesc("count", ElementType.LONG) ); + + private final StdDeviationLongAggregator.GroupingStdDeviationLongState state; + + private final List channels; + + private final DriverContext driverContext; + + public StdDeviationLongGroupingAggregatorFunction(List channels, + StdDeviationLongAggregator.GroupingStdDeviationLongState state, DriverContext driverContext) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + } + + public static StdDeviationLongGroupingAggregatorFunction create(List channels, + DriverContext driverContext) { + return new StdDeviationLongGroupingAggregatorFunction(channels, StdDeviationLongAggregator.initGrouping(driverContext.bigArrays()), driverContext); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + Page page) { + LongBlock valuesBlock = page.getBlock(channels.get(0)); + LongVector valuesVector = valuesBlock.asVector(); + if (valuesVector == null) { + if (valuesBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + StdDeviationLongAggregator.combine(state, groupId, values.getLong(v)); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, LongVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + StdDeviationLongAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); + } + } + + private void addRawInput(int positionOffset, IntBlock groups, LongBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + StdDeviationLongAggregator.combine(state, groupId, values.getLong(v)); + } + } + } + } + + private void addRawInput(int positionOffset, IntBlock groups, LongVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + StdDeviationLongAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); + } + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block meanUncast = page.getBlock(channels.get(0)); + if (meanUncast.areAllValuesNull()) { + return; + } + DoubleVector mean = ((DoubleBlock) meanUncast).asVector(); + Block m2Uncast = page.getBlock(channels.get(1)); + if (m2Uncast.areAllValuesNull()) { + return; + } + DoubleVector m2 = ((DoubleBlock) m2Uncast).asVector(); + Block countUncast = page.getBlock(channels.get(2)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert mean.getPositionCount() == m2.getPositionCount() && mean.getPositionCount() == count.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + StdDeviationLongAggregator.combineIntermediate(state, groupId, mean.getDouble(groupPosition + positionOffset), m2.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + + @Override + public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { + if (input.getClass() != getClass()) { + throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); + } + StdDeviationLongAggregator.GroupingStdDeviationLongState inState = ((StdDeviationLongGroupingAggregatorFunction) input).state; + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + StdDeviationLongAggregator.combineStates(state, groupId, inState, position); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + DriverContext driverContext) { + blocks[offset] = StdDeviationLongAggregator.evaluateFinal(state, selected, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/WelfordAlgorithm.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/WelfordAlgorithm.java new file mode 100644 index 0000000000000..6d8d0beb51c22 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/WelfordAlgorithm.java @@ -0,0 +1,79 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +/** + * Algorithm for calculating standard deviation, one value at a time + * + * @see + * Welford's_online_algorithm and + * + * Parallel algorithm + */ +public final class WelfordAlgorithm { + private double mean; + private double m2; + private long count; + + public double mean() { + return mean; + } + + public double m2() { + return m2; + } + + public long count() { + return count; + } + + public WelfordAlgorithm() { + this(0, 0, 0); + } + + public WelfordAlgorithm(double mean, double m2, long count) { + this.mean = mean; + this.m2 = m2; + this.count = count; + } + + public void add(int value) { + add((double) value); + } + + public void add(long value) { + add((double) value); + } + + public void add(double value) { + final double delta = value - mean; + count += 1; + mean += delta / count; + m2 += delta * (value - mean); + } + + public void add(double meanValue, double m2Value, long countValue) { + if (countValue == 0) { + return; + } + if (count == 0) { + mean = meanValue; + m2 = m2Value; + count = countValue; + return; + } + double delta = mean - meanValue; + m2 += m2Value + delta * delta * count * countValue / (count + countValue); + count += countValue; + mean += delta * countValue / (count); + } + + public double evaluate() { + return count < 2 ? 0 : Math.sqrt(m2 / count); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDeviationAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDeviationAggregator.java.st new file mode 100644 index 0000000000000..a195659ccd6d3 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDeviationAggregator.java.st @@ -0,0 +1,241 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.ObjectArray; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; + +/** + * A standard deviation aggregation definition for $type$. + * This class is generated. Edit `X-StdDeviationAggregator.java.st` instead. + */ +@Aggregator( + { + @IntermediateState(name = "mean", type = "DOUBLE"), + @IntermediateState(name = "m2", type = "DOUBLE"), + @IntermediateState(name = "count", type = "LONG") } +) +@GroupingAggregator +public class StdDeviation$Type$Aggregator { + + public static StdDeviation$Type$State initSingle() { + return new StdDeviation$Type$State(); + } + + public static void combine(StdDeviation$Type$State state, $type$ value) { + state.add(value); + } + + public static void combineIntermediate(StdDeviation$Type$State state, double mean, double m2, long count) { + state.combine(mean, m2, count); + } + + + public static void evaluateIntermediate(StdDeviation$Type$State state, DriverContext driverContext, Block[] blocks, int offset) { + assert blocks.length >= offset + 3; + BlockFactory blockFactory = driverContext.blockFactory(); + blocks[offset + 0] = blockFactory.newConstantDoubleBlockWith(state.mean(), 1); + blocks[offset + 1] = blockFactory.newConstantDoubleBlockWith(state.m2(), 1); + blocks[offset + 2] = blockFactory.newConstantLongBlockWith(state.count(), 1); + } + + public static Block evaluateFinal(StdDeviation$Type$State state, DriverContext driverContext) { + final long count = state.count(); + final double m2 = state.m2(); + if (count == 0 || Double.isFinite(m2) == false) { + return driverContext.blockFactory().newConstantNullBlock(1); + } + return driverContext.blockFactory().newConstantDoubleBlockWith(state.evaluateFinal(), 1); + } + + public static GroupingStdDeviation$Type$State initGrouping(BigArrays bigArrays) { + return new GroupingStdDeviation$Type$State(bigArrays); + } + + public static void combine(GroupingStdDeviation$Type$State current, int groupId, $type$ value) { + current.add(groupId, value); + } + + public static void combineStates( + GroupingStdDeviation$Type$State current, + int groupId, + GroupingStdDeviation$Type$State state, + int statePosition + ) { + var st = state.states.get(statePosition); + if (st != null) { + current.combine(groupId, st.mean(), st.m2(), st.count()); + } + } + + public static void combineIntermediate( + GroupingStdDeviation$Type$State state, + int groupId, + double mean, + double m2, + long count + ) { + state.combine(groupId, mean, m2, count); + } + + public static void evaluateIntermediate( + GroupingStdDeviation$Type$State state, + Block[] blocks, + int offset, + IntVector selected, + DriverContext driverContext + ) { + assert blocks.length >= offset + 3 : "blocks=" + blocks.length + ",offset=" + offset; + try ( + var meanBuilder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount()); + var m2Builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount()); + var countBuilder = driverContext.blockFactory().newLongBlockBuilder(selected.getPositionCount()); + ) { + for (int i = 0; i < selected.getPositionCount(); i++) { + final var groupId = selected.getInt(i); + final var st = groupId < state.states.size() ? state.states.get(groupId) : null; + if (st != null) { + meanBuilder.appendDouble(st.mean()); + m2Builder.appendDouble(st.m2()); + countBuilder.appendLong(st.count()); + } else { + meanBuilder.appendNull(); + m2Builder.appendNull(); + countBuilder.appendNull(); + } + } + blocks[offset + 0] = meanBuilder.build(); + blocks[offset + 1] = m2Builder.build(); + blocks[offset + 2] = countBuilder.build(); + } + } + + public static Block evaluateFinal(GroupingStdDeviation$Type$State state, IntVector selected, DriverContext driverContext) { + try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount())) { + for (int i = 0; i < selected.getPositionCount(); i++) { + final var groupId = selected.getInt(i); + final var st = groupId < state.states.size() ? state.states.get(groupId) : null; + if (st != null) { + final var m2 = st.m2(); + if (Double.isFinite(m2) == false) { + builder.appendNull(); + } else { + builder.appendDouble(st.evaluateFinal()); + } + } else { + builder.appendNull(); + } + } + return builder.build(); + } + } + + static final class StdDeviation$Type$State implements AggregatorState { + + private WelfordAlgorithm welfordAlgorithm; + + StdDeviation$Type$State() { + this(0, 0, 0); + } + + StdDeviation$Type$State(double mean, double m2, long count) { + this.welfordAlgorithm = new WelfordAlgorithm(mean, m2, count); + } + + public void add($type$ value) { + welfordAlgorithm.add(value); + } + + public void combine(double mean, double m2, long count) { + welfordAlgorithm.add(mean, m2, count); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + StdDeviation$Type$Aggregator.evaluateIntermediate(this, driverContext, blocks, offset); + } + + @Override + public void close() {} + + public double mean() { + return welfordAlgorithm.mean(); + } + + public double m2() { + return welfordAlgorithm.m2(); + } + + public long count() { + return welfordAlgorithm.count(); + } + + public double evaluateFinal() { + return welfordAlgorithm.evaluate(); + } + } + + static final class GroupingStdDeviation$Type$State implements GroupingAggregatorState { + + private ObjectArray states; + private final BigArrays bigArrays; + + GroupingStdDeviation$Type$State(BigArrays bigArrays) { + this.states = bigArrays.newObjectArray(1); + this.bigArrays = bigArrays; + } + + public void combine(int groupId, double meanValue, double m2Value, long countValue) { + ensureCapacity(groupId); + var state = states.get(groupId); + if (state == null) { + state = new StdDeviation$Type$State(meanValue, m2Value, countValue); + states.set(groupId, state); + } else { + state.combine(meanValue, m2Value, countValue); + } + } + + public void add(int groupId, $type$ value) { + ensureCapacity(groupId); + var state = states.get(groupId); + if (state == null) { + state = new StdDeviation$Type$State(); + states.set(groupId, state); + } + state.add(value); + } + + private void ensureCapacity(int groupId) { + states = bigArrays.grow(states, groupId + 1); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + StdDeviation$Type$Aggregator.evaluateIntermediate(this, blocks, offset, selected, driverContext); + } + + @Override + public void close() { + Releasables.close(states); + } + + void enableGroupIdTracking(SeenGroupIds seenGroupIds) { + // noop - we handle the null states inside `toIntermediate` and `evaluateFinal` + } + } +} diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec index 448ee57b34c58..854d63431a296 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec @@ -2685,3 +2685,133 @@ max:integer | job_positions:keyword 39878 | Business Analyst 67492 | Data Scientist ; + +stdDeviation +// tag::stdev[] +FROM employees +| STATS STD_DEVIATION(height) +// end::stdev[] +; + +// tag::stdev-result[] +STD_DEVIATION(height):double +0.20637044362020449 +// end::stdev-result[] +; + +stdDeviationNested +// tag::docsStatsStdDeviationNestedExpression[] +FROM employees +| STATS stdev_salary_change = STD_DEVIATION(MV_MAX(salary_change)) +// end::docsStatsStdDeviationNestedExpression[] +; + +// tag::docsStatsStdDeviationNestedExpression-result[] +stdev_salary_change:double +6.875829592924112 +// end::docsStatsStdDeviationNestedExpression-result[] +; + + +stdDeviationWithLongs +FROM employees +| STATS STD_DEVIATION(avg_worked_seconds) +; + +STD_DEVIATION(avg_worked_seconds):double +5.76010425971634E7 +; + +stdDeviationWithInts +FROM employees +| STATS STD_DEVIATION(salary) +; + +STD_DEVIATION(salary):double +13765.12550278783 +; + +stdDeviationConstantValue +FROM employees +| WHERE languages == 2 +| STATS STD_DEVIATION(languages) +; + +STD_DEVIATION(languages):double +0.0 +; + +stdDeviationGrouped +FROM employees +| STATS STD_DEVIATION(height) BY languages +| SORT languages asc +; + +STD_DEVIATION(height):double | languages:integer +0.22106409327010415 | 1 +0.22797190865484734 | 2 +0.18893070075713295 | 3 +0.14656141004227627 | 4 +0.17733860152780256 | 5 +0.2486543786061287 | null +; + +stdDeviationGrouped1 +FROM employees +| WHERE languages == 1 +| STATS STD_DEVIATION(height) +; + +STD_DEVIATION(height):double +0.22106409327010415 +; + +stdDeviationGrouped2 +FROM employees +| WHERE languages == 2 +| STATS STD_DEVIATION(height) +; + +STD_DEVIATION(height):double +0.22797190865484734 +; + +stdDeviationGrouped3 +FROM employees +| WHERE languages == 3 +| STATS STD_DEVIATION(height) +; + +STD_DEVIATION(height):double +0.18893070075713295 +; + +stdDeviationGrouped4 +FROM employees +| WHERE languages == 4 +| STATS STD_DEVIATION(height) +; + +STD_DEVIATION(height):double +0.14656141004227627 +; + +stdDeviationGrouped5 +FROM employees +| WHERE languages == 5 +| STATS STD_DEVIATION(height) +; + +STD_DEVIATION(height):double +0.17733860152780256 +; + +stdDeviationNoRows +FROM employees +| WHERE languages IS null +| STATS STD_DEVIATION(languages) +; + +STD_DEVIATION(languages):double +null +; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 837acfc5083f7..c391e5ba07c40 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -409,6 +409,11 @@ public enum Cap { */ PER_AGG_FILTERING_ORDS, + /** + * Support for {@code STD_DEVIATION} aggregation. + */ + STD_DEVIATION, + /** * Fix for https://github.com/elastic/elasticsearch/issues/114714 */ diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index 66151275fc2e8..79b07c4f67898 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -28,6 +28,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile; import org.elasticsearch.xpack.esql.expression.function.aggregate.Rate; import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid; +import org.elasticsearch.xpack.esql.expression.function.aggregate.StdDeviation; import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum; import org.elasticsearch.xpack.esql.expression.function.aggregate.Top; import org.elasticsearch.xpack.esql.expression.function.aggregate.Values; @@ -270,6 +271,7 @@ private FunctionDefinition[][] functions() { def(MedianAbsoluteDeviation.class, uni(MedianAbsoluteDeviation::new), "median_absolute_deviation"), def(Min.class, uni(Min::new), "min"), def(Percentile.class, bi(Percentile::new), "percentile"), + def(StdDeviation.class, uni(StdDeviation::new), "std_deviation"), def(Sum.class, uni(Sum::new), "sum"), def(Top.class, tri(Top::new), "top"), def(Values.class, uni(Values::new), "values"), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java index f7a74cc2ae93f..5e661979ba3b9 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java @@ -42,6 +42,7 @@ public static List getNamedWriteables() { Percentile.ENTRY, Rate.ENTRY, SpatialCentroid.ENTRY, + StdDeviation.ENTRY, Sum.ENTRY, Top.ENTRY, Values.ENTRY, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDeviation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDeviation.java new file mode 100644 index 0000000000000..301bb5b519b52 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDeviation.java @@ -0,0 +1,103 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.StdDeviationDoubleAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.StdDeviationIntAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.StdDeviationLongAggregatorFunctionSupplier; +import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.Example; +import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; +import org.elasticsearch.xpack.esql.expression.function.Param; +import org.elasticsearch.xpack.esql.planner.ToAggregator; + +import java.io.IOException; +import java.util.List; + +import static java.util.Collections.emptyList; + +public class StdDeviation extends AggregateFunction implements ToAggregator { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Expression.class, + "StdDeviation", + StdDeviation::new + ); + + @FunctionInfo( + returnType = "double", + description = "The standard deviation of a numeric field.", + isAggregation = true, + examples = { + @Example(file = "stats", tag = "stdev"), + @Example( + description = "The expression can use inline functions. For example, to calculate the standard " + + "deviation of each employee's maximum salary changes, first use `MV_MAX` on each row, " + + "and then use `StdDeviation` on the result", + file = "stats", + tag = "docsStatsStdDeviationNestedExpression" + ) } + ) + public StdDeviation(Source source, @Param(name = "number", type = { "double", "integer", "long" }) Expression field) { + this(source, field, Literal.TRUE); + } + + public StdDeviation(Source source, Expression field, Expression filter) { + super(source, field, filter, emptyList()); + } + + private StdDeviation(StreamInput in) throws IOException { + super(in); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + @Override + public DataType dataType() { + return DataType.DOUBLE; + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, StdDeviation::new, field(), filter()); + } + + @Override + public StdDeviation replaceChildren(List newChildren) { + return new StdDeviation(source(), newChildren.get(0), newChildren.get(1)); + } + + public StdDeviation withFilter(Expression filter) { + return new StdDeviation(source(), field(), filter); + } + + @Override + public final AggregatorFunctionSupplier supplier(List inputChannels) { + DataType type = field().dataType(); + if (type == DataType.LONG) { + return new StdDeviationLongAggregatorFunctionSupplier(inputChannels); + } + if (type == DataType.INTEGER) { + return new StdDeviationIntAggregatorFunctionSupplier(inputChannels); + } + if (type == DataType.DOUBLE) { + return new StdDeviationDoubleAggregatorFunctionSupplier(inputChannels); + } + throw EsqlIllegalArgumentException.illegalDataType(type); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java index 3e81c2a2c1101..5bcbf69ae63db 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java @@ -34,6 +34,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.Rate; import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialAggregateFunction; import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid; +import org.elasticsearch.xpack.esql.expression.function.aggregate.StdDeviation; import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum; import org.elasticsearch.xpack.esql.expression.function.aggregate.ToPartial; import org.elasticsearch.xpack.esql.expression.function.aggregate.Top; @@ -78,6 +79,7 @@ final class AggregateMapper { Min.class, Percentile.class, SpatialCentroid.class, + StdDeviation.class, Sum.class, Values.class, Top.class, @@ -171,7 +173,7 @@ private static Stream, Tuple>> typeAndNames(Class types = List.of("Int", "Long", "Double", "Boolean", "BytesRef"); } else if (Top.class.isAssignableFrom(clazz)) { types = List.of("Boolean", "Int", "Long", "Double", "Ip", "BytesRef"); - } else if (Rate.class.isAssignableFrom(clazz)) { + } else if (Rate.class.isAssignableFrom(clazz) || StdDeviation.class.isAssignableFrom(clazz)) { types = List.of("Int", "Long", "Double"); } else if (FromPartial.class.isAssignableFrom(clazz) || ToPartial.class.isAssignableFrom(clazz)) { types = List.of(""); // no type diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDeviationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDeviationTests.java new file mode 100644 index 0000000000000..f107c5623a35f --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDeviationTests.java @@ -0,0 +1,89 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.compute.aggregation.WelfordAlgorithm; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractAggregationTestCase; +import org.elasticsearch.xpack.esql.expression.function.MultiRowTestCaseSupplier; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.nullValue; + +public class StdDeviationTests extends AbstractAggregationTestCase { + public StdDeviationTests(@Name("TestCase") Supplier testCaseSupplier) { + this.testCase = testCaseSupplier.get(); + } + + @ParametersFactory + public static Iterable parameters() { + var suppliers = new ArrayList(); + + Stream.of( + MultiRowTestCaseSupplier.intCases(1, 1000, Integer.MIN_VALUE, Integer.MAX_VALUE, true), + MultiRowTestCaseSupplier.longCases(1, 1000, Long.MIN_VALUE, Long.MAX_VALUE, true), + MultiRowTestCaseSupplier.doubleCases(1, 1000, -Double.MAX_VALUE, Double.MAX_VALUE, true) + ).flatMap(List::stream).map(StdDeviationTests::makeSupplier).collect(Collectors.toCollection(() -> suppliers)); + + // No rows + for (var dataType : List.of(DataType.INTEGER, DataType.LONG, DataType.DOUBLE)) { + suppliers.add( + new TestCaseSupplier( + "No rows (" + dataType + ")", + List.of(dataType), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(), dataType, "field")), + "StdDeviation[field=Attribute[channel=0]]", + DataType.DOUBLE, + nullValue() + ) + ) + ); + } + return parameterSuppliersFromTypedData(randomizeBytesRefsOffset(suppliers)); + } + + @Override + protected Expression build(Source source, List args) { + return new StdDeviation(source, args.get(0)); + } + + private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier fieldSupplier) { + return new TestCaseSupplier(List.of(fieldSupplier.type()), () -> { + var fieldTypedData = fieldSupplier.get(); + var fieldValues = fieldTypedData.multiRowData(); + + WelfordAlgorithm welfordAlgorithm = new WelfordAlgorithm(); + + for (var fieldValue : fieldValues) { + var value = ((Number) fieldValue).doubleValue(); + welfordAlgorithm.add(value); + } + var result = welfordAlgorithm.evaluate(); + var expected = Double.isInfinite(result) ? null : result; + return new TestCaseSupplier.TestCase( + List.of(fieldTypedData), + "StdDeviation[field=Attribute[channel=0]]", + DataType.DOUBLE, + equalTo(expected) + ); + }); + } +} From fa54d7384793241beb24a89ab1b93fb9750b568e Mon Sep 17 00:00:00 2001 From: Larisa Motova Date: Fri, 8 Nov 2024 13:34:43 -1000 Subject: [PATCH 02/20] lint --- .../compute/aggregation/X-StdDeviationAggregator.java.st | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDeviationAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDeviationAggregator.java.st index a195659ccd6d3..ff9d081358496 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDeviationAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDeviationAggregator.java.st @@ -44,7 +44,6 @@ public class StdDeviation$Type$Aggregator { state.combine(mean, m2, count); } - public static void evaluateIntermediate(StdDeviation$Type$State state, DriverContext driverContext, Block[] blocks, int offset) { assert blocks.length >= offset + 3; BlockFactory blockFactory = driverContext.blockFactory(); @@ -82,13 +81,7 @@ public class StdDeviation$Type$Aggregator { } } - public static void combineIntermediate( - GroupingStdDeviation$Type$State state, - int groupId, - double mean, - double m2, - long count - ) { + public static void combineIntermediate(GroupingStdDeviation$Type$State state, int groupId, double mean, double m2, long count) { state.combine(groupId, mean, m2, count); } From b4c9c1b0cf6e5372a34ac13d1b3f2cfb82d0fefe Mon Sep 17 00:00:00 2001 From: Larisa Motova Date: Fri, 8 Nov 2024 16:39:07 -1000 Subject: [PATCH 03/20] rest test --- .../resources/rest-api-spec/test/esql/60_usage.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml index 6e7098da33805..db89ffc3dfcca 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml @@ -91,7 +91,7 @@ setup: - match: {esql.functions.cos: $functions_cos} - gt: {esql.functions.to_long: $functions_to_long} - match: {esql.functions.coalesce: $functions_coalesce} - - length: {esql.functions: 119} # check the "sister" test below for a likely update to the same esql.functions length check + - length: {esql.functions: 120} # check the "sister" test below for a likely update to the same esql.functions length check --- "Basic ESQL usage output (telemetry) non-snapshot version": @@ -162,4 +162,4 @@ setup: - match: {esql.functions.cos: $functions_cos} - gt: {esql.functions.to_long: $functions_to_long} - match: {esql.functions.coalesce: $functions_coalesce} - - length: {esql.functions: 116} # check the "sister" test above for a likely update to the same esql.functions length check + - length: {esql.functions: 117} # check the "sister" test above for a likely update to the same esql.functions length check From 0a74c9a90b99b217ef908abeae407229cc83ba17 Mon Sep 17 00:00:00 2001 From: Larisa Motova Date: Fri, 8 Nov 2024 18:00:31 -1000 Subject: [PATCH 04/20] bwc --- .../testFixtures/src/main/resources/stats.csv-spec | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec index 854d63431a296..5435d13233db5 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec @@ -2687,6 +2687,7 @@ max:integer | job_positions:keyword ; stdDeviation +required_capability: std_deviation // tag::stdev[] FROM employees | STATS STD_DEVIATION(height) @@ -2700,6 +2701,7 @@ STD_DEVIATION(height):double ; stdDeviationNested +required_capability: std_deviation // tag::docsStatsStdDeviationNestedExpression[] FROM employees | STATS stdev_salary_change = STD_DEVIATION(MV_MAX(salary_change)) @@ -2714,6 +2716,7 @@ stdev_salary_change:double stdDeviationWithLongs +required_capability: std_deviation FROM employees | STATS STD_DEVIATION(avg_worked_seconds) ; @@ -2723,6 +2726,7 @@ STD_DEVIATION(avg_worked_seconds):double ; stdDeviationWithInts +required_capability: std_deviation FROM employees | STATS STD_DEVIATION(salary) ; @@ -2732,6 +2736,7 @@ STD_DEVIATION(salary):double ; stdDeviationConstantValue +required_capability: std_deviation FROM employees | WHERE languages == 2 | STATS STD_DEVIATION(languages) @@ -2742,6 +2747,7 @@ STD_DEVIATION(languages):double ; stdDeviationGrouped +required_capability: std_deviation FROM employees | STATS STD_DEVIATION(height) BY languages | SORT languages asc @@ -2757,6 +2763,7 @@ STD_DEVIATION(height):double | languages:integer ; stdDeviationGrouped1 +required_capability: std_deviation FROM employees | WHERE languages == 1 | STATS STD_DEVIATION(height) @@ -2767,6 +2774,7 @@ STD_DEVIATION(height):double ; stdDeviationGrouped2 +required_capability: std_deviation FROM employees | WHERE languages == 2 | STATS STD_DEVIATION(height) @@ -2777,6 +2785,7 @@ STD_DEVIATION(height):double ; stdDeviationGrouped3 +required_capability: std_deviation FROM employees | WHERE languages == 3 | STATS STD_DEVIATION(height) @@ -2787,6 +2796,7 @@ STD_DEVIATION(height):double ; stdDeviationGrouped4 +required_capability: std_deviation FROM employees | WHERE languages == 4 | STATS STD_DEVIATION(height) @@ -2797,6 +2807,7 @@ STD_DEVIATION(height):double ; stdDeviationGrouped5 +required_capability: std_deviation FROM employees | WHERE languages == 5 | STATS STD_DEVIATION(height) @@ -2807,6 +2818,7 @@ STD_DEVIATION(height):double ; stdDeviationNoRows +required_capability: std_deviation FROM employees | WHERE languages IS null | STATS STD_DEVIATION(languages) From 55cdef107f7634d0e5e498d0bdb18a8e5b11114f Mon Sep 17 00:00:00 2001 From: Larisa Motova Date: Tue, 12 Nov 2024 13:32:13 -1000 Subject: [PATCH 05/20] move states from template to individual classes --- .../StdDeviationDoubleAggregator.java | 175 ++--------------- .../StdDeviationFloatAggregator.java | 175 ++--------------- .../StdDeviationIntAggregator.java | 175 ++--------------- .../StdDeviationLongAggregator.java | 175 ++--------------- .../StdDeviationDoubleAggregatorFunction.java | 4 +- ...ationDoubleGroupingAggregatorFunction.java | 7 +- .../StdDeviationFloatAggregatorFunction.java | 4 +- ...iationFloatGroupingAggregatorFunction.java | 7 +- .../StdDeviationIntAggregatorFunction.java | 4 +- ...eviationIntGroupingAggregatorFunction.java | 6 +- .../StdDeviationLongAggregatorFunction.java | 4 +- ...viationLongGroupingAggregatorFunction.java | 6 +- .../aggregation/StdDeviationStates.java | 180 ++++++++++++++++++ .../X-StdDeviationAggregator.java.st | 175 ++--------------- 14 files changed, 300 insertions(+), 797 deletions(-) create mode 100644 x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDeviationStates.java diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregator.java index 4b0d1e7d79881..be1408915195d 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregator.java @@ -8,16 +8,13 @@ package org.elasticsearch.compute.aggregation; import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.common.util.ObjectArray; import org.elasticsearch.compute.ann.Aggregator; import org.elasticsearch.compute.ann.GroupingAggregator; import org.elasticsearch.compute.ann.IntermediateState; import org.elasticsearch.compute.data.Block; -import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.core.Releasables; /** * A standard deviation aggregation definition for double. @@ -32,27 +29,19 @@ @GroupingAggregator public class StdDeviationDoubleAggregator { - public static StdDeviationDoubleState initSingle() { - return new StdDeviationDoubleState(); + public static StdDeviationStates.StdDeviationState initSingle() { + return new StdDeviationStates.StdDeviationState(); } - public static void combine(StdDeviationDoubleState state, double value) { + public static void combine(StdDeviationStates.StdDeviationState state, double value) { state.add(value); } - public static void combineIntermediate(StdDeviationDoubleState state, double mean, double m2, long count) { + public static void combineIntermediate(StdDeviationStates.StdDeviationState state, double mean, double m2, long count) { state.combine(mean, m2, count); } - public static void evaluateIntermediate(StdDeviationDoubleState state, DriverContext driverContext, Block[] blocks, int offset) { - assert blocks.length >= offset + 3; - BlockFactory blockFactory = driverContext.blockFactory(); - blocks[offset + 0] = blockFactory.newConstantDoubleBlockWith(state.mean(), 1); - blocks[offset + 1] = blockFactory.newConstantDoubleBlockWith(state.m2(), 1); - blocks[offset + 2] = blockFactory.newConstantLongBlockWith(state.count(), 1); - } - - public static Block evaluateFinal(StdDeviationDoubleState state, DriverContext driverContext) { + public static Block evaluateFinal(StdDeviationStates.StdDeviationState state, DriverContext driverContext) { final long count = state.count(); final double m2 = state.m2(); if (count == 0 || Double.isFinite(m2) == false) { @@ -61,67 +50,38 @@ public static Block evaluateFinal(StdDeviationDoubleState state, DriverContext d return driverContext.blockFactory().newConstantDoubleBlockWith(state.evaluateFinal(), 1); } - public static GroupingStdDeviationDoubleState initGrouping(BigArrays bigArrays) { - return new GroupingStdDeviationDoubleState(bigArrays); + public static StdDeviationStates.GroupingStdDeviationState initGrouping(BigArrays bigArrays) { + return new StdDeviationStates.GroupingStdDeviationState(bigArrays); } - public static void combine(GroupingStdDeviationDoubleState current, int groupId, double value) { + public static void combine(StdDeviationStates.GroupingStdDeviationState current, int groupId, double value) { current.add(groupId, value); } public static void combineStates( - GroupingStdDeviationDoubleState current, + StdDeviationStates.GroupingStdDeviationState current, int groupId, - GroupingStdDeviationDoubleState state, + StdDeviationStates.GroupingStdDeviationState state, int statePosition ) { - var st = state.states.get(statePosition); - if (st != null) { - current.combine(groupId, st.mean(), st.m2(), st.count()); - } + current.combine(groupId, state.getOrNull(statePosition)); } - public static void combineIntermediate(GroupingStdDeviationDoubleState state, int groupId, double mean, double m2, long count) { - state.combine(groupId, mean, m2, count); - } - - public static void evaluateIntermediate( - GroupingStdDeviationDoubleState state, - Block[] blocks, - int offset, - IntVector selected, - DriverContext driverContext + public static void combineIntermediate( + StdDeviationStates.GroupingStdDeviationState state, + int groupId, + double mean, + double m2, + long count ) { - assert blocks.length >= offset + 3 : "blocks=" + blocks.length + ",offset=" + offset; - try ( - var meanBuilder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount()); - var m2Builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount()); - var countBuilder = driverContext.blockFactory().newLongBlockBuilder(selected.getPositionCount()); - ) { - for (int i = 0; i < selected.getPositionCount(); i++) { - final var groupId = selected.getInt(i); - final var st = groupId < state.states.size() ? state.states.get(groupId) : null; - if (st != null) { - meanBuilder.appendDouble(st.mean()); - m2Builder.appendDouble(st.m2()); - countBuilder.appendLong(st.count()); - } else { - meanBuilder.appendNull(); - m2Builder.appendNull(); - countBuilder.appendNull(); - } - } - blocks[offset + 0] = meanBuilder.build(); - blocks[offset + 1] = m2Builder.build(); - blocks[offset + 2] = countBuilder.build(); - } + state.combine(groupId, mean, m2, count); } - public static Block evaluateFinal(GroupingStdDeviationDoubleState state, IntVector selected, DriverContext driverContext) { + public static Block evaluateFinal(StdDeviationStates.GroupingStdDeviationState state, IntVector selected, DriverContext driverContext) { try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount())) { for (int i = 0; i < selected.getPositionCount(); i++) { final var groupId = selected.getInt(i); - final var st = groupId < state.states.size() ? state.states.get(groupId) : null; + final var st = state.getOrNull(groupId); if (st != null) { final var m2 = st.m2(); if (Double.isFinite(m2) == false) { @@ -136,99 +96,4 @@ public static Block evaluateFinal(GroupingStdDeviationDoubleState state, IntVect return builder.build(); } } - - static final class StdDeviationDoubleState implements AggregatorState { - - private WelfordAlgorithm welfordAlgorithm; - - StdDeviationDoubleState() { - this(0, 0, 0); - } - - StdDeviationDoubleState(double mean, double m2, long count) { - this.welfordAlgorithm = new WelfordAlgorithm(mean, m2, count); - } - - public void add(double value) { - welfordAlgorithm.add(value); - } - - public void combine(double mean, double m2, long count) { - welfordAlgorithm.add(mean, m2, count); - } - - @Override - public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { - StdDeviationDoubleAggregator.evaluateIntermediate(this, driverContext, blocks, offset); - } - - @Override - public void close() {} - - public double mean() { - return welfordAlgorithm.mean(); - } - - public double m2() { - return welfordAlgorithm.m2(); - } - - public long count() { - return welfordAlgorithm.count(); - } - - public double evaluateFinal() { - return welfordAlgorithm.evaluate(); - } - } - - static final class GroupingStdDeviationDoubleState implements GroupingAggregatorState { - - private ObjectArray states; - private final BigArrays bigArrays; - - GroupingStdDeviationDoubleState(BigArrays bigArrays) { - this.states = bigArrays.newObjectArray(1); - this.bigArrays = bigArrays; - } - - public void combine(int groupId, double meanValue, double m2Value, long countValue) { - ensureCapacity(groupId); - var state = states.get(groupId); - if (state == null) { - state = new StdDeviationDoubleState(meanValue, m2Value, countValue); - states.set(groupId, state); - } else { - state.combine(meanValue, m2Value, countValue); - } - } - - public void add(int groupId, double value) { - ensureCapacity(groupId); - var state = states.get(groupId); - if (state == null) { - state = new StdDeviationDoubleState(); - states.set(groupId, state); - } - state.add(value); - } - - private void ensureCapacity(int groupId) { - states = bigArrays.grow(states, groupId + 1); - } - - @Override - public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { - StdDeviationDoubleAggregator.evaluateIntermediate(this, blocks, offset, selected, driverContext); - } - - @Override - public void close() { - Releasables.close(states); - } - - void enableGroupIdTracking(SeenGroupIds seenGroupIds) { - // noop - we handle the null states inside `toIntermediate` and `evaluateFinal` - } - } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationFloatAggregator.java index de96ce9d5c1ea..84d58a61c5da7 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationFloatAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationFloatAggregator.java @@ -8,16 +8,13 @@ package org.elasticsearch.compute.aggregation; import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.common.util.ObjectArray; import org.elasticsearch.compute.ann.Aggregator; import org.elasticsearch.compute.ann.GroupingAggregator; import org.elasticsearch.compute.ann.IntermediateState; import org.elasticsearch.compute.data.Block; -import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.core.Releasables; /** * A standard deviation aggregation definition for float. @@ -32,27 +29,19 @@ @GroupingAggregator public class StdDeviationFloatAggregator { - public static StdDeviationFloatState initSingle() { - return new StdDeviationFloatState(); + public static StdDeviationStates.StdDeviationState initSingle() { + return new StdDeviationStates.StdDeviationState(); } - public static void combine(StdDeviationFloatState state, float value) { + public static void combine(StdDeviationStates.StdDeviationState state, float value) { state.add(value); } - public static void combineIntermediate(StdDeviationFloatState state, double mean, double m2, long count) { + public static void combineIntermediate(StdDeviationStates.StdDeviationState state, double mean, double m2, long count) { state.combine(mean, m2, count); } - public static void evaluateIntermediate(StdDeviationFloatState state, DriverContext driverContext, Block[] blocks, int offset) { - assert blocks.length >= offset + 3; - BlockFactory blockFactory = driverContext.blockFactory(); - blocks[offset + 0] = blockFactory.newConstantDoubleBlockWith(state.mean(), 1); - blocks[offset + 1] = blockFactory.newConstantDoubleBlockWith(state.m2(), 1); - blocks[offset + 2] = blockFactory.newConstantLongBlockWith(state.count(), 1); - } - - public static Block evaluateFinal(StdDeviationFloatState state, DriverContext driverContext) { + public static Block evaluateFinal(StdDeviationStates.StdDeviationState state, DriverContext driverContext) { final long count = state.count(); final double m2 = state.m2(); if (count == 0 || Double.isFinite(m2) == false) { @@ -61,67 +50,38 @@ public static Block evaluateFinal(StdDeviationFloatState state, DriverContext dr return driverContext.blockFactory().newConstantDoubleBlockWith(state.evaluateFinal(), 1); } - public static GroupingStdDeviationFloatState initGrouping(BigArrays bigArrays) { - return new GroupingStdDeviationFloatState(bigArrays); + public static StdDeviationStates.GroupingStdDeviationState initGrouping(BigArrays bigArrays) { + return new StdDeviationStates.GroupingStdDeviationState(bigArrays); } - public static void combine(GroupingStdDeviationFloatState current, int groupId, float value) { + public static void combine(StdDeviationStates.GroupingStdDeviationState current, int groupId, float value) { current.add(groupId, value); } public static void combineStates( - GroupingStdDeviationFloatState current, + StdDeviationStates.GroupingStdDeviationState current, int groupId, - GroupingStdDeviationFloatState state, + StdDeviationStates.GroupingStdDeviationState state, int statePosition ) { - var st = state.states.get(statePosition); - if (st != null) { - current.combine(groupId, st.mean(), st.m2(), st.count()); - } + current.combine(groupId, state.getOrNull(statePosition)); } - public static void combineIntermediate(GroupingStdDeviationFloatState state, int groupId, double mean, double m2, long count) { - state.combine(groupId, mean, m2, count); - } - - public static void evaluateIntermediate( - GroupingStdDeviationFloatState state, - Block[] blocks, - int offset, - IntVector selected, - DriverContext driverContext + public static void combineIntermediate( + StdDeviationStates.GroupingStdDeviationState state, + int groupId, + double mean, + double m2, + long count ) { - assert blocks.length >= offset + 3 : "blocks=" + blocks.length + ",offset=" + offset; - try ( - var meanBuilder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount()); - var m2Builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount()); - var countBuilder = driverContext.blockFactory().newLongBlockBuilder(selected.getPositionCount()); - ) { - for (int i = 0; i < selected.getPositionCount(); i++) { - final var groupId = selected.getInt(i); - final var st = groupId < state.states.size() ? state.states.get(groupId) : null; - if (st != null) { - meanBuilder.appendDouble(st.mean()); - m2Builder.appendDouble(st.m2()); - countBuilder.appendLong(st.count()); - } else { - meanBuilder.appendNull(); - m2Builder.appendNull(); - countBuilder.appendNull(); - } - } - blocks[offset + 0] = meanBuilder.build(); - blocks[offset + 1] = m2Builder.build(); - blocks[offset + 2] = countBuilder.build(); - } + state.combine(groupId, mean, m2, count); } - public static Block evaluateFinal(GroupingStdDeviationFloatState state, IntVector selected, DriverContext driverContext) { + public static Block evaluateFinal(StdDeviationStates.GroupingStdDeviationState state, IntVector selected, DriverContext driverContext) { try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount())) { for (int i = 0; i < selected.getPositionCount(); i++) { final var groupId = selected.getInt(i); - final var st = groupId < state.states.size() ? state.states.get(groupId) : null; + final var st = state.getOrNull(groupId); if (st != null) { final var m2 = st.m2(); if (Double.isFinite(m2) == false) { @@ -136,99 +96,4 @@ public static Block evaluateFinal(GroupingStdDeviationFloatState state, IntVecto return builder.build(); } } - - static final class StdDeviationFloatState implements AggregatorState { - - private WelfordAlgorithm welfordAlgorithm; - - StdDeviationFloatState() { - this(0, 0, 0); - } - - StdDeviationFloatState(double mean, double m2, long count) { - this.welfordAlgorithm = new WelfordAlgorithm(mean, m2, count); - } - - public void add(float value) { - welfordAlgorithm.add(value); - } - - public void combine(double mean, double m2, long count) { - welfordAlgorithm.add(mean, m2, count); - } - - @Override - public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { - StdDeviationFloatAggregator.evaluateIntermediate(this, driverContext, blocks, offset); - } - - @Override - public void close() {} - - public double mean() { - return welfordAlgorithm.mean(); - } - - public double m2() { - return welfordAlgorithm.m2(); - } - - public long count() { - return welfordAlgorithm.count(); - } - - public double evaluateFinal() { - return welfordAlgorithm.evaluate(); - } - } - - static final class GroupingStdDeviationFloatState implements GroupingAggregatorState { - - private ObjectArray states; - private final BigArrays bigArrays; - - GroupingStdDeviationFloatState(BigArrays bigArrays) { - this.states = bigArrays.newObjectArray(1); - this.bigArrays = bigArrays; - } - - public void combine(int groupId, double meanValue, double m2Value, long countValue) { - ensureCapacity(groupId); - var state = states.get(groupId); - if (state == null) { - state = new StdDeviationFloatState(meanValue, m2Value, countValue); - states.set(groupId, state); - } else { - state.combine(meanValue, m2Value, countValue); - } - } - - public void add(int groupId, float value) { - ensureCapacity(groupId); - var state = states.get(groupId); - if (state == null) { - state = new StdDeviationFloatState(); - states.set(groupId, state); - } - state.add(value); - } - - private void ensureCapacity(int groupId) { - states = bigArrays.grow(states, groupId + 1); - } - - @Override - public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { - StdDeviationFloatAggregator.evaluateIntermediate(this, blocks, offset, selected, driverContext); - } - - @Override - public void close() { - Releasables.close(states); - } - - void enableGroupIdTracking(SeenGroupIds seenGroupIds) { - // noop - we handle the null states inside `toIntermediate` and `evaluateFinal` - } - } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationIntAggregator.java index 0532d89675c9d..dc1dc0c716684 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationIntAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationIntAggregator.java @@ -8,16 +8,13 @@ package org.elasticsearch.compute.aggregation; import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.common.util.ObjectArray; import org.elasticsearch.compute.ann.Aggregator; import org.elasticsearch.compute.ann.GroupingAggregator; import org.elasticsearch.compute.ann.IntermediateState; import org.elasticsearch.compute.data.Block; -import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.core.Releasables; /** * A standard deviation aggregation definition for int. @@ -32,27 +29,19 @@ @GroupingAggregator public class StdDeviationIntAggregator { - public static StdDeviationIntState initSingle() { - return new StdDeviationIntState(); + public static StdDeviationStates.StdDeviationState initSingle() { + return new StdDeviationStates.StdDeviationState(); } - public static void combine(StdDeviationIntState state, int value) { + public static void combine(StdDeviationStates.StdDeviationState state, int value) { state.add(value); } - public static void combineIntermediate(StdDeviationIntState state, double mean, double m2, long count) { + public static void combineIntermediate(StdDeviationStates.StdDeviationState state, double mean, double m2, long count) { state.combine(mean, m2, count); } - public static void evaluateIntermediate(StdDeviationIntState state, DriverContext driverContext, Block[] blocks, int offset) { - assert blocks.length >= offset + 3; - BlockFactory blockFactory = driverContext.blockFactory(); - blocks[offset + 0] = blockFactory.newConstantDoubleBlockWith(state.mean(), 1); - blocks[offset + 1] = blockFactory.newConstantDoubleBlockWith(state.m2(), 1); - blocks[offset + 2] = blockFactory.newConstantLongBlockWith(state.count(), 1); - } - - public static Block evaluateFinal(StdDeviationIntState state, DriverContext driverContext) { + public static Block evaluateFinal(StdDeviationStates.StdDeviationState state, DriverContext driverContext) { final long count = state.count(); final double m2 = state.m2(); if (count == 0 || Double.isFinite(m2) == false) { @@ -61,67 +50,38 @@ public static Block evaluateFinal(StdDeviationIntState state, DriverContext driv return driverContext.blockFactory().newConstantDoubleBlockWith(state.evaluateFinal(), 1); } - public static GroupingStdDeviationIntState initGrouping(BigArrays bigArrays) { - return new GroupingStdDeviationIntState(bigArrays); + public static StdDeviationStates.GroupingStdDeviationState initGrouping(BigArrays bigArrays) { + return new StdDeviationStates.GroupingStdDeviationState(bigArrays); } - public static void combine(GroupingStdDeviationIntState current, int groupId, int value) { + public static void combine(StdDeviationStates.GroupingStdDeviationState current, int groupId, int value) { current.add(groupId, value); } public static void combineStates( - GroupingStdDeviationIntState current, + StdDeviationStates.GroupingStdDeviationState current, int groupId, - GroupingStdDeviationIntState state, + StdDeviationStates.GroupingStdDeviationState state, int statePosition ) { - var st = state.states.get(statePosition); - if (st != null) { - current.combine(groupId, st.mean(), st.m2(), st.count()); - } + current.combine(groupId, state.getOrNull(statePosition)); } - public static void combineIntermediate(GroupingStdDeviationIntState state, int groupId, double mean, double m2, long count) { - state.combine(groupId, mean, m2, count); - } - - public static void evaluateIntermediate( - GroupingStdDeviationIntState state, - Block[] blocks, - int offset, - IntVector selected, - DriverContext driverContext + public static void combineIntermediate( + StdDeviationStates.GroupingStdDeviationState state, + int groupId, + double mean, + double m2, + long count ) { - assert blocks.length >= offset + 3 : "blocks=" + blocks.length + ",offset=" + offset; - try ( - var meanBuilder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount()); - var m2Builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount()); - var countBuilder = driverContext.blockFactory().newLongBlockBuilder(selected.getPositionCount()); - ) { - for (int i = 0; i < selected.getPositionCount(); i++) { - final var groupId = selected.getInt(i); - final var st = groupId < state.states.size() ? state.states.get(groupId) : null; - if (st != null) { - meanBuilder.appendDouble(st.mean()); - m2Builder.appendDouble(st.m2()); - countBuilder.appendLong(st.count()); - } else { - meanBuilder.appendNull(); - m2Builder.appendNull(); - countBuilder.appendNull(); - } - } - blocks[offset + 0] = meanBuilder.build(); - blocks[offset + 1] = m2Builder.build(); - blocks[offset + 2] = countBuilder.build(); - } + state.combine(groupId, mean, m2, count); } - public static Block evaluateFinal(GroupingStdDeviationIntState state, IntVector selected, DriverContext driverContext) { + public static Block evaluateFinal(StdDeviationStates.GroupingStdDeviationState state, IntVector selected, DriverContext driverContext) { try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount())) { for (int i = 0; i < selected.getPositionCount(); i++) { final var groupId = selected.getInt(i); - final var st = groupId < state.states.size() ? state.states.get(groupId) : null; + final var st = state.getOrNull(groupId); if (st != null) { final var m2 = st.m2(); if (Double.isFinite(m2) == false) { @@ -136,99 +96,4 @@ public static Block evaluateFinal(GroupingStdDeviationIntState state, IntVector return builder.build(); } } - - static final class StdDeviationIntState implements AggregatorState { - - private WelfordAlgorithm welfordAlgorithm; - - StdDeviationIntState() { - this(0, 0, 0); - } - - StdDeviationIntState(double mean, double m2, long count) { - this.welfordAlgorithm = new WelfordAlgorithm(mean, m2, count); - } - - public void add(int value) { - welfordAlgorithm.add(value); - } - - public void combine(double mean, double m2, long count) { - welfordAlgorithm.add(mean, m2, count); - } - - @Override - public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { - StdDeviationIntAggregator.evaluateIntermediate(this, driverContext, blocks, offset); - } - - @Override - public void close() {} - - public double mean() { - return welfordAlgorithm.mean(); - } - - public double m2() { - return welfordAlgorithm.m2(); - } - - public long count() { - return welfordAlgorithm.count(); - } - - public double evaluateFinal() { - return welfordAlgorithm.evaluate(); - } - } - - static final class GroupingStdDeviationIntState implements GroupingAggregatorState { - - private ObjectArray states; - private final BigArrays bigArrays; - - GroupingStdDeviationIntState(BigArrays bigArrays) { - this.states = bigArrays.newObjectArray(1); - this.bigArrays = bigArrays; - } - - public void combine(int groupId, double meanValue, double m2Value, long countValue) { - ensureCapacity(groupId); - var state = states.get(groupId); - if (state == null) { - state = new StdDeviationIntState(meanValue, m2Value, countValue); - states.set(groupId, state); - } else { - state.combine(meanValue, m2Value, countValue); - } - } - - public void add(int groupId, int value) { - ensureCapacity(groupId); - var state = states.get(groupId); - if (state == null) { - state = new StdDeviationIntState(); - states.set(groupId, state); - } - state.add(value); - } - - private void ensureCapacity(int groupId) { - states = bigArrays.grow(states, groupId + 1); - } - - @Override - public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { - StdDeviationIntAggregator.evaluateIntermediate(this, blocks, offset, selected, driverContext); - } - - @Override - public void close() { - Releasables.close(states); - } - - void enableGroupIdTracking(SeenGroupIds seenGroupIds) { - // noop - we handle the null states inside `toIntermediate` and `evaluateFinal` - } - } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationLongAggregator.java index e13131104d82b..aa9d4dfd43983 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationLongAggregator.java @@ -8,16 +8,13 @@ package org.elasticsearch.compute.aggregation; import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.common.util.ObjectArray; import org.elasticsearch.compute.ann.Aggregator; import org.elasticsearch.compute.ann.GroupingAggregator; import org.elasticsearch.compute.ann.IntermediateState; import org.elasticsearch.compute.data.Block; -import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.core.Releasables; /** * A standard deviation aggregation definition for long. @@ -32,27 +29,19 @@ @GroupingAggregator public class StdDeviationLongAggregator { - public static StdDeviationLongState initSingle() { - return new StdDeviationLongState(); + public static StdDeviationStates.StdDeviationState initSingle() { + return new StdDeviationStates.StdDeviationState(); } - public static void combine(StdDeviationLongState state, long value) { + public static void combine(StdDeviationStates.StdDeviationState state, long value) { state.add(value); } - public static void combineIntermediate(StdDeviationLongState state, double mean, double m2, long count) { + public static void combineIntermediate(StdDeviationStates.StdDeviationState state, double mean, double m2, long count) { state.combine(mean, m2, count); } - public static void evaluateIntermediate(StdDeviationLongState state, DriverContext driverContext, Block[] blocks, int offset) { - assert blocks.length >= offset + 3; - BlockFactory blockFactory = driverContext.blockFactory(); - blocks[offset + 0] = blockFactory.newConstantDoubleBlockWith(state.mean(), 1); - blocks[offset + 1] = blockFactory.newConstantDoubleBlockWith(state.m2(), 1); - blocks[offset + 2] = blockFactory.newConstantLongBlockWith(state.count(), 1); - } - - public static Block evaluateFinal(StdDeviationLongState state, DriverContext driverContext) { + public static Block evaluateFinal(StdDeviationStates.StdDeviationState state, DriverContext driverContext) { final long count = state.count(); final double m2 = state.m2(); if (count == 0 || Double.isFinite(m2) == false) { @@ -61,67 +50,38 @@ public static Block evaluateFinal(StdDeviationLongState state, DriverContext dri return driverContext.blockFactory().newConstantDoubleBlockWith(state.evaluateFinal(), 1); } - public static GroupingStdDeviationLongState initGrouping(BigArrays bigArrays) { - return new GroupingStdDeviationLongState(bigArrays); + public static StdDeviationStates.GroupingStdDeviationState initGrouping(BigArrays bigArrays) { + return new StdDeviationStates.GroupingStdDeviationState(bigArrays); } - public static void combine(GroupingStdDeviationLongState current, int groupId, long value) { + public static void combine(StdDeviationStates.GroupingStdDeviationState current, int groupId, long value) { current.add(groupId, value); } public static void combineStates( - GroupingStdDeviationLongState current, + StdDeviationStates.GroupingStdDeviationState current, int groupId, - GroupingStdDeviationLongState state, + StdDeviationStates.GroupingStdDeviationState state, int statePosition ) { - var st = state.states.get(statePosition); - if (st != null) { - current.combine(groupId, st.mean(), st.m2(), st.count()); - } + current.combine(groupId, state.getOrNull(statePosition)); } - public static void combineIntermediate(GroupingStdDeviationLongState state, int groupId, double mean, double m2, long count) { - state.combine(groupId, mean, m2, count); - } - - public static void evaluateIntermediate( - GroupingStdDeviationLongState state, - Block[] blocks, - int offset, - IntVector selected, - DriverContext driverContext + public static void combineIntermediate( + StdDeviationStates.GroupingStdDeviationState state, + int groupId, + double mean, + double m2, + long count ) { - assert blocks.length >= offset + 3 : "blocks=" + blocks.length + ",offset=" + offset; - try ( - var meanBuilder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount()); - var m2Builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount()); - var countBuilder = driverContext.blockFactory().newLongBlockBuilder(selected.getPositionCount()); - ) { - for (int i = 0; i < selected.getPositionCount(); i++) { - final var groupId = selected.getInt(i); - final var st = groupId < state.states.size() ? state.states.get(groupId) : null; - if (st != null) { - meanBuilder.appendDouble(st.mean()); - m2Builder.appendDouble(st.m2()); - countBuilder.appendLong(st.count()); - } else { - meanBuilder.appendNull(); - m2Builder.appendNull(); - countBuilder.appendNull(); - } - } - blocks[offset + 0] = meanBuilder.build(); - blocks[offset + 1] = m2Builder.build(); - blocks[offset + 2] = countBuilder.build(); - } + state.combine(groupId, mean, m2, count); } - public static Block evaluateFinal(GroupingStdDeviationLongState state, IntVector selected, DriverContext driverContext) { + public static Block evaluateFinal(StdDeviationStates.GroupingStdDeviationState state, IntVector selected, DriverContext driverContext) { try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount())) { for (int i = 0; i < selected.getPositionCount(); i++) { final var groupId = selected.getInt(i); - final var st = groupId < state.states.size() ? state.states.get(groupId) : null; + final var st = state.getOrNull(groupId); if (st != null) { final var m2 = st.m2(); if (Double.isFinite(m2) == false) { @@ -136,99 +96,4 @@ public static Block evaluateFinal(GroupingStdDeviationLongState state, IntVector return builder.build(); } } - - static final class StdDeviationLongState implements AggregatorState { - - private WelfordAlgorithm welfordAlgorithm; - - StdDeviationLongState() { - this(0, 0, 0); - } - - StdDeviationLongState(double mean, double m2, long count) { - this.welfordAlgorithm = new WelfordAlgorithm(mean, m2, count); - } - - public void add(long value) { - welfordAlgorithm.add(value); - } - - public void combine(double mean, double m2, long count) { - welfordAlgorithm.add(mean, m2, count); - } - - @Override - public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { - StdDeviationLongAggregator.evaluateIntermediate(this, driverContext, blocks, offset); - } - - @Override - public void close() {} - - public double mean() { - return welfordAlgorithm.mean(); - } - - public double m2() { - return welfordAlgorithm.m2(); - } - - public long count() { - return welfordAlgorithm.count(); - } - - public double evaluateFinal() { - return welfordAlgorithm.evaluate(); - } - } - - static final class GroupingStdDeviationLongState implements GroupingAggregatorState { - - private ObjectArray states; - private final BigArrays bigArrays; - - GroupingStdDeviationLongState(BigArrays bigArrays) { - this.states = bigArrays.newObjectArray(1); - this.bigArrays = bigArrays; - } - - public void combine(int groupId, double meanValue, double m2Value, long countValue) { - ensureCapacity(groupId); - var state = states.get(groupId); - if (state == null) { - state = new StdDeviationLongState(meanValue, m2Value, countValue); - states.set(groupId, state); - } else { - state.combine(meanValue, m2Value, countValue); - } - } - - public void add(int groupId, long value) { - ensureCapacity(groupId); - var state = states.get(groupId); - if (state == null) { - state = new StdDeviationLongState(); - states.set(groupId, state); - } - state.add(value); - } - - private void ensureCapacity(int groupId) { - states = bigArrays.grow(states, groupId + 1); - } - - @Override - public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { - StdDeviationLongAggregator.evaluateIntermediate(this, blocks, offset, selected, driverContext); - } - - @Override - public void close() { - Releasables.close(states); - } - - void enableGroupIdTracking(SeenGroupIds seenGroupIds) { - // noop - we handle the null states inside `toIntermediate` and `evaluateFinal` - } - } } diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregatorFunction.java index c479ac9234b8b..324dc47d2be8e 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregatorFunction.java @@ -31,12 +31,12 @@ public final class StdDeviationDoubleAggregatorFunction implements AggregatorFun private final DriverContext driverContext; - private final StdDeviationDoubleAggregator.StdDeviationDoubleState state; + private final StdDeviationStates.StdDeviationState state; private final List channels; public StdDeviationDoubleAggregatorFunction(DriverContext driverContext, List channels, - StdDeviationDoubleAggregator.StdDeviationDoubleState state) { + StdDeviationStates.StdDeviationState state) { this.driverContext = driverContext; this.channels = channels; this.state = state; diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationDoubleGroupingAggregatorFunction.java index f9ad9ff3eb5db..6c77b6a3b19bb 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationDoubleGroupingAggregatorFunction.java @@ -30,15 +30,14 @@ public final class StdDeviationDoubleGroupingAggregatorFunction implements Group new IntermediateStateDesc("m2", ElementType.DOUBLE), new IntermediateStateDesc("count", ElementType.LONG) ); - private final StdDeviationDoubleAggregator.GroupingStdDeviationDoubleState state; + private final StdDeviationStates.GroupingStdDeviationState state; private final List channels; private final DriverContext driverContext; public StdDeviationDoubleGroupingAggregatorFunction(List channels, - StdDeviationDoubleAggregator.GroupingStdDeviationDoubleState state, - DriverContext driverContext) { + StdDeviationStates.GroupingStdDeviationState state, DriverContext driverContext) { this.channels = channels; this.state = state; this.driverContext = driverContext; @@ -192,7 +191,7 @@ public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction inpu if (input.getClass() != getClass()) { throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); } - StdDeviationDoubleAggregator.GroupingStdDeviationDoubleState inState = ((StdDeviationDoubleGroupingAggregatorFunction) input).state; + StdDeviationStates.GroupingStdDeviationState inState = ((StdDeviationDoubleGroupingAggregatorFunction) input).state; state.enableGroupIdTracking(new SeenGroupIds.Empty()); StdDeviationDoubleAggregator.combineStates(state, groupId, inState, position); } diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationFloatAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationFloatAggregatorFunction.java index 82f2edcf5b6da..17d19f67b91ca 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationFloatAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationFloatAggregatorFunction.java @@ -33,12 +33,12 @@ public final class StdDeviationFloatAggregatorFunction implements AggregatorFunc private final DriverContext driverContext; - private final StdDeviationFloatAggregator.StdDeviationFloatState state; + private final StdDeviationStates.StdDeviationState state; private final List channels; public StdDeviationFloatAggregatorFunction(DriverContext driverContext, List channels, - StdDeviationFloatAggregator.StdDeviationFloatState state) { + StdDeviationStates.StdDeviationState state) { this.driverContext = driverContext; this.channels = channels; this.state = state; diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationFloatGroupingAggregatorFunction.java index 82f568283378c..521564fa3942f 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationFloatGroupingAggregatorFunction.java @@ -32,15 +32,14 @@ public final class StdDeviationFloatGroupingAggregatorFunction implements Groupi new IntermediateStateDesc("m2", ElementType.DOUBLE), new IntermediateStateDesc("count", ElementType.LONG) ); - private final StdDeviationFloatAggregator.GroupingStdDeviationFloatState state; + private final StdDeviationStates.GroupingStdDeviationState state; private final List channels; private final DriverContext driverContext; public StdDeviationFloatGroupingAggregatorFunction(List channels, - StdDeviationFloatAggregator.GroupingStdDeviationFloatState state, - DriverContext driverContext) { + StdDeviationStates.GroupingStdDeviationState state, DriverContext driverContext) { this.channels = channels; this.state = state; this.driverContext = driverContext; @@ -194,7 +193,7 @@ public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction inpu if (input.getClass() != getClass()) { throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); } - StdDeviationFloatAggregator.GroupingStdDeviationFloatState inState = ((StdDeviationFloatGroupingAggregatorFunction) input).state; + StdDeviationStates.GroupingStdDeviationState inState = ((StdDeviationFloatGroupingAggregatorFunction) input).state; state.enableGroupIdTracking(new SeenGroupIds.Empty()); StdDeviationFloatAggregator.combineStates(state, groupId, inState, position); } diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationIntAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationIntAggregatorFunction.java index 90b1e163f4668..f941edd3f011e 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationIntAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationIntAggregatorFunction.java @@ -33,12 +33,12 @@ public final class StdDeviationIntAggregatorFunction implements AggregatorFuncti private final DriverContext driverContext; - private final StdDeviationIntAggregator.StdDeviationIntState state; + private final StdDeviationStates.StdDeviationState state; private final List channels; public StdDeviationIntAggregatorFunction(DriverContext driverContext, List channels, - StdDeviationIntAggregator.StdDeviationIntState state) { + StdDeviationStates.StdDeviationState state) { this.driverContext = driverContext; this.channels = channels; this.state = state; diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationIntGroupingAggregatorFunction.java index efac6e27f7ea5..53938feb15ff1 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationIntGroupingAggregatorFunction.java @@ -30,14 +30,14 @@ public final class StdDeviationIntGroupingAggregatorFunction implements Grouping new IntermediateStateDesc("m2", ElementType.DOUBLE), new IntermediateStateDesc("count", ElementType.LONG) ); - private final StdDeviationIntAggregator.GroupingStdDeviationIntState state; + private final StdDeviationStates.GroupingStdDeviationState state; private final List channels; private final DriverContext driverContext; public StdDeviationIntGroupingAggregatorFunction(List channels, - StdDeviationIntAggregator.GroupingStdDeviationIntState state, DriverContext driverContext) { + StdDeviationStates.GroupingStdDeviationState state, DriverContext driverContext) { this.channels = channels; this.state = state; this.driverContext = driverContext; @@ -191,7 +191,7 @@ public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction inpu if (input.getClass() != getClass()) { throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); } - StdDeviationIntAggregator.GroupingStdDeviationIntState inState = ((StdDeviationIntGroupingAggregatorFunction) input).state; + StdDeviationStates.GroupingStdDeviationState inState = ((StdDeviationIntGroupingAggregatorFunction) input).state; state.enableGroupIdTracking(new SeenGroupIds.Empty()); StdDeviationIntAggregator.combineStates(state, groupId, inState, position); } diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationLongAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationLongAggregatorFunction.java index fc801e35403b3..f7b59522db89d 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationLongAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationLongAggregatorFunction.java @@ -31,12 +31,12 @@ public final class StdDeviationLongAggregatorFunction implements AggregatorFunct private final DriverContext driverContext; - private final StdDeviationLongAggregator.StdDeviationLongState state; + private final StdDeviationStates.StdDeviationState state; private final List channels; public StdDeviationLongAggregatorFunction(DriverContext driverContext, List channels, - StdDeviationLongAggregator.StdDeviationLongState state) { + StdDeviationStates.StdDeviationState state) { this.driverContext = driverContext; this.channels = channels; this.state = state; diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationLongGroupingAggregatorFunction.java index 4154b70f43277..19700a071028b 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationLongGroupingAggregatorFunction.java @@ -30,14 +30,14 @@ public final class StdDeviationLongGroupingAggregatorFunction implements Groupin new IntermediateStateDesc("m2", ElementType.DOUBLE), new IntermediateStateDesc("count", ElementType.LONG) ); - private final StdDeviationLongAggregator.GroupingStdDeviationLongState state; + private final StdDeviationStates.GroupingStdDeviationState state; private final List channels; private final DriverContext driverContext; public StdDeviationLongGroupingAggregatorFunction(List channels, - StdDeviationLongAggregator.GroupingStdDeviationLongState state, DriverContext driverContext) { + StdDeviationStates.GroupingStdDeviationState state, DriverContext driverContext) { this.channels = channels; this.state = state; this.driverContext = driverContext; @@ -191,7 +191,7 @@ public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction inpu if (input.getClass() != getClass()) { throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); } - StdDeviationLongAggregator.GroupingStdDeviationLongState inState = ((StdDeviationLongGroupingAggregatorFunction) input).state; + StdDeviationStates.GroupingStdDeviationState inState = ((StdDeviationLongGroupingAggregatorFunction) input).state; state.enableGroupIdTracking(new SeenGroupIds.Empty()); StdDeviationLongAggregator.combineStates(state, groupId, inState, position); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDeviationStates.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDeviationStates.java new file mode 100644 index 0000000000000..3a9ae6f1a7a58 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDeviationStates.java @@ -0,0 +1,180 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.ObjectArray; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; + +public final class StdDeviationStates { + + private StdDeviationStates() {} + + static final class StdDeviationState implements AggregatorState { + + private WelfordAlgorithm welfordAlgorithm; + + StdDeviationState() { + this(0, 0, 0); + } + + StdDeviationState(double mean, double m2, long count) { + this.welfordAlgorithm = new WelfordAlgorithm(mean, m2, count); + } + + public void add(long value) { + welfordAlgorithm.add(value); + } + + public void add(double value) { + welfordAlgorithm.add(value); + } + + public void add(int value) { + welfordAlgorithm.add(value); + } + + public void combine(double mean, double m2, long count) { + welfordAlgorithm.add(mean, m2, count); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + assert blocks.length >= offset + 3; + BlockFactory blockFactory = driverContext.blockFactory(); + blocks[offset + 0] = blockFactory.newConstantDoubleBlockWith(mean(), 1); + blocks[offset + 1] = blockFactory.newConstantDoubleBlockWith(m2(), 1); + blocks[offset + 2] = blockFactory.newConstantLongBlockWith(count(), 1); + } + + @Override + public void close() {} + + public double mean() { + return welfordAlgorithm.mean(); + } + + public double m2() { + return welfordAlgorithm.m2(); + } + + public long count() { + return welfordAlgorithm.count(); + } + + public double evaluateFinal() { + return welfordAlgorithm.evaluate(); + } + } + + static final class GroupingStdDeviationState implements GroupingAggregatorState { + + private ObjectArray states; + private final BigArrays bigArrays; + + GroupingStdDeviationState(BigArrays bigArrays) { + this.states = bigArrays.newObjectArray(1); + this.bigArrays = bigArrays; + } + + StdDeviationState getOrNull(int position) { + if (position < states.size()) { + return states.get(position); + } else { + return null; + } + } + + public void combine(int groupId, StdDeviationState state) { + if (state == null) { + return; + } + combine(groupId, state.mean(), state.m2(), state.count()); + } + + public void combine(int groupId, double meanValue, double m2Value, long countValue) { + ensureCapacity(groupId); + var state = states.get(groupId); + if (state == null) { + state = new StdDeviationState(meanValue, m2Value, countValue); + states.set(groupId, state); + } else { + state.combine(meanValue, m2Value, countValue); + } + } + + public StdDeviationState getOrSet(int groupId) { + ensureCapacity(groupId); + var state = states.get(groupId); + if (state == null) { + state = new StdDeviationState(); + states.set(groupId, state); + } + return state; + } + + public void add(int groupId, long value) { + var state = getOrSet(groupId); + state.add(value); + } + + public void add(int groupId, double value) { + var state = getOrSet(groupId); + state.add(value); + } + + public void add(int groupId, int value) { + var state = getOrSet(groupId); + state.add(value); + } + + private void ensureCapacity(int groupId) { + states = bigArrays.grow(states, groupId + 1); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + assert blocks.length >= offset + 3 : "blocks=" + blocks.length + ",offset=" + offset; + try ( + var meanBuilder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount()); + var m2Builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount()); + var countBuilder = driverContext.blockFactory().newLongBlockBuilder(selected.getPositionCount()); + ) { + for (int i = 0; i < selected.getPositionCount(); i++) { + final var groupId = selected.getInt(i); + final var state = groupId < states.size() ? states.get(groupId) : null; + if (state != null) { + meanBuilder.appendDouble(state.mean()); + m2Builder.appendDouble(state.m2()); + countBuilder.appendLong(state.count()); + } else { + meanBuilder.appendNull(); + m2Builder.appendNull(); + countBuilder.appendNull(); + } + } + blocks[offset + 0] = meanBuilder.build(); + blocks[offset + 1] = m2Builder.build(); + blocks[offset + 2] = countBuilder.build(); + } + } + + @Override + public void close() { + Releasables.close(states); + } + + void enableGroupIdTracking(SeenGroupIds seenGroupIds) { + // noop - we handle the null states inside `toIntermediate` and `evaluateFinal` + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDeviationAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDeviationAggregator.java.st index ff9d081358496..cd14799d0ef9f 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDeviationAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDeviationAggregator.java.st @@ -8,16 +8,13 @@ package org.elasticsearch.compute.aggregation; import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.common.util.ObjectArray; import org.elasticsearch.compute.ann.Aggregator; import org.elasticsearch.compute.ann.GroupingAggregator; import org.elasticsearch.compute.ann.IntermediateState; import org.elasticsearch.compute.data.Block; -import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.core.Releasables; /** * A standard deviation aggregation definition for $type$. @@ -32,27 +29,19 @@ import org.elasticsearch.core.Releasables; @GroupingAggregator public class StdDeviation$Type$Aggregator { - public static StdDeviation$Type$State initSingle() { - return new StdDeviation$Type$State(); + public static StdDeviationStates.StdDeviationState initSingle() { + return new StdDeviationStates.StdDeviationState(); } - public static void combine(StdDeviation$Type$State state, $type$ value) { + public static void combine(StdDeviationStates.StdDeviationState state, $type$ value) { state.add(value); } - public static void combineIntermediate(StdDeviation$Type$State state, double mean, double m2, long count) { + public static void combineIntermediate(StdDeviationStates.StdDeviationState state, double mean, double m2, long count) { state.combine(mean, m2, count); } - public static void evaluateIntermediate(StdDeviation$Type$State state, DriverContext driverContext, Block[] blocks, int offset) { - assert blocks.length >= offset + 3; - BlockFactory blockFactory = driverContext.blockFactory(); - blocks[offset + 0] = blockFactory.newConstantDoubleBlockWith(state.mean(), 1); - blocks[offset + 1] = blockFactory.newConstantDoubleBlockWith(state.m2(), 1); - blocks[offset + 2] = blockFactory.newConstantLongBlockWith(state.count(), 1); - } - - public static Block evaluateFinal(StdDeviation$Type$State state, DriverContext driverContext) { + public static Block evaluateFinal(StdDeviationStates.StdDeviationState state, DriverContext driverContext) { final long count = state.count(); final double m2 = state.m2(); if (count == 0 || Double.isFinite(m2) == false) { @@ -61,67 +50,38 @@ public class StdDeviation$Type$Aggregator { return driverContext.blockFactory().newConstantDoubleBlockWith(state.evaluateFinal(), 1); } - public static GroupingStdDeviation$Type$State initGrouping(BigArrays bigArrays) { - return new GroupingStdDeviation$Type$State(bigArrays); + public static StdDeviationStates.GroupingStdDeviationState initGrouping(BigArrays bigArrays) { + return new StdDeviationStates.GroupingStdDeviationState(bigArrays); } - public static void combine(GroupingStdDeviation$Type$State current, int groupId, $type$ value) { + public static void combine(StdDeviationStates.GroupingStdDeviationState current, int groupId, $type$ value) { current.add(groupId, value); } public static void combineStates( - GroupingStdDeviation$Type$State current, + StdDeviationStates.GroupingStdDeviationState current, int groupId, - GroupingStdDeviation$Type$State state, + StdDeviationStates.GroupingStdDeviationState state, int statePosition ) { - var st = state.states.get(statePosition); - if (st != null) { - current.combine(groupId, st.mean(), st.m2(), st.count()); - } + current.combine(groupId, state.getOrNull(statePosition)); } - public static void combineIntermediate(GroupingStdDeviation$Type$State state, int groupId, double mean, double m2, long count) { - state.combine(groupId, mean, m2, count); - } - - public static void evaluateIntermediate( - GroupingStdDeviation$Type$State state, - Block[] blocks, - int offset, - IntVector selected, - DriverContext driverContext + public static void combineIntermediate( + StdDeviationStates.GroupingStdDeviationState state, + int groupId, + double mean, + double m2, + long count ) { - assert blocks.length >= offset + 3 : "blocks=" + blocks.length + ",offset=" + offset; - try ( - var meanBuilder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount()); - var m2Builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount()); - var countBuilder = driverContext.blockFactory().newLongBlockBuilder(selected.getPositionCount()); - ) { - for (int i = 0; i < selected.getPositionCount(); i++) { - final var groupId = selected.getInt(i); - final var st = groupId < state.states.size() ? state.states.get(groupId) : null; - if (st != null) { - meanBuilder.appendDouble(st.mean()); - m2Builder.appendDouble(st.m2()); - countBuilder.appendLong(st.count()); - } else { - meanBuilder.appendNull(); - m2Builder.appendNull(); - countBuilder.appendNull(); - } - } - blocks[offset + 0] = meanBuilder.build(); - blocks[offset + 1] = m2Builder.build(); - blocks[offset + 2] = countBuilder.build(); - } + state.combine(groupId, mean, m2, count); } - public static Block evaluateFinal(GroupingStdDeviation$Type$State state, IntVector selected, DriverContext driverContext) { + public static Block evaluateFinal(StdDeviationStates.GroupingStdDeviationState state, IntVector selected, DriverContext driverContext) { try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount())) { for (int i = 0; i < selected.getPositionCount(); i++) { final var groupId = selected.getInt(i); - final var st = groupId < state.states.size() ? state.states.get(groupId) : null; + final var st = state.getOrNull(groupId); if (st != null) { final var m2 = st.m2(); if (Double.isFinite(m2) == false) { @@ -136,99 +96,4 @@ public class StdDeviation$Type$Aggregator { return builder.build(); } } - - static final class StdDeviation$Type$State implements AggregatorState { - - private WelfordAlgorithm welfordAlgorithm; - - StdDeviation$Type$State() { - this(0, 0, 0); - } - - StdDeviation$Type$State(double mean, double m2, long count) { - this.welfordAlgorithm = new WelfordAlgorithm(mean, m2, count); - } - - public void add($type$ value) { - welfordAlgorithm.add(value); - } - - public void combine(double mean, double m2, long count) { - welfordAlgorithm.add(mean, m2, count); - } - - @Override - public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { - StdDeviation$Type$Aggregator.evaluateIntermediate(this, driverContext, blocks, offset); - } - - @Override - public void close() {} - - public double mean() { - return welfordAlgorithm.mean(); - } - - public double m2() { - return welfordAlgorithm.m2(); - } - - public long count() { - return welfordAlgorithm.count(); - } - - public double evaluateFinal() { - return welfordAlgorithm.evaluate(); - } - } - - static final class GroupingStdDeviation$Type$State implements GroupingAggregatorState { - - private ObjectArray states; - private final BigArrays bigArrays; - - GroupingStdDeviation$Type$State(BigArrays bigArrays) { - this.states = bigArrays.newObjectArray(1); - this.bigArrays = bigArrays; - } - - public void combine(int groupId, double meanValue, double m2Value, long countValue) { - ensureCapacity(groupId); - var state = states.get(groupId); - if (state == null) { - state = new StdDeviation$Type$State(meanValue, m2Value, countValue); - states.set(groupId, state); - } else { - state.combine(meanValue, m2Value, countValue); - } - } - - public void add(int groupId, $type$ value) { - ensureCapacity(groupId); - var state = states.get(groupId); - if (state == null) { - state = new StdDeviation$Type$State(); - states.set(groupId, state); - } - state.add(value); - } - - private void ensureCapacity(int groupId) { - states = bigArrays.grow(states, groupId + 1); - } - - @Override - public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { - StdDeviation$Type$Aggregator.evaluateIntermediate(this, blocks, offset, selected, driverContext); - } - - @Override - public void close() { - Releasables.close(states); - } - - void enableGroupIdTracking(SeenGroupIds seenGroupIds) { - // noop - we handle the null states inside `toIntermediate` and `evaluateFinal` - } - } } From 9f8b3be4ea478b513cb5735f0830098c754bbb5a Mon Sep 17 00:00:00 2001 From: Larisa Motova Date: Thu, 14 Nov 2024 11:37:30 -1000 Subject: [PATCH 06/20] change State names to SingleState and GroupingState --- .../StdDeviationDoubleAggregator.java | 30 ++++++++----------- .../StdDeviationFloatAggregator.java | 30 ++++++++----------- .../StdDeviationIntAggregator.java | 30 ++++++++----------- .../StdDeviationLongAggregator.java | 30 ++++++++----------- .../StdDeviationDoubleAggregatorFunction.java | 4 +-- ...ationDoubleGroupingAggregatorFunction.java | 6 ++-- .../StdDeviationFloatAggregatorFunction.java | 4 +-- ...iationFloatGroupingAggregatorFunction.java | 6 ++-- .../StdDeviationIntAggregatorFunction.java | 4 +-- ...eviationIntGroupingAggregatorFunction.java | 6 ++-- .../StdDeviationLongAggregatorFunction.java | 4 +-- ...viationLongGroupingAggregatorFunction.java | 6 ++-- .../aggregation/StdDeviationStates.java | 22 +++++++------- .../X-StdDeviationAggregator.java.st | 30 ++++++++----------- 14 files changed, 91 insertions(+), 121 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregator.java index be1408915195d..cf06c36604f8d 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregator.java @@ -29,19 +29,19 @@ @GroupingAggregator public class StdDeviationDoubleAggregator { - public static StdDeviationStates.StdDeviationState initSingle() { - return new StdDeviationStates.StdDeviationState(); + public static StdDeviationStates.SingleState initSingle() { + return new StdDeviationStates.SingleState(); } - public static void combine(StdDeviationStates.StdDeviationState state, double value) { + public static void combine(StdDeviationStates.SingleState state, double value) { state.add(value); } - public static void combineIntermediate(StdDeviationStates.StdDeviationState state, double mean, double m2, long count) { + public static void combineIntermediate(StdDeviationStates.SingleState state, double mean, double m2, long count) { state.combine(mean, m2, count); } - public static Block evaluateFinal(StdDeviationStates.StdDeviationState state, DriverContext driverContext) { + public static Block evaluateFinal(StdDeviationStates.SingleState state, DriverContext driverContext) { final long count = state.count(); final double m2 = state.m2(); if (count == 0 || Double.isFinite(m2) == false) { @@ -50,34 +50,28 @@ public static Block evaluateFinal(StdDeviationStates.StdDeviationState state, Dr return driverContext.blockFactory().newConstantDoubleBlockWith(state.evaluateFinal(), 1); } - public static StdDeviationStates.GroupingStdDeviationState initGrouping(BigArrays bigArrays) { - return new StdDeviationStates.GroupingStdDeviationState(bigArrays); + public static StdDeviationStates.GroupingState initGrouping(BigArrays bigArrays) { + return new StdDeviationStates.GroupingState(bigArrays); } - public static void combine(StdDeviationStates.GroupingStdDeviationState current, int groupId, double value) { + public static void combine(StdDeviationStates.GroupingState current, int groupId, double value) { current.add(groupId, value); } public static void combineStates( - StdDeviationStates.GroupingStdDeviationState current, + StdDeviationStates.GroupingState current, int groupId, - StdDeviationStates.GroupingStdDeviationState state, + StdDeviationStates.GroupingState state, int statePosition ) { current.combine(groupId, state.getOrNull(statePosition)); } - public static void combineIntermediate( - StdDeviationStates.GroupingStdDeviationState state, - int groupId, - double mean, - double m2, - long count - ) { + public static void combineIntermediate(StdDeviationStates.GroupingState state, int groupId, double mean, double m2, long count) { state.combine(groupId, mean, m2, count); } - public static Block evaluateFinal(StdDeviationStates.GroupingStdDeviationState state, IntVector selected, DriverContext driverContext) { + public static Block evaluateFinal(StdDeviationStates.GroupingState state, IntVector selected, DriverContext driverContext) { try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount())) { for (int i = 0; i < selected.getPositionCount(); i++) { final var groupId = selected.getInt(i); diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationFloatAggregator.java index 84d58a61c5da7..321c76733120f 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationFloatAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationFloatAggregator.java @@ -29,19 +29,19 @@ @GroupingAggregator public class StdDeviationFloatAggregator { - public static StdDeviationStates.StdDeviationState initSingle() { - return new StdDeviationStates.StdDeviationState(); + public static StdDeviationStates.SingleState initSingle() { + return new StdDeviationStates.SingleState(); } - public static void combine(StdDeviationStates.StdDeviationState state, float value) { + public static void combine(StdDeviationStates.SingleState state, float value) { state.add(value); } - public static void combineIntermediate(StdDeviationStates.StdDeviationState state, double mean, double m2, long count) { + public static void combineIntermediate(StdDeviationStates.SingleState state, double mean, double m2, long count) { state.combine(mean, m2, count); } - public static Block evaluateFinal(StdDeviationStates.StdDeviationState state, DriverContext driverContext) { + public static Block evaluateFinal(StdDeviationStates.SingleState state, DriverContext driverContext) { final long count = state.count(); final double m2 = state.m2(); if (count == 0 || Double.isFinite(m2) == false) { @@ -50,34 +50,28 @@ public static Block evaluateFinal(StdDeviationStates.StdDeviationState state, Dr return driverContext.blockFactory().newConstantDoubleBlockWith(state.evaluateFinal(), 1); } - public static StdDeviationStates.GroupingStdDeviationState initGrouping(BigArrays bigArrays) { - return new StdDeviationStates.GroupingStdDeviationState(bigArrays); + public static StdDeviationStates.GroupingState initGrouping(BigArrays bigArrays) { + return new StdDeviationStates.GroupingState(bigArrays); } - public static void combine(StdDeviationStates.GroupingStdDeviationState current, int groupId, float value) { + public static void combine(StdDeviationStates.GroupingState current, int groupId, float value) { current.add(groupId, value); } public static void combineStates( - StdDeviationStates.GroupingStdDeviationState current, + StdDeviationStates.GroupingState current, int groupId, - StdDeviationStates.GroupingStdDeviationState state, + StdDeviationStates.GroupingState state, int statePosition ) { current.combine(groupId, state.getOrNull(statePosition)); } - public static void combineIntermediate( - StdDeviationStates.GroupingStdDeviationState state, - int groupId, - double mean, - double m2, - long count - ) { + public static void combineIntermediate(StdDeviationStates.GroupingState state, int groupId, double mean, double m2, long count) { state.combine(groupId, mean, m2, count); } - public static Block evaluateFinal(StdDeviationStates.GroupingStdDeviationState state, IntVector selected, DriverContext driverContext) { + public static Block evaluateFinal(StdDeviationStates.GroupingState state, IntVector selected, DriverContext driverContext) { try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount())) { for (int i = 0; i < selected.getPositionCount(); i++) { final var groupId = selected.getInt(i); diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationIntAggregator.java index dc1dc0c716684..44d253c35848a 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationIntAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationIntAggregator.java @@ -29,19 +29,19 @@ @GroupingAggregator public class StdDeviationIntAggregator { - public static StdDeviationStates.StdDeviationState initSingle() { - return new StdDeviationStates.StdDeviationState(); + public static StdDeviationStates.SingleState initSingle() { + return new StdDeviationStates.SingleState(); } - public static void combine(StdDeviationStates.StdDeviationState state, int value) { + public static void combine(StdDeviationStates.SingleState state, int value) { state.add(value); } - public static void combineIntermediate(StdDeviationStates.StdDeviationState state, double mean, double m2, long count) { + public static void combineIntermediate(StdDeviationStates.SingleState state, double mean, double m2, long count) { state.combine(mean, m2, count); } - public static Block evaluateFinal(StdDeviationStates.StdDeviationState state, DriverContext driverContext) { + public static Block evaluateFinal(StdDeviationStates.SingleState state, DriverContext driverContext) { final long count = state.count(); final double m2 = state.m2(); if (count == 0 || Double.isFinite(m2) == false) { @@ -50,34 +50,28 @@ public static Block evaluateFinal(StdDeviationStates.StdDeviationState state, Dr return driverContext.blockFactory().newConstantDoubleBlockWith(state.evaluateFinal(), 1); } - public static StdDeviationStates.GroupingStdDeviationState initGrouping(BigArrays bigArrays) { - return new StdDeviationStates.GroupingStdDeviationState(bigArrays); + public static StdDeviationStates.GroupingState initGrouping(BigArrays bigArrays) { + return new StdDeviationStates.GroupingState(bigArrays); } - public static void combine(StdDeviationStates.GroupingStdDeviationState current, int groupId, int value) { + public static void combine(StdDeviationStates.GroupingState current, int groupId, int value) { current.add(groupId, value); } public static void combineStates( - StdDeviationStates.GroupingStdDeviationState current, + StdDeviationStates.GroupingState current, int groupId, - StdDeviationStates.GroupingStdDeviationState state, + StdDeviationStates.GroupingState state, int statePosition ) { current.combine(groupId, state.getOrNull(statePosition)); } - public static void combineIntermediate( - StdDeviationStates.GroupingStdDeviationState state, - int groupId, - double mean, - double m2, - long count - ) { + public static void combineIntermediate(StdDeviationStates.GroupingState state, int groupId, double mean, double m2, long count) { state.combine(groupId, mean, m2, count); } - public static Block evaluateFinal(StdDeviationStates.GroupingStdDeviationState state, IntVector selected, DriverContext driverContext) { + public static Block evaluateFinal(StdDeviationStates.GroupingState state, IntVector selected, DriverContext driverContext) { try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount())) { for (int i = 0; i < selected.getPositionCount(); i++) { final var groupId = selected.getInt(i); diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationLongAggregator.java index aa9d4dfd43983..7b7c6b3dc19d7 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationLongAggregator.java @@ -29,19 +29,19 @@ @GroupingAggregator public class StdDeviationLongAggregator { - public static StdDeviationStates.StdDeviationState initSingle() { - return new StdDeviationStates.StdDeviationState(); + public static StdDeviationStates.SingleState initSingle() { + return new StdDeviationStates.SingleState(); } - public static void combine(StdDeviationStates.StdDeviationState state, long value) { + public static void combine(StdDeviationStates.SingleState state, long value) { state.add(value); } - public static void combineIntermediate(StdDeviationStates.StdDeviationState state, double mean, double m2, long count) { + public static void combineIntermediate(StdDeviationStates.SingleState state, double mean, double m2, long count) { state.combine(mean, m2, count); } - public static Block evaluateFinal(StdDeviationStates.StdDeviationState state, DriverContext driverContext) { + public static Block evaluateFinal(StdDeviationStates.SingleState state, DriverContext driverContext) { final long count = state.count(); final double m2 = state.m2(); if (count == 0 || Double.isFinite(m2) == false) { @@ -50,34 +50,28 @@ public static Block evaluateFinal(StdDeviationStates.StdDeviationState state, Dr return driverContext.blockFactory().newConstantDoubleBlockWith(state.evaluateFinal(), 1); } - public static StdDeviationStates.GroupingStdDeviationState initGrouping(BigArrays bigArrays) { - return new StdDeviationStates.GroupingStdDeviationState(bigArrays); + public static StdDeviationStates.GroupingState initGrouping(BigArrays bigArrays) { + return new StdDeviationStates.GroupingState(bigArrays); } - public static void combine(StdDeviationStates.GroupingStdDeviationState current, int groupId, long value) { + public static void combine(StdDeviationStates.GroupingState current, int groupId, long value) { current.add(groupId, value); } public static void combineStates( - StdDeviationStates.GroupingStdDeviationState current, + StdDeviationStates.GroupingState current, int groupId, - StdDeviationStates.GroupingStdDeviationState state, + StdDeviationStates.GroupingState state, int statePosition ) { current.combine(groupId, state.getOrNull(statePosition)); } - public static void combineIntermediate( - StdDeviationStates.GroupingStdDeviationState state, - int groupId, - double mean, - double m2, - long count - ) { + public static void combineIntermediate(StdDeviationStates.GroupingState state, int groupId, double mean, double m2, long count) { state.combine(groupId, mean, m2, count); } - public static Block evaluateFinal(StdDeviationStates.GroupingStdDeviationState state, IntVector selected, DriverContext driverContext) { + public static Block evaluateFinal(StdDeviationStates.GroupingState state, IntVector selected, DriverContext driverContext) { try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount())) { for (int i = 0; i < selected.getPositionCount(); i++) { final var groupId = selected.getInt(i); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregatorFunction.java index 324dc47d2be8e..437fddf081d3b 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregatorFunction.java @@ -31,12 +31,12 @@ public final class StdDeviationDoubleAggregatorFunction implements AggregatorFun private final DriverContext driverContext; - private final StdDeviationStates.StdDeviationState state; + private final StdDeviationStates.SingleState state; private final List channels; public StdDeviationDoubleAggregatorFunction(DriverContext driverContext, List channels, - StdDeviationStates.StdDeviationState state) { + StdDeviationStates.SingleState state) { this.driverContext = driverContext; this.channels = channels; this.state = state; diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationDoubleGroupingAggregatorFunction.java index 6c77b6a3b19bb..8d30b95c4c246 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationDoubleGroupingAggregatorFunction.java @@ -30,14 +30,14 @@ public final class StdDeviationDoubleGroupingAggregatorFunction implements Group new IntermediateStateDesc("m2", ElementType.DOUBLE), new IntermediateStateDesc("count", ElementType.LONG) ); - private final StdDeviationStates.GroupingStdDeviationState state; + private final StdDeviationStates.GroupingState state; private final List channels; private final DriverContext driverContext; public StdDeviationDoubleGroupingAggregatorFunction(List channels, - StdDeviationStates.GroupingStdDeviationState state, DriverContext driverContext) { + StdDeviationStates.GroupingState state, DriverContext driverContext) { this.channels = channels; this.state = state; this.driverContext = driverContext; @@ -191,7 +191,7 @@ public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction inpu if (input.getClass() != getClass()) { throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); } - StdDeviationStates.GroupingStdDeviationState inState = ((StdDeviationDoubleGroupingAggregatorFunction) input).state; + StdDeviationStates.GroupingState inState = ((StdDeviationDoubleGroupingAggregatorFunction) input).state; state.enableGroupIdTracking(new SeenGroupIds.Empty()); StdDeviationDoubleAggregator.combineStates(state, groupId, inState, position); } diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationFloatAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationFloatAggregatorFunction.java index 17d19f67b91ca..35bbbe3fa83d4 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationFloatAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationFloatAggregatorFunction.java @@ -33,12 +33,12 @@ public final class StdDeviationFloatAggregatorFunction implements AggregatorFunc private final DriverContext driverContext; - private final StdDeviationStates.StdDeviationState state; + private final StdDeviationStates.SingleState state; private final List channels; public StdDeviationFloatAggregatorFunction(DriverContext driverContext, List channels, - StdDeviationStates.StdDeviationState state) { + StdDeviationStates.SingleState state) { this.driverContext = driverContext; this.channels = channels; this.state = state; diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationFloatGroupingAggregatorFunction.java index 521564fa3942f..f22d6e866f7e7 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationFloatGroupingAggregatorFunction.java @@ -32,14 +32,14 @@ public final class StdDeviationFloatGroupingAggregatorFunction implements Groupi new IntermediateStateDesc("m2", ElementType.DOUBLE), new IntermediateStateDesc("count", ElementType.LONG) ); - private final StdDeviationStates.GroupingStdDeviationState state; + private final StdDeviationStates.GroupingState state; private final List channels; private final DriverContext driverContext; public StdDeviationFloatGroupingAggregatorFunction(List channels, - StdDeviationStates.GroupingStdDeviationState state, DriverContext driverContext) { + StdDeviationStates.GroupingState state, DriverContext driverContext) { this.channels = channels; this.state = state; this.driverContext = driverContext; @@ -193,7 +193,7 @@ public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction inpu if (input.getClass() != getClass()) { throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); } - StdDeviationStates.GroupingStdDeviationState inState = ((StdDeviationFloatGroupingAggregatorFunction) input).state; + StdDeviationStates.GroupingState inState = ((StdDeviationFloatGroupingAggregatorFunction) input).state; state.enableGroupIdTracking(new SeenGroupIds.Empty()); StdDeviationFloatAggregator.combineStates(state, groupId, inState, position); } diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationIntAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationIntAggregatorFunction.java index f941edd3f011e..5bfd4fb495f29 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationIntAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationIntAggregatorFunction.java @@ -33,12 +33,12 @@ public final class StdDeviationIntAggregatorFunction implements AggregatorFuncti private final DriverContext driverContext; - private final StdDeviationStates.StdDeviationState state; + private final StdDeviationStates.SingleState state; private final List channels; public StdDeviationIntAggregatorFunction(DriverContext driverContext, List channels, - StdDeviationStates.StdDeviationState state) { + StdDeviationStates.SingleState state) { this.driverContext = driverContext; this.channels = channels; this.state = state; diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationIntGroupingAggregatorFunction.java index 53938feb15ff1..661bc068be26f 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationIntGroupingAggregatorFunction.java @@ -30,14 +30,14 @@ public final class StdDeviationIntGroupingAggregatorFunction implements Grouping new IntermediateStateDesc("m2", ElementType.DOUBLE), new IntermediateStateDesc("count", ElementType.LONG) ); - private final StdDeviationStates.GroupingStdDeviationState state; + private final StdDeviationStates.GroupingState state; private final List channels; private final DriverContext driverContext; public StdDeviationIntGroupingAggregatorFunction(List channels, - StdDeviationStates.GroupingStdDeviationState state, DriverContext driverContext) { + StdDeviationStates.GroupingState state, DriverContext driverContext) { this.channels = channels; this.state = state; this.driverContext = driverContext; @@ -191,7 +191,7 @@ public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction inpu if (input.getClass() != getClass()) { throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); } - StdDeviationStates.GroupingStdDeviationState inState = ((StdDeviationIntGroupingAggregatorFunction) input).state; + StdDeviationStates.GroupingState inState = ((StdDeviationIntGroupingAggregatorFunction) input).state; state.enableGroupIdTracking(new SeenGroupIds.Empty()); StdDeviationIntAggregator.combineStates(state, groupId, inState, position); } diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationLongAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationLongAggregatorFunction.java index f7b59522db89d..d7644c8286b64 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationLongAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationLongAggregatorFunction.java @@ -31,12 +31,12 @@ public final class StdDeviationLongAggregatorFunction implements AggregatorFunct private final DriverContext driverContext; - private final StdDeviationStates.StdDeviationState state; + private final StdDeviationStates.SingleState state; private final List channels; public StdDeviationLongAggregatorFunction(DriverContext driverContext, List channels, - StdDeviationStates.StdDeviationState state) { + StdDeviationStates.SingleState state) { this.driverContext = driverContext; this.channels = channels; this.state = state; diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationLongGroupingAggregatorFunction.java index 19700a071028b..5091e84840c76 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationLongGroupingAggregatorFunction.java @@ -30,14 +30,14 @@ public final class StdDeviationLongGroupingAggregatorFunction implements Groupin new IntermediateStateDesc("m2", ElementType.DOUBLE), new IntermediateStateDesc("count", ElementType.LONG) ); - private final StdDeviationStates.GroupingStdDeviationState state; + private final StdDeviationStates.GroupingState state; private final List channels; private final DriverContext driverContext; public StdDeviationLongGroupingAggregatorFunction(List channels, - StdDeviationStates.GroupingStdDeviationState state, DriverContext driverContext) { + StdDeviationStates.GroupingState state, DriverContext driverContext) { this.channels = channels; this.state = state; this.driverContext = driverContext; @@ -191,7 +191,7 @@ public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction inpu if (input.getClass() != getClass()) { throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); } - StdDeviationStates.GroupingStdDeviationState inState = ((StdDeviationLongGroupingAggregatorFunction) input).state; + StdDeviationStates.GroupingState inState = ((StdDeviationLongGroupingAggregatorFunction) input).state; state.enableGroupIdTracking(new SeenGroupIds.Empty()); StdDeviationLongAggregator.combineStates(state, groupId, inState, position); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDeviationStates.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDeviationStates.java index 3a9ae6f1a7a58..03cb6c49ce04f 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDeviationStates.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDeviationStates.java @@ -19,15 +19,15 @@ public final class StdDeviationStates { private StdDeviationStates() {} - static final class StdDeviationState implements AggregatorState { + static final class SingleState implements AggregatorState { private WelfordAlgorithm welfordAlgorithm; - StdDeviationState() { + SingleState() { this(0, 0, 0); } - StdDeviationState(double mean, double m2, long count) { + SingleState(double mean, double m2, long count) { this.welfordAlgorithm = new WelfordAlgorithm(mean, m2, count); } @@ -76,17 +76,17 @@ public double evaluateFinal() { } } - static final class GroupingStdDeviationState implements GroupingAggregatorState { + static final class GroupingState implements GroupingAggregatorState { - private ObjectArray states; + private ObjectArray states; private final BigArrays bigArrays; - GroupingStdDeviationState(BigArrays bigArrays) { + GroupingState(BigArrays bigArrays) { this.states = bigArrays.newObjectArray(1); this.bigArrays = bigArrays; } - StdDeviationState getOrNull(int position) { + SingleState getOrNull(int position) { if (position < states.size()) { return states.get(position); } else { @@ -94,7 +94,7 @@ StdDeviationState getOrNull(int position) { } } - public void combine(int groupId, StdDeviationState state) { + public void combine(int groupId, SingleState state) { if (state == null) { return; } @@ -105,18 +105,18 @@ public void combine(int groupId, double meanValue, double m2Value, long countVal ensureCapacity(groupId); var state = states.get(groupId); if (state == null) { - state = new StdDeviationState(meanValue, m2Value, countValue); + state = new SingleState(meanValue, m2Value, countValue); states.set(groupId, state); } else { state.combine(meanValue, m2Value, countValue); } } - public StdDeviationState getOrSet(int groupId) { + public SingleState getOrSet(int groupId) { ensureCapacity(groupId); var state = states.get(groupId); if (state == null) { - state = new StdDeviationState(); + state = new SingleState(); states.set(groupId, state); } return state; diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDeviationAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDeviationAggregator.java.st index cd14799d0ef9f..3ec96eab43321 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDeviationAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDeviationAggregator.java.st @@ -29,19 +29,19 @@ import org.elasticsearch.compute.operator.DriverContext; @GroupingAggregator public class StdDeviation$Type$Aggregator { - public static StdDeviationStates.StdDeviationState initSingle() { - return new StdDeviationStates.StdDeviationState(); + public static StdDeviationStates.SingleState initSingle() { + return new StdDeviationStates.SingleState(); } - public static void combine(StdDeviationStates.StdDeviationState state, $type$ value) { + public static void combine(StdDeviationStates.SingleState state, $type$ value) { state.add(value); } - public static void combineIntermediate(StdDeviationStates.StdDeviationState state, double mean, double m2, long count) { + public static void combineIntermediate(StdDeviationStates.SingleState state, double mean, double m2, long count) { state.combine(mean, m2, count); } - public static Block evaluateFinal(StdDeviationStates.StdDeviationState state, DriverContext driverContext) { + public static Block evaluateFinal(StdDeviationStates.SingleState state, DriverContext driverContext) { final long count = state.count(); final double m2 = state.m2(); if (count == 0 || Double.isFinite(m2) == false) { @@ -50,34 +50,28 @@ public class StdDeviation$Type$Aggregator { return driverContext.blockFactory().newConstantDoubleBlockWith(state.evaluateFinal(), 1); } - public static StdDeviationStates.GroupingStdDeviationState initGrouping(BigArrays bigArrays) { - return new StdDeviationStates.GroupingStdDeviationState(bigArrays); + public static StdDeviationStates.GroupingState initGrouping(BigArrays bigArrays) { + return new StdDeviationStates.GroupingState(bigArrays); } - public static void combine(StdDeviationStates.GroupingStdDeviationState current, int groupId, $type$ value) { + public static void combine(StdDeviationStates.GroupingState current, int groupId, $type$ value) { current.add(groupId, value); } public static void combineStates( - StdDeviationStates.GroupingStdDeviationState current, + StdDeviationStates.GroupingState current, int groupId, - StdDeviationStates.GroupingStdDeviationState state, + StdDeviationStates.GroupingState state, int statePosition ) { current.combine(groupId, state.getOrNull(statePosition)); } - public static void combineIntermediate( - StdDeviationStates.GroupingStdDeviationState state, - int groupId, - double mean, - double m2, - long count - ) { + public static void combineIntermediate(StdDeviationStates.GroupingState state, int groupId, double mean, double m2, long count) { state.combine(groupId, mean, m2, count); } - public static Block evaluateFinal(StdDeviationStates.GroupingStdDeviationState state, IntVector selected, DriverContext driverContext) { + public static Block evaluateFinal(StdDeviationStates.GroupingState state, IntVector selected, DriverContext driverContext) { try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount())) { for (int i = 0; i < selected.getPositionCount(); i++) { final var groupId = selected.getInt(i); From 5fb5b88c87ce102943c5e018024b86e2f45f467f Mon Sep 17 00:00:00 2001 From: Larisa Motova Date: Thu, 14 Nov 2024 12:20:05 -1000 Subject: [PATCH 07/20] change StdDeviation to StdDev --- .../functions/aggregation-functions.asciidoc | 2 +- ...td_deviation.asciidoc => std_dev.asciidoc} | 0 ...td_deviation.asciidoc => std_dev.asciidoc} | 6 +- .../{std_deviation.json => std_dev.json} | 6 +- .../docs/{std_deviation.md => std_dev.md} | 4 +- .../esql/functions/layout/std_dev.asciidoc | 15 +++ .../functions/layout/std_deviation.asciidoc | 15 --- ...td_deviation.asciidoc => std_dev.asciidoc} | 0 .../esql/functions/signature/std_dev.svg | 1 + .../functions/signature/std_deviation.svg | 1 - ...td_deviation.asciidoc => std_dev.asciidoc} | 0 x-pack/plugin/esql/compute/build.gradle | 10 +- ...gator.java => StdDevDoubleAggregator.java} | 2 +- ...egator.java => StdDevFloatAggregator.java} | 2 +- ...gregator.java => StdDevIntAggregator.java} | 2 +- ...regator.java => StdDevLongAggregator.java} | 2 +- ...va => StdDevDoubleAggregatorFunction.java} | 22 ++--- ...dDevDoubleAggregatorFunctionSupplier.java} | 16 ++-- ...dDevDoubleGroupingAggregatorFunction.java} | 26 +++--- ...ava => StdDevFloatAggregatorFunction.java} | 22 ++--- ...tdDevFloatAggregatorFunctionSupplier.java} | 17 ++-- ...tdDevFloatGroupingAggregatorFunction.java} | 26 +++--- ....java => StdDevIntAggregatorFunction.java} | 22 ++--- ... StdDevIntAggregatorFunctionSupplier.java} | 17 ++-- ... StdDevIntGroupingAggregatorFunction.java} | 26 +++--- ...java => StdDevLongAggregatorFunction.java} | 22 ++--- ...StdDevLongAggregatorFunctionSupplier.java} | 17 ++-- ...StdDevLongGroupingAggregatorFunction.java} | 26 +++--- ...tor.java.st => X-StdDevAggregator.java.st} | 2 +- .../src/main/resources/stats.csv-spec | 92 +++++++++---------- .../xpack/esql/action/EsqlCapabilities.java | 4 +- .../function/EsqlFunctionRegistry.java | 4 +- .../function/aggregate/AggregateFunction.java | 2 +- .../{StdDeviation.java => StdDev.java} | 40 ++++---- .../xpack/esql/planner/AggregateMapper.java | 9 +- ...tdDeviationTests.java => StdDevTests.java} | 12 +-- 36 files changed, 243 insertions(+), 249 deletions(-) rename docs/reference/esql/functions/description/{std_deviation.asciidoc => std_dev.asciidoc} (100%) rename docs/reference/esql/functions/examples/{std_deviation.asciidoc => std_dev.asciidoc} (69%) rename docs/reference/esql/functions/kibana/definition/{std_deviation.json => std_dev.json} (85%) rename docs/reference/esql/functions/kibana/docs/{std_deviation.md => std_dev.md} (79%) create mode 100644 docs/reference/esql/functions/layout/std_dev.asciidoc delete mode 100644 docs/reference/esql/functions/layout/std_deviation.asciidoc rename docs/reference/esql/functions/parameters/{std_deviation.asciidoc => std_dev.asciidoc} (100%) create mode 100644 docs/reference/esql/functions/signature/std_dev.svg delete mode 100644 docs/reference/esql/functions/signature/std_deviation.svg rename docs/reference/esql/functions/types/{std_deviation.asciidoc => std_dev.asciidoc} (100%) rename x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/{StdDeviationDoubleAggregator.java => StdDevDoubleAggregator.java} (98%) rename x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/{StdDeviationFloatAggregator.java => StdDevFloatAggregator.java} (98%) rename x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/{StdDeviationIntAggregator.java => StdDevIntAggregator.java} (98%) rename x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/{StdDeviationLongAggregator.java => StdDevLongAggregator.java} (98%) rename x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/{StdDeviationDoubleAggregatorFunction.java => StdDevDoubleAggregatorFunction.java} (83%) rename x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/{StdDeviationIntAggregatorFunctionSupplier.java => StdDevDoubleAggregatorFunctionSupplier.java} (55%) rename x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/{StdDeviationDoubleGroupingAggregatorFunction.java => StdDevDoubleGroupingAggregatorFunction.java} (84%) rename x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/{StdDeviationFloatAggregatorFunction.java => StdDevFloatAggregatorFunction.java} (83%) rename x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/{StdDeviationLongAggregatorFunctionSupplier.java => StdDevFloatAggregatorFunctionSupplier.java} (54%) rename x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/{StdDeviationFloatGroupingAggregatorFunction.java => StdDevFloatGroupingAggregatorFunction.java} (84%) rename x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/{StdDeviationIntAggregatorFunction.java => StdDevIntAggregatorFunction.java} (84%) rename x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/{StdDeviationFloatAggregatorFunctionSupplier.java => StdDevIntAggregatorFunctionSupplier.java} (54%) rename x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/{StdDeviationIntGroupingAggregatorFunction.java => StdDevIntGroupingAggregatorFunction.java} (84%) rename x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/{StdDeviationLongAggregatorFunction.java => StdDevLongAggregatorFunction.java} (83%) rename x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/{StdDeviationDoubleAggregatorFunctionSupplier.java => StdDevLongAggregatorFunctionSupplier.java} (53%) rename x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/{StdDeviationLongGroupingAggregatorFunction.java => StdDevLongGroupingAggregatorFunction.java} (84%) rename x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/{X-StdDeviationAggregator.java.st => X-StdDevAggregator.java.st} (98%) rename x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/{StdDeviation.java => StdDev.java} (64%) rename x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/{StdDeviationTests.java => StdDevTests.java} (87%) diff --git a/docs/reference/esql/functions/aggregation-functions.asciidoc b/docs/reference/esql/functions/aggregation-functions.asciidoc index 7777859e898ce..2c9d831dd05c2 100644 --- a/docs/reference/esql/functions/aggregation-functions.asciidoc +++ b/docs/reference/esql/functions/aggregation-functions.asciidoc @@ -17,7 +17,7 @@ The <> command supports these aggregate functions: * <> * <> * experimental:[] <> -* <> +* <> * <> * <> * <> diff --git a/docs/reference/esql/functions/description/std_deviation.asciidoc b/docs/reference/esql/functions/description/std_dev.asciidoc similarity index 100% rename from docs/reference/esql/functions/description/std_deviation.asciidoc rename to docs/reference/esql/functions/description/std_dev.asciidoc diff --git a/docs/reference/esql/functions/examples/std_deviation.asciidoc b/docs/reference/esql/functions/examples/std_dev.asciidoc similarity index 69% rename from docs/reference/esql/functions/examples/std_deviation.asciidoc rename to docs/reference/esql/functions/examples/std_dev.asciidoc index 741f5e886b945..2e6dc996aae9a 100644 --- a/docs/reference/esql/functions/examples/std_deviation.asciidoc +++ b/docs/reference/esql/functions/examples/std_dev.asciidoc @@ -10,13 +10,13 @@ include::{esql-specs}/stats.csv-spec[tag=stdev] |=== include::{esql-specs}/stats.csv-spec[tag=stdev-result] |=== -The expression can use inline functions. For example, to calculate the standard deviation of each employee's maximum salary changes, first use `MV_MAX` on each row, and then use `StdDeviation` on the result +The expression can use inline functions. For example, to calculate the standard deviation of each employee's maximum salary changes, first use `MV_MAX` on each row, and then use `STD_DEV` on the result [source.merge.styled,esql] ---- -include::{esql-specs}/stats.csv-spec[tag=docsStatsStdDeviationNestedExpression] +include::{esql-specs}/stats.csv-spec[tag=docsStatsStdDevNestedExpression] ---- [%header.monospaced.styled,format=dsv,separator=|] |=== -include::{esql-specs}/stats.csv-spec[tag=docsStatsStdDeviationNestedExpression-result] +include::{esql-specs}/stats.csv-spec[tag=docsStatsStdDevNestedExpression-result] |=== diff --git a/docs/reference/esql/functions/kibana/definition/std_deviation.json b/docs/reference/esql/functions/kibana/definition/std_dev.json similarity index 85% rename from docs/reference/esql/functions/kibana/definition/std_deviation.json rename to docs/reference/esql/functions/kibana/definition/std_dev.json index 0beb5c8b75ec9..f31d3345421d9 100644 --- a/docs/reference/esql/functions/kibana/definition/std_deviation.json +++ b/docs/reference/esql/functions/kibana/definition/std_dev.json @@ -1,7 +1,7 @@ { "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.", "type" : "agg", - "name" : "std_deviation", + "name" : "std_dev", "description" : "The standard deviation of a numeric field.", "signatures" : [ { @@ -42,8 +42,8 @@ } ], "examples" : [ - "FROM employees\n| STATS STD_DEVIATION(height)", - "FROM employees\n| STATS stdev_salary_change = STD_DEVIATION(MV_MAX(salary_change))" + "FROM employees\n| STATS STD_DEV(height)", + "FROM employees\n| STATS stddev_salary_change = STD_DEV(MV_MAX(salary_change))" ], "preview" : false, "snapshot_only" : false diff --git a/docs/reference/esql/functions/kibana/docs/std_deviation.md b/docs/reference/esql/functions/kibana/docs/std_dev.md similarity index 79% rename from docs/reference/esql/functions/kibana/docs/std_deviation.md rename to docs/reference/esql/functions/kibana/docs/std_dev.md index d3dad54b3c5b4..a6afca7b8f6b3 100644 --- a/docs/reference/esql/functions/kibana/docs/std_deviation.md +++ b/docs/reference/esql/functions/kibana/docs/std_dev.md @@ -2,10 +2,10 @@ This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. --> -### STD_DEVIATION +### STD_DEV The standard deviation of a numeric field. ``` FROM employees -| STATS STD_DEVIATION(height) +| STATS STD_DEV(height) ``` diff --git a/docs/reference/esql/functions/layout/std_dev.asciidoc b/docs/reference/esql/functions/layout/std_dev.asciidoc new file mode 100644 index 0000000000000..a7a34b1331d17 --- /dev/null +++ b/docs/reference/esql/functions/layout/std_dev.asciidoc @@ -0,0 +1,15 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +[discrete] +[[esql-std_dev]] +=== `STD_DEV` + +*Syntax* + +[.text-center] +image::esql/functions/signature/std_dev.svg[Embedded,opts=inline] + +include::../parameters/std_dev.asciidoc[] +include::../description/std_dev.asciidoc[] +include::../types/std_dev.asciidoc[] +include::../examples/std_dev.asciidoc[] diff --git a/docs/reference/esql/functions/layout/std_deviation.asciidoc b/docs/reference/esql/functions/layout/std_deviation.asciidoc deleted file mode 100644 index 93fbcbb87aba6..0000000000000 --- a/docs/reference/esql/functions/layout/std_deviation.asciidoc +++ /dev/null @@ -1,15 +0,0 @@ -// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. - -[discrete] -[[esql-std_deviation]] -=== `STD_DEVIATION` - -*Syntax* - -[.text-center] -image::esql/functions/signature/std_deviation.svg[Embedded,opts=inline] - -include::../parameters/std_deviation.asciidoc[] -include::../description/std_deviation.asciidoc[] -include::../types/std_deviation.asciidoc[] -include::../examples/std_deviation.asciidoc[] diff --git a/docs/reference/esql/functions/parameters/std_deviation.asciidoc b/docs/reference/esql/functions/parameters/std_dev.asciidoc similarity index 100% rename from docs/reference/esql/functions/parameters/std_deviation.asciidoc rename to docs/reference/esql/functions/parameters/std_dev.asciidoc diff --git a/docs/reference/esql/functions/signature/std_dev.svg b/docs/reference/esql/functions/signature/std_dev.svg new file mode 100644 index 0000000000000..606d285154f59 --- /dev/null +++ b/docs/reference/esql/functions/signature/std_dev.svg @@ -0,0 +1 @@ +STD_DEV(number) \ No newline at end of file diff --git a/docs/reference/esql/functions/signature/std_deviation.svg b/docs/reference/esql/functions/signature/std_deviation.svg deleted file mode 100644 index af83594d04871..0000000000000 --- a/docs/reference/esql/functions/signature/std_deviation.svg +++ /dev/null @@ -1 +0,0 @@ -STD_DEVIATION(number) \ No newline at end of file diff --git a/docs/reference/esql/functions/types/std_deviation.asciidoc b/docs/reference/esql/functions/types/std_dev.asciidoc similarity index 100% rename from docs/reference/esql/functions/types/std_deviation.asciidoc rename to docs/reference/esql/functions/types/std_dev.asciidoc diff --git a/x-pack/plugin/esql/compute/build.gradle b/x-pack/plugin/esql/compute/build.gradle index f2e19fecf051e..da7c7d876ce6e 100644 --- a/x-pack/plugin/esql/compute/build.gradle +++ b/x-pack/plugin/esql/compute/build.gradle @@ -608,26 +608,26 @@ tasks.named('stringTemplates').configure { it.outputFile = "org/elasticsearch/compute/aggregation/RateDoubleAggregator.java" } - File stdDeviationAggregatorInputFile = file("src/main/java/org/elasticsearch/compute/aggregation/X-StdDeviationAggregator.java.st") + File stdDeviationAggregatorInputFile = file("src/main/java/org/elasticsearch/compute/aggregation/X-StdDevAggregator.java.st") template { it.properties = intProperties it.inputFile = stdDeviationAggregatorInputFile - it.outputFile = "org/elasticsearch/compute/aggregation/StdDeviationIntAggregator.java" + it.outputFile = "org/elasticsearch/compute/aggregation/StdDevIntAggregator.java" } template { it.properties = longProperties it.inputFile = stdDeviationAggregatorInputFile - it.outputFile = "org/elasticsearch/compute/aggregation/StdDeviationLongAggregator.java" + it.outputFile = "org/elasticsearch/compute/aggregation/StdDevLongAggregator.java" } template { it.properties = floatProperties it.inputFile = stdDeviationAggregatorInputFile - it.outputFile = "org/elasticsearch/compute/aggregation/StdDeviationFloatAggregator.java" + it.outputFile = "org/elasticsearch/compute/aggregation/StdDevFloatAggregator.java" } template { it.properties = doubleProperties it.inputFile = stdDeviationAggregatorInputFile - it.outputFile = "org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregator.java" + it.outputFile = "org/elasticsearch/compute/aggregation/StdDevDoubleAggregator.java" } File topAggregatorInputFile = new File("${projectDir}/src/main/java/org/elasticsearch/compute/aggregation/X-TopAggregator.java.st") diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevDoubleAggregator.java similarity index 98% rename from x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregator.java rename to x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevDoubleAggregator.java index cf06c36604f8d..babd7e25f5027 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevDoubleAggregator.java @@ -27,7 +27,7 @@ @IntermediateState(name = "count", type = "LONG") } ) @GroupingAggregator -public class StdDeviationDoubleAggregator { +public class StdDevDoubleAggregator { public static StdDeviationStates.SingleState initSingle() { return new StdDeviationStates.SingleState(); diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevFloatAggregator.java similarity index 98% rename from x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationFloatAggregator.java rename to x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevFloatAggregator.java index 321c76733120f..d118b913cc880 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationFloatAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevFloatAggregator.java @@ -27,7 +27,7 @@ @IntermediateState(name = "count", type = "LONG") } ) @GroupingAggregator -public class StdDeviationFloatAggregator { +public class StdDevFloatAggregator { public static StdDeviationStates.SingleState initSingle() { return new StdDeviationStates.SingleState(); diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevIntAggregator.java similarity index 98% rename from x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationIntAggregator.java rename to x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevIntAggregator.java index 44d253c35848a..10466b25cb75c 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationIntAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevIntAggregator.java @@ -27,7 +27,7 @@ @IntermediateState(name = "count", type = "LONG") } ) @GroupingAggregator -public class StdDeviationIntAggregator { +public class StdDevIntAggregator { public static StdDeviationStates.SingleState initSingle() { return new StdDeviationStates.SingleState(); diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevLongAggregator.java similarity index 98% rename from x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationLongAggregator.java rename to x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevLongAggregator.java index 7b7c6b3dc19d7..f1f40e211d7b9 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDeviationLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevLongAggregator.java @@ -27,7 +27,7 @@ @IntermediateState(name = "count", type = "LONG") } ) @GroupingAggregator -public class StdDeviationLongAggregator { +public class StdDevLongAggregator { public static StdDeviationStates.SingleState initSingle() { return new StdDeviationStates.SingleState(); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleAggregatorFunction.java similarity index 83% rename from x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregatorFunction.java rename to x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleAggregatorFunction.java index 437fddf081d3b..c6ba833b3499b 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleAggregatorFunction.java @@ -20,10 +20,10 @@ import org.elasticsearch.compute.operator.DriverContext; /** - * {@link AggregatorFunction} implementation for {@link StdDeviationDoubleAggregator}. + * {@link AggregatorFunction} implementation for {@link StdDevDoubleAggregator}. * This class is generated. Do not edit it. */ -public final class StdDeviationDoubleAggregatorFunction implements AggregatorFunction { +public final class StdDevDoubleAggregatorFunction implements AggregatorFunction { private static final List INTERMEDIATE_STATE_DESC = List.of( new IntermediateStateDesc("mean", ElementType.DOUBLE), new IntermediateStateDesc("m2", ElementType.DOUBLE), @@ -35,16 +35,16 @@ public final class StdDeviationDoubleAggregatorFunction implements AggregatorFun private final List channels; - public StdDeviationDoubleAggregatorFunction(DriverContext driverContext, List channels, + public StdDevDoubleAggregatorFunction(DriverContext driverContext, List channels, StdDeviationStates.SingleState state) { this.driverContext = driverContext; this.channels = channels; this.state = state; } - public static StdDeviationDoubleAggregatorFunction create(DriverContext driverContext, + public static StdDevDoubleAggregatorFunction create(DriverContext driverContext, List channels) { - return new StdDeviationDoubleAggregatorFunction(driverContext, channels, StdDeviationDoubleAggregator.initSingle()); + return new StdDevDoubleAggregatorFunction(driverContext, channels, StdDevDoubleAggregator.initSingle()); } public static List intermediateStateDesc() { @@ -85,7 +85,7 @@ public void addRawInput(Page page, BooleanVector mask) { private void addRawVector(DoubleVector vector) { for (int i = 0; i < vector.getPositionCount(); i++) { - StdDeviationDoubleAggregator.combine(state, vector.getDouble(i)); + StdDevDoubleAggregator.combine(state, vector.getDouble(i)); } } @@ -94,7 +94,7 @@ private void addRawVector(DoubleVector vector, BooleanVector mask) { if (mask.getBoolean(i) == false) { continue; } - StdDeviationDoubleAggregator.combine(state, vector.getDouble(i)); + StdDevDoubleAggregator.combine(state, vector.getDouble(i)); } } @@ -106,7 +106,7 @@ private void addRawBlock(DoubleBlock block) { int start = block.getFirstValueIndex(p); int end = start + block.getValueCount(p); for (int i = start; i < end; i++) { - StdDeviationDoubleAggregator.combine(state, block.getDouble(i)); + StdDevDoubleAggregator.combine(state, block.getDouble(i)); } } } @@ -122,7 +122,7 @@ private void addRawBlock(DoubleBlock block, BooleanVector mask) { int start = block.getFirstValueIndex(p); int end = start + block.getValueCount(p); for (int i = start; i < end; i++) { - StdDeviationDoubleAggregator.combine(state, block.getDouble(i)); + StdDevDoubleAggregator.combine(state, block.getDouble(i)); } } } @@ -149,7 +149,7 @@ public void addIntermediateInput(Page page) { } LongVector count = ((LongBlock) countUncast).asVector(); assert count.getPositionCount() == 1; - StdDeviationDoubleAggregator.combineIntermediate(state, mean.getDouble(0), m2.getDouble(0), count.getLong(0)); + StdDevDoubleAggregator.combineIntermediate(state, mean.getDouble(0), m2.getDouble(0), count.getLong(0)); } @Override @@ -159,7 +159,7 @@ public void evaluateIntermediate(Block[] blocks, int offset, DriverContext drive @Override public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { - blocks[offset] = StdDeviationDoubleAggregator.evaluateFinal(state, driverContext); + blocks[offset] = StdDevDoubleAggregator.evaluateFinal(state, driverContext); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationIntAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleAggregatorFunctionSupplier.java similarity index 55% rename from x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationIntAggregatorFunctionSupplier.java rename to x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleAggregatorFunctionSupplier.java index 0d1c8d5c2415b..313eed4ae97ae 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationIntAggregatorFunctionSupplier.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleAggregatorFunctionSupplier.java @@ -11,28 +11,28 @@ import org.elasticsearch.compute.operator.DriverContext; /** - * {@link AggregatorFunctionSupplier} implementation for {@link StdDeviationIntAggregator}. + * {@link AggregatorFunctionSupplier} implementation for {@link StdDevDoubleAggregator}. * This class is generated. Do not edit it. */ -public final class StdDeviationIntAggregatorFunctionSupplier implements AggregatorFunctionSupplier { +public final class StdDevDoubleAggregatorFunctionSupplier implements AggregatorFunctionSupplier { private final List channels; - public StdDeviationIntAggregatorFunctionSupplier(List channels) { + public StdDevDoubleAggregatorFunctionSupplier(List channels) { this.channels = channels; } @Override - public StdDeviationIntAggregatorFunction aggregator(DriverContext driverContext) { - return StdDeviationIntAggregatorFunction.create(driverContext, channels); + public StdDevDoubleAggregatorFunction aggregator(DriverContext driverContext) { + return StdDevDoubleAggregatorFunction.create(driverContext, channels); } @Override - public StdDeviationIntGroupingAggregatorFunction groupingAggregator(DriverContext driverContext) { - return StdDeviationIntGroupingAggregatorFunction.create(channels, driverContext); + public StdDevDoubleGroupingAggregatorFunction groupingAggregator(DriverContext driverContext) { + return StdDevDoubleGroupingAggregatorFunction.create(channels, driverContext); } @Override public String describe() { - return "std_deviation of ints"; + return "std_dev of doubles"; } } diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleGroupingAggregatorFunction.java similarity index 84% rename from x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationDoubleGroupingAggregatorFunction.java rename to x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleGroupingAggregatorFunction.java index 8d30b95c4c246..663b6661aaec8 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleGroupingAggregatorFunction.java @@ -21,10 +21,10 @@ import org.elasticsearch.compute.operator.DriverContext; /** - * {@link GroupingAggregatorFunction} implementation for {@link StdDeviationDoubleAggregator}. + * {@link GroupingAggregatorFunction} implementation for {@link StdDevDoubleAggregator}. * This class is generated. Do not edit it. */ -public final class StdDeviationDoubleGroupingAggregatorFunction implements GroupingAggregatorFunction { +public final class StdDevDoubleGroupingAggregatorFunction implements GroupingAggregatorFunction { private static final List INTERMEDIATE_STATE_DESC = List.of( new IntermediateStateDesc("mean", ElementType.DOUBLE), new IntermediateStateDesc("m2", ElementType.DOUBLE), @@ -36,16 +36,16 @@ public final class StdDeviationDoubleGroupingAggregatorFunction implements Group private final DriverContext driverContext; - public StdDeviationDoubleGroupingAggregatorFunction(List channels, + public StdDevDoubleGroupingAggregatorFunction(List channels, StdDeviationStates.GroupingState state, DriverContext driverContext) { this.channels = channels; this.state = state; this.driverContext = driverContext; } - public static StdDeviationDoubleGroupingAggregatorFunction create(List channels, + public static StdDevDoubleGroupingAggregatorFunction create(List channels, DriverContext driverContext) { - return new StdDeviationDoubleGroupingAggregatorFunction(channels, StdDeviationDoubleAggregator.initGrouping(driverContext.bigArrays()), driverContext); + return new StdDevDoubleGroupingAggregatorFunction(channels, StdDevDoubleAggregator.initGrouping(driverContext.bigArrays()), driverContext); } public static List intermediateStateDesc() { @@ -108,7 +108,7 @@ private void addRawInput(int positionOffset, IntVector groups, DoubleBlock value int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { - StdDeviationDoubleAggregator.combine(state, groupId, values.getDouble(v)); + StdDevDoubleAggregator.combine(state, groupId, values.getDouble(v)); } } } @@ -116,7 +116,7 @@ private void addRawInput(int positionOffset, IntVector groups, DoubleBlock value private void addRawInput(int positionOffset, IntVector groups, DoubleVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { int groupId = groups.getInt(groupPosition); - StdDeviationDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset)); + StdDevDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset)); } } @@ -135,7 +135,7 @@ private void addRawInput(int positionOffset, IntBlock groups, DoubleBlock values int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { - StdDeviationDoubleAggregator.combine(state, groupId, values.getDouble(v)); + StdDevDoubleAggregator.combine(state, groupId, values.getDouble(v)); } } } @@ -150,7 +150,7 @@ private void addRawInput(int positionOffset, IntBlock groups, DoubleVector value int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - StdDeviationDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset)); + StdDevDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset)); } } } @@ -182,7 +182,7 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page assert mean.getPositionCount() == m2.getPositionCount() && mean.getPositionCount() == count.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { int groupId = groups.getInt(groupPosition); - StdDeviationDoubleAggregator.combineIntermediate(state, groupId, mean.getDouble(groupPosition + positionOffset), m2.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + StdDevDoubleAggregator.combineIntermediate(state, groupId, mean.getDouble(groupPosition + positionOffset), m2.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); } } @@ -191,9 +191,9 @@ public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction inpu if (input.getClass() != getClass()) { throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); } - StdDeviationStates.GroupingState inState = ((StdDeviationDoubleGroupingAggregatorFunction) input).state; + StdDeviationStates.GroupingState inState = ((StdDevDoubleGroupingAggregatorFunction) input).state; state.enableGroupIdTracking(new SeenGroupIds.Empty()); - StdDeviationDoubleAggregator.combineStates(state, groupId, inState, position); + StdDevDoubleAggregator.combineStates(state, groupId, inState, position); } @Override @@ -204,7 +204,7 @@ public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) @Override public void evaluateFinal(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { - blocks[offset] = StdDeviationDoubleAggregator.evaluateFinal(state, selected, driverContext); + blocks[offset] = StdDevDoubleAggregator.evaluateFinal(state, selected, driverContext); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationFloatAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatAggregatorFunction.java similarity index 83% rename from x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationFloatAggregatorFunction.java rename to x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatAggregatorFunction.java index 35bbbe3fa83d4..cd2f5b88932cb 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationFloatAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatAggregatorFunction.java @@ -22,10 +22,10 @@ import org.elasticsearch.compute.operator.DriverContext; /** - * {@link AggregatorFunction} implementation for {@link StdDeviationFloatAggregator}. + * {@link AggregatorFunction} implementation for {@link StdDevFloatAggregator}. * This class is generated. Do not edit it. */ -public final class StdDeviationFloatAggregatorFunction implements AggregatorFunction { +public final class StdDevFloatAggregatorFunction implements AggregatorFunction { private static final List INTERMEDIATE_STATE_DESC = List.of( new IntermediateStateDesc("mean", ElementType.DOUBLE), new IntermediateStateDesc("m2", ElementType.DOUBLE), @@ -37,16 +37,16 @@ public final class StdDeviationFloatAggregatorFunction implements AggregatorFunc private final List channels; - public StdDeviationFloatAggregatorFunction(DriverContext driverContext, List channels, + public StdDevFloatAggregatorFunction(DriverContext driverContext, List channels, StdDeviationStates.SingleState state) { this.driverContext = driverContext; this.channels = channels; this.state = state; } - public static StdDeviationFloatAggregatorFunction create(DriverContext driverContext, + public static StdDevFloatAggregatorFunction create(DriverContext driverContext, List channels) { - return new StdDeviationFloatAggregatorFunction(driverContext, channels, StdDeviationFloatAggregator.initSingle()); + return new StdDevFloatAggregatorFunction(driverContext, channels, StdDevFloatAggregator.initSingle()); } public static List intermediateStateDesc() { @@ -87,7 +87,7 @@ public void addRawInput(Page page, BooleanVector mask) { private void addRawVector(FloatVector vector) { for (int i = 0; i < vector.getPositionCount(); i++) { - StdDeviationFloatAggregator.combine(state, vector.getFloat(i)); + StdDevFloatAggregator.combine(state, vector.getFloat(i)); } } @@ -96,7 +96,7 @@ private void addRawVector(FloatVector vector, BooleanVector mask) { if (mask.getBoolean(i) == false) { continue; } - StdDeviationFloatAggregator.combine(state, vector.getFloat(i)); + StdDevFloatAggregator.combine(state, vector.getFloat(i)); } } @@ -108,7 +108,7 @@ private void addRawBlock(FloatBlock block) { int start = block.getFirstValueIndex(p); int end = start + block.getValueCount(p); for (int i = start; i < end; i++) { - StdDeviationFloatAggregator.combine(state, block.getFloat(i)); + StdDevFloatAggregator.combine(state, block.getFloat(i)); } } } @@ -124,7 +124,7 @@ private void addRawBlock(FloatBlock block, BooleanVector mask) { int start = block.getFirstValueIndex(p); int end = start + block.getValueCount(p); for (int i = start; i < end; i++) { - StdDeviationFloatAggregator.combine(state, block.getFloat(i)); + StdDevFloatAggregator.combine(state, block.getFloat(i)); } } } @@ -151,7 +151,7 @@ public void addIntermediateInput(Page page) { } LongVector count = ((LongBlock) countUncast).asVector(); assert count.getPositionCount() == 1; - StdDeviationFloatAggregator.combineIntermediate(state, mean.getDouble(0), m2.getDouble(0), count.getLong(0)); + StdDevFloatAggregator.combineIntermediate(state, mean.getDouble(0), m2.getDouble(0), count.getLong(0)); } @Override @@ -161,7 +161,7 @@ public void evaluateIntermediate(Block[] blocks, int offset, DriverContext drive @Override public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { - blocks[offset] = StdDeviationFloatAggregator.evaluateFinal(state, driverContext); + blocks[offset] = StdDevFloatAggregator.evaluateFinal(state, driverContext); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationLongAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatAggregatorFunctionSupplier.java similarity index 54% rename from x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationLongAggregatorFunctionSupplier.java rename to x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatAggregatorFunctionSupplier.java index b2c141647572a..25dfa54895eda 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationLongAggregatorFunctionSupplier.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatAggregatorFunctionSupplier.java @@ -11,29 +11,28 @@ import org.elasticsearch.compute.operator.DriverContext; /** - * {@link AggregatorFunctionSupplier} implementation for {@link StdDeviationLongAggregator}. + * {@link AggregatorFunctionSupplier} implementation for {@link StdDevFloatAggregator}. * This class is generated. Do not edit it. */ -public final class StdDeviationLongAggregatorFunctionSupplier implements AggregatorFunctionSupplier { +public final class StdDevFloatAggregatorFunctionSupplier implements AggregatorFunctionSupplier { private final List channels; - public StdDeviationLongAggregatorFunctionSupplier(List channels) { + public StdDevFloatAggregatorFunctionSupplier(List channels) { this.channels = channels; } @Override - public StdDeviationLongAggregatorFunction aggregator(DriverContext driverContext) { - return StdDeviationLongAggregatorFunction.create(driverContext, channels); + public StdDevFloatAggregatorFunction aggregator(DriverContext driverContext) { + return StdDevFloatAggregatorFunction.create(driverContext, channels); } @Override - public StdDeviationLongGroupingAggregatorFunction groupingAggregator( - DriverContext driverContext) { - return StdDeviationLongGroupingAggregatorFunction.create(channels, driverContext); + public StdDevFloatGroupingAggregatorFunction groupingAggregator(DriverContext driverContext) { + return StdDevFloatGroupingAggregatorFunction.create(channels, driverContext); } @Override public String describe() { - return "std_deviation of longs"; + return "std_dev of floats"; } } diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatGroupingAggregatorFunction.java similarity index 84% rename from x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationFloatGroupingAggregatorFunction.java rename to x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatGroupingAggregatorFunction.java index f22d6e866f7e7..cf79620e19ace 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatGroupingAggregatorFunction.java @@ -23,10 +23,10 @@ import org.elasticsearch.compute.operator.DriverContext; /** - * {@link GroupingAggregatorFunction} implementation for {@link StdDeviationFloatAggregator}. + * {@link GroupingAggregatorFunction} implementation for {@link StdDevFloatAggregator}. * This class is generated. Do not edit it. */ -public final class StdDeviationFloatGroupingAggregatorFunction implements GroupingAggregatorFunction { +public final class StdDevFloatGroupingAggregatorFunction implements GroupingAggregatorFunction { private static final List INTERMEDIATE_STATE_DESC = List.of( new IntermediateStateDesc("mean", ElementType.DOUBLE), new IntermediateStateDesc("m2", ElementType.DOUBLE), @@ -38,16 +38,16 @@ public final class StdDeviationFloatGroupingAggregatorFunction implements Groupi private final DriverContext driverContext; - public StdDeviationFloatGroupingAggregatorFunction(List channels, + public StdDevFloatGroupingAggregatorFunction(List channels, StdDeviationStates.GroupingState state, DriverContext driverContext) { this.channels = channels; this.state = state; this.driverContext = driverContext; } - public static StdDeviationFloatGroupingAggregatorFunction create(List channels, + public static StdDevFloatGroupingAggregatorFunction create(List channels, DriverContext driverContext) { - return new StdDeviationFloatGroupingAggregatorFunction(channels, StdDeviationFloatAggregator.initGrouping(driverContext.bigArrays()), driverContext); + return new StdDevFloatGroupingAggregatorFunction(channels, StdDevFloatAggregator.initGrouping(driverContext.bigArrays()), driverContext); } public static List intermediateStateDesc() { @@ -110,7 +110,7 @@ private void addRawInput(int positionOffset, IntVector groups, FloatBlock values int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { - StdDeviationFloatAggregator.combine(state, groupId, values.getFloat(v)); + StdDevFloatAggregator.combine(state, groupId, values.getFloat(v)); } } } @@ -118,7 +118,7 @@ private void addRawInput(int positionOffset, IntVector groups, FloatBlock values private void addRawInput(int positionOffset, IntVector groups, FloatVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { int groupId = groups.getInt(groupPosition); - StdDeviationFloatAggregator.combine(state, groupId, values.getFloat(groupPosition + positionOffset)); + StdDevFloatAggregator.combine(state, groupId, values.getFloat(groupPosition + positionOffset)); } } @@ -137,7 +137,7 @@ private void addRawInput(int positionOffset, IntBlock groups, FloatBlock values) int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { - StdDeviationFloatAggregator.combine(state, groupId, values.getFloat(v)); + StdDevFloatAggregator.combine(state, groupId, values.getFloat(v)); } } } @@ -152,7 +152,7 @@ private void addRawInput(int positionOffset, IntBlock groups, FloatVector values int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - StdDeviationFloatAggregator.combine(state, groupId, values.getFloat(groupPosition + positionOffset)); + StdDevFloatAggregator.combine(state, groupId, values.getFloat(groupPosition + positionOffset)); } } } @@ -184,7 +184,7 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page assert mean.getPositionCount() == m2.getPositionCount() && mean.getPositionCount() == count.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { int groupId = groups.getInt(groupPosition); - StdDeviationFloatAggregator.combineIntermediate(state, groupId, mean.getDouble(groupPosition + positionOffset), m2.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + StdDevFloatAggregator.combineIntermediate(state, groupId, mean.getDouble(groupPosition + positionOffset), m2.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); } } @@ -193,9 +193,9 @@ public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction inpu if (input.getClass() != getClass()) { throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); } - StdDeviationStates.GroupingState inState = ((StdDeviationFloatGroupingAggregatorFunction) input).state; + StdDeviationStates.GroupingState inState = ((StdDevFloatGroupingAggregatorFunction) input).state; state.enableGroupIdTracking(new SeenGroupIds.Empty()); - StdDeviationFloatAggregator.combineStates(state, groupId, inState, position); + StdDevFloatAggregator.combineStates(state, groupId, inState, position); } @Override @@ -206,7 +206,7 @@ public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) @Override public void evaluateFinal(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { - blocks[offset] = StdDeviationFloatAggregator.evaluateFinal(state, selected, driverContext); + blocks[offset] = StdDevFloatAggregator.evaluateFinal(state, selected, driverContext); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationIntAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntAggregatorFunction.java similarity index 84% rename from x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationIntAggregatorFunction.java rename to x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntAggregatorFunction.java index 5bfd4fb495f29..a499ae4698819 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationIntAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntAggregatorFunction.java @@ -22,10 +22,10 @@ import org.elasticsearch.compute.operator.DriverContext; /** - * {@link AggregatorFunction} implementation for {@link StdDeviationIntAggregator}. + * {@link AggregatorFunction} implementation for {@link StdDevIntAggregator}. * This class is generated. Do not edit it. */ -public final class StdDeviationIntAggregatorFunction implements AggregatorFunction { +public final class StdDevIntAggregatorFunction implements AggregatorFunction { private static final List INTERMEDIATE_STATE_DESC = List.of( new IntermediateStateDesc("mean", ElementType.DOUBLE), new IntermediateStateDesc("m2", ElementType.DOUBLE), @@ -37,16 +37,16 @@ public final class StdDeviationIntAggregatorFunction implements AggregatorFuncti private final List channels; - public StdDeviationIntAggregatorFunction(DriverContext driverContext, List channels, + public StdDevIntAggregatorFunction(DriverContext driverContext, List channels, StdDeviationStates.SingleState state) { this.driverContext = driverContext; this.channels = channels; this.state = state; } - public static StdDeviationIntAggregatorFunction create(DriverContext driverContext, + public static StdDevIntAggregatorFunction create(DriverContext driverContext, List channels) { - return new StdDeviationIntAggregatorFunction(driverContext, channels, StdDeviationIntAggregator.initSingle()); + return new StdDevIntAggregatorFunction(driverContext, channels, StdDevIntAggregator.initSingle()); } public static List intermediateStateDesc() { @@ -87,7 +87,7 @@ public void addRawInput(Page page, BooleanVector mask) { private void addRawVector(IntVector vector) { for (int i = 0; i < vector.getPositionCount(); i++) { - StdDeviationIntAggregator.combine(state, vector.getInt(i)); + StdDevIntAggregator.combine(state, vector.getInt(i)); } } @@ -96,7 +96,7 @@ private void addRawVector(IntVector vector, BooleanVector mask) { if (mask.getBoolean(i) == false) { continue; } - StdDeviationIntAggregator.combine(state, vector.getInt(i)); + StdDevIntAggregator.combine(state, vector.getInt(i)); } } @@ -108,7 +108,7 @@ private void addRawBlock(IntBlock block) { int start = block.getFirstValueIndex(p); int end = start + block.getValueCount(p); for (int i = start; i < end; i++) { - StdDeviationIntAggregator.combine(state, block.getInt(i)); + StdDevIntAggregator.combine(state, block.getInt(i)); } } } @@ -124,7 +124,7 @@ private void addRawBlock(IntBlock block, BooleanVector mask) { int start = block.getFirstValueIndex(p); int end = start + block.getValueCount(p); for (int i = start; i < end; i++) { - StdDeviationIntAggregator.combine(state, block.getInt(i)); + StdDevIntAggregator.combine(state, block.getInt(i)); } } } @@ -151,7 +151,7 @@ public void addIntermediateInput(Page page) { } LongVector count = ((LongBlock) countUncast).asVector(); assert count.getPositionCount() == 1; - StdDeviationIntAggregator.combineIntermediate(state, mean.getDouble(0), m2.getDouble(0), count.getLong(0)); + StdDevIntAggregator.combineIntermediate(state, mean.getDouble(0), m2.getDouble(0), count.getLong(0)); } @Override @@ -161,7 +161,7 @@ public void evaluateIntermediate(Block[] blocks, int offset, DriverContext drive @Override public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { - blocks[offset] = StdDeviationIntAggregator.evaluateFinal(state, driverContext); + blocks[offset] = StdDevIntAggregator.evaluateFinal(state, driverContext); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationFloatAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntAggregatorFunctionSupplier.java similarity index 54% rename from x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationFloatAggregatorFunctionSupplier.java rename to x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntAggregatorFunctionSupplier.java index e761b32fc8f0c..5a762d6606a25 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationFloatAggregatorFunctionSupplier.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntAggregatorFunctionSupplier.java @@ -11,29 +11,28 @@ import org.elasticsearch.compute.operator.DriverContext; /** - * {@link AggregatorFunctionSupplier} implementation for {@link StdDeviationFloatAggregator}. + * {@link AggregatorFunctionSupplier} implementation for {@link StdDevIntAggregator}. * This class is generated. Do not edit it. */ -public final class StdDeviationFloatAggregatorFunctionSupplier implements AggregatorFunctionSupplier { +public final class StdDevIntAggregatorFunctionSupplier implements AggregatorFunctionSupplier { private final List channels; - public StdDeviationFloatAggregatorFunctionSupplier(List channels) { + public StdDevIntAggregatorFunctionSupplier(List channels) { this.channels = channels; } @Override - public StdDeviationFloatAggregatorFunction aggregator(DriverContext driverContext) { - return StdDeviationFloatAggregatorFunction.create(driverContext, channels); + public StdDevIntAggregatorFunction aggregator(DriverContext driverContext) { + return StdDevIntAggregatorFunction.create(driverContext, channels); } @Override - public StdDeviationFloatGroupingAggregatorFunction groupingAggregator( - DriverContext driverContext) { - return StdDeviationFloatGroupingAggregatorFunction.create(channels, driverContext); + public StdDevIntGroupingAggregatorFunction groupingAggregator(DriverContext driverContext) { + return StdDevIntGroupingAggregatorFunction.create(channels, driverContext); } @Override public String describe() { - return "std_deviation of floats"; + return "std_dev of ints"; } } diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntGroupingAggregatorFunction.java similarity index 84% rename from x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationIntGroupingAggregatorFunction.java rename to x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntGroupingAggregatorFunction.java index 661bc068be26f..6196f89f3d300 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntGroupingAggregatorFunction.java @@ -21,10 +21,10 @@ import org.elasticsearch.compute.operator.DriverContext; /** - * {@link GroupingAggregatorFunction} implementation for {@link StdDeviationIntAggregator}. + * {@link GroupingAggregatorFunction} implementation for {@link StdDevIntAggregator}. * This class is generated. Do not edit it. */ -public final class StdDeviationIntGroupingAggregatorFunction implements GroupingAggregatorFunction { +public final class StdDevIntGroupingAggregatorFunction implements GroupingAggregatorFunction { private static final List INTERMEDIATE_STATE_DESC = List.of( new IntermediateStateDesc("mean", ElementType.DOUBLE), new IntermediateStateDesc("m2", ElementType.DOUBLE), @@ -36,16 +36,16 @@ public final class StdDeviationIntGroupingAggregatorFunction implements Grouping private final DriverContext driverContext; - public StdDeviationIntGroupingAggregatorFunction(List channels, + public StdDevIntGroupingAggregatorFunction(List channels, StdDeviationStates.GroupingState state, DriverContext driverContext) { this.channels = channels; this.state = state; this.driverContext = driverContext; } - public static StdDeviationIntGroupingAggregatorFunction create(List channels, + public static StdDevIntGroupingAggregatorFunction create(List channels, DriverContext driverContext) { - return new StdDeviationIntGroupingAggregatorFunction(channels, StdDeviationIntAggregator.initGrouping(driverContext.bigArrays()), driverContext); + return new StdDevIntGroupingAggregatorFunction(channels, StdDevIntAggregator.initGrouping(driverContext.bigArrays()), driverContext); } public static List intermediateStateDesc() { @@ -108,7 +108,7 @@ private void addRawInput(int positionOffset, IntVector groups, IntBlock values) int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { - StdDeviationIntAggregator.combine(state, groupId, values.getInt(v)); + StdDevIntAggregator.combine(state, groupId, values.getInt(v)); } } } @@ -116,7 +116,7 @@ private void addRawInput(int positionOffset, IntVector groups, IntBlock values) private void addRawInput(int positionOffset, IntVector groups, IntVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { int groupId = groups.getInt(groupPosition); - StdDeviationIntAggregator.combine(state, groupId, values.getInt(groupPosition + positionOffset)); + StdDevIntAggregator.combine(state, groupId, values.getInt(groupPosition + positionOffset)); } } @@ -135,7 +135,7 @@ private void addRawInput(int positionOffset, IntBlock groups, IntBlock values) { int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { - StdDeviationIntAggregator.combine(state, groupId, values.getInt(v)); + StdDevIntAggregator.combine(state, groupId, values.getInt(v)); } } } @@ -150,7 +150,7 @@ private void addRawInput(int positionOffset, IntBlock groups, IntVector values) int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - StdDeviationIntAggregator.combine(state, groupId, values.getInt(groupPosition + positionOffset)); + StdDevIntAggregator.combine(state, groupId, values.getInt(groupPosition + positionOffset)); } } } @@ -182,7 +182,7 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page assert mean.getPositionCount() == m2.getPositionCount() && mean.getPositionCount() == count.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { int groupId = groups.getInt(groupPosition); - StdDeviationIntAggregator.combineIntermediate(state, groupId, mean.getDouble(groupPosition + positionOffset), m2.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + StdDevIntAggregator.combineIntermediate(state, groupId, mean.getDouble(groupPosition + positionOffset), m2.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); } } @@ -191,9 +191,9 @@ public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction inpu if (input.getClass() != getClass()) { throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); } - StdDeviationStates.GroupingState inState = ((StdDeviationIntGroupingAggregatorFunction) input).state; + StdDeviationStates.GroupingState inState = ((StdDevIntGroupingAggregatorFunction) input).state; state.enableGroupIdTracking(new SeenGroupIds.Empty()); - StdDeviationIntAggregator.combineStates(state, groupId, inState, position); + StdDevIntAggregator.combineStates(state, groupId, inState, position); } @Override @@ -204,7 +204,7 @@ public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) @Override public void evaluateFinal(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { - blocks[offset] = StdDeviationIntAggregator.evaluateFinal(state, selected, driverContext); + blocks[offset] = StdDevIntAggregator.evaluateFinal(state, selected, driverContext); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationLongAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongAggregatorFunction.java similarity index 83% rename from x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationLongAggregatorFunction.java rename to x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongAggregatorFunction.java index d7644c8286b64..ec5f38833ff0b 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationLongAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongAggregatorFunction.java @@ -20,10 +20,10 @@ import org.elasticsearch.compute.operator.DriverContext; /** - * {@link AggregatorFunction} implementation for {@link StdDeviationLongAggregator}. + * {@link AggregatorFunction} implementation for {@link StdDevLongAggregator}. * This class is generated. Do not edit it. */ -public final class StdDeviationLongAggregatorFunction implements AggregatorFunction { +public final class StdDevLongAggregatorFunction implements AggregatorFunction { private static final List INTERMEDIATE_STATE_DESC = List.of( new IntermediateStateDesc("mean", ElementType.DOUBLE), new IntermediateStateDesc("m2", ElementType.DOUBLE), @@ -35,16 +35,16 @@ public final class StdDeviationLongAggregatorFunction implements AggregatorFunct private final List channels; - public StdDeviationLongAggregatorFunction(DriverContext driverContext, List channels, + public StdDevLongAggregatorFunction(DriverContext driverContext, List channels, StdDeviationStates.SingleState state) { this.driverContext = driverContext; this.channels = channels; this.state = state; } - public static StdDeviationLongAggregatorFunction create(DriverContext driverContext, + public static StdDevLongAggregatorFunction create(DriverContext driverContext, List channels) { - return new StdDeviationLongAggregatorFunction(driverContext, channels, StdDeviationLongAggregator.initSingle()); + return new StdDevLongAggregatorFunction(driverContext, channels, StdDevLongAggregator.initSingle()); } public static List intermediateStateDesc() { @@ -85,7 +85,7 @@ public void addRawInput(Page page, BooleanVector mask) { private void addRawVector(LongVector vector) { for (int i = 0; i < vector.getPositionCount(); i++) { - StdDeviationLongAggregator.combine(state, vector.getLong(i)); + StdDevLongAggregator.combine(state, vector.getLong(i)); } } @@ -94,7 +94,7 @@ private void addRawVector(LongVector vector, BooleanVector mask) { if (mask.getBoolean(i) == false) { continue; } - StdDeviationLongAggregator.combine(state, vector.getLong(i)); + StdDevLongAggregator.combine(state, vector.getLong(i)); } } @@ -106,7 +106,7 @@ private void addRawBlock(LongBlock block) { int start = block.getFirstValueIndex(p); int end = start + block.getValueCount(p); for (int i = start; i < end; i++) { - StdDeviationLongAggregator.combine(state, block.getLong(i)); + StdDevLongAggregator.combine(state, block.getLong(i)); } } } @@ -122,7 +122,7 @@ private void addRawBlock(LongBlock block, BooleanVector mask) { int start = block.getFirstValueIndex(p); int end = start + block.getValueCount(p); for (int i = start; i < end; i++) { - StdDeviationLongAggregator.combine(state, block.getLong(i)); + StdDevLongAggregator.combine(state, block.getLong(i)); } } } @@ -149,7 +149,7 @@ public void addIntermediateInput(Page page) { } LongVector count = ((LongBlock) countUncast).asVector(); assert count.getPositionCount() == 1; - StdDeviationLongAggregator.combineIntermediate(state, mean.getDouble(0), m2.getDouble(0), count.getLong(0)); + StdDevLongAggregator.combineIntermediate(state, mean.getDouble(0), m2.getDouble(0), count.getLong(0)); } @Override @@ -159,7 +159,7 @@ public void evaluateIntermediate(Block[] blocks, int offset, DriverContext drive @Override public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { - blocks[offset] = StdDeviationLongAggregator.evaluateFinal(state, driverContext); + blocks[offset] = StdDevLongAggregator.evaluateFinal(state, driverContext); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongAggregatorFunctionSupplier.java similarity index 53% rename from x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregatorFunctionSupplier.java rename to x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongAggregatorFunctionSupplier.java index 28915079c4b7a..09b996201ef16 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationDoubleAggregatorFunctionSupplier.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongAggregatorFunctionSupplier.java @@ -11,29 +11,28 @@ import org.elasticsearch.compute.operator.DriverContext; /** - * {@link AggregatorFunctionSupplier} implementation for {@link StdDeviationDoubleAggregator}. + * {@link AggregatorFunctionSupplier} implementation for {@link StdDevLongAggregator}. * This class is generated. Do not edit it. */ -public final class StdDeviationDoubleAggregatorFunctionSupplier implements AggregatorFunctionSupplier { +public final class StdDevLongAggregatorFunctionSupplier implements AggregatorFunctionSupplier { private final List channels; - public StdDeviationDoubleAggregatorFunctionSupplier(List channels) { + public StdDevLongAggregatorFunctionSupplier(List channels) { this.channels = channels; } @Override - public StdDeviationDoubleAggregatorFunction aggregator(DriverContext driverContext) { - return StdDeviationDoubleAggregatorFunction.create(driverContext, channels); + public StdDevLongAggregatorFunction aggregator(DriverContext driverContext) { + return StdDevLongAggregatorFunction.create(driverContext, channels); } @Override - public StdDeviationDoubleGroupingAggregatorFunction groupingAggregator( - DriverContext driverContext) { - return StdDeviationDoubleGroupingAggregatorFunction.create(channels, driverContext); + public StdDevLongGroupingAggregatorFunction groupingAggregator(DriverContext driverContext) { + return StdDevLongGroupingAggregatorFunction.create(channels, driverContext); } @Override public String describe() { - return "std_deviation of doubles"; + return "std_dev of longs"; } } diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongGroupingAggregatorFunction.java similarity index 84% rename from x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationLongGroupingAggregatorFunction.java rename to x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongGroupingAggregatorFunction.java index 5091e84840c76..7e16fa2967c74 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDeviationLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongGroupingAggregatorFunction.java @@ -21,10 +21,10 @@ import org.elasticsearch.compute.operator.DriverContext; /** - * {@link GroupingAggregatorFunction} implementation for {@link StdDeviationLongAggregator}. + * {@link GroupingAggregatorFunction} implementation for {@link StdDevLongAggregator}. * This class is generated. Do not edit it. */ -public final class StdDeviationLongGroupingAggregatorFunction implements GroupingAggregatorFunction { +public final class StdDevLongGroupingAggregatorFunction implements GroupingAggregatorFunction { private static final List INTERMEDIATE_STATE_DESC = List.of( new IntermediateStateDesc("mean", ElementType.DOUBLE), new IntermediateStateDesc("m2", ElementType.DOUBLE), @@ -36,16 +36,16 @@ public final class StdDeviationLongGroupingAggregatorFunction implements Groupin private final DriverContext driverContext; - public StdDeviationLongGroupingAggregatorFunction(List channels, + public StdDevLongGroupingAggregatorFunction(List channels, StdDeviationStates.GroupingState state, DriverContext driverContext) { this.channels = channels; this.state = state; this.driverContext = driverContext; } - public static StdDeviationLongGroupingAggregatorFunction create(List channels, + public static StdDevLongGroupingAggregatorFunction create(List channels, DriverContext driverContext) { - return new StdDeviationLongGroupingAggregatorFunction(channels, StdDeviationLongAggregator.initGrouping(driverContext.bigArrays()), driverContext); + return new StdDevLongGroupingAggregatorFunction(channels, StdDevLongAggregator.initGrouping(driverContext.bigArrays()), driverContext); } public static List intermediateStateDesc() { @@ -108,7 +108,7 @@ private void addRawInput(int positionOffset, IntVector groups, LongBlock values) int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { - StdDeviationLongAggregator.combine(state, groupId, values.getLong(v)); + StdDevLongAggregator.combine(state, groupId, values.getLong(v)); } } } @@ -116,7 +116,7 @@ private void addRawInput(int positionOffset, IntVector groups, LongBlock values) private void addRawInput(int positionOffset, IntVector groups, LongVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { int groupId = groups.getInt(groupPosition); - StdDeviationLongAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); + StdDevLongAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); } } @@ -135,7 +135,7 @@ private void addRawInput(int positionOffset, IntBlock groups, LongBlock values) int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { - StdDeviationLongAggregator.combine(state, groupId, values.getLong(v)); + StdDevLongAggregator.combine(state, groupId, values.getLong(v)); } } } @@ -150,7 +150,7 @@ private void addRawInput(int positionOffset, IntBlock groups, LongVector values) int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - StdDeviationLongAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); + StdDevLongAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); } } } @@ -182,7 +182,7 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page assert mean.getPositionCount() == m2.getPositionCount() && mean.getPositionCount() == count.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { int groupId = groups.getInt(groupPosition); - StdDeviationLongAggregator.combineIntermediate(state, groupId, mean.getDouble(groupPosition + positionOffset), m2.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + StdDevLongAggregator.combineIntermediate(state, groupId, mean.getDouble(groupPosition + positionOffset), m2.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); } } @@ -191,9 +191,9 @@ public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction inpu if (input.getClass() != getClass()) { throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); } - StdDeviationStates.GroupingState inState = ((StdDeviationLongGroupingAggregatorFunction) input).state; + StdDeviationStates.GroupingState inState = ((StdDevLongGroupingAggregatorFunction) input).state; state.enableGroupIdTracking(new SeenGroupIds.Empty()); - StdDeviationLongAggregator.combineStates(state, groupId, inState, position); + StdDevLongAggregator.combineStates(state, groupId, inState, position); } @Override @@ -204,7 +204,7 @@ public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) @Override public void evaluateFinal(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { - blocks[offset] = StdDeviationLongAggregator.evaluateFinal(state, selected, driverContext); + blocks[offset] = StdDevLongAggregator.evaluateFinal(state, selected, driverContext); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDeviationAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDevAggregator.java.st similarity index 98% rename from x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDeviationAggregator.java.st rename to x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDevAggregator.java.st index 3ec96eab43321..cf52ee572a727 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDeviationAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDevAggregator.java.st @@ -27,7 +27,7 @@ import org.elasticsearch.compute.operator.DriverContext; @IntermediateState(name = "count", type = "LONG") } ) @GroupingAggregator -public class StdDeviation$Type$Aggregator { +public class StdDev$Type$Aggregator { public static StdDeviationStates.SingleState initSingle() { return new StdDeviationStates.SingleState(); diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec index 5435d13233db5..1fa3e71018325 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec @@ -2687,143 +2687,143 @@ max:integer | job_positions:keyword ; stdDeviation -required_capability: std_deviation +required_capability: std_dev // tag::stdev[] FROM employees -| STATS STD_DEVIATION(height) +| STATS STD_DEV(height) // end::stdev[] ; // tag::stdev-result[] -STD_DEVIATION(height):double +STD_DEV(height):double 0.20637044362020449 // end::stdev-result[] ; stdDeviationNested -required_capability: std_deviation -// tag::docsStatsStdDeviationNestedExpression[] +required_capability: std_dev +// tag::docsStatsStdDevNestedExpression[] FROM employees -| STATS stdev_salary_change = STD_DEVIATION(MV_MAX(salary_change)) -// end::docsStatsStdDeviationNestedExpression[] +| STATS stddev_salary_change = STD_DEV(MV_MAX(salary_change)) +// end::docsStatsStdDevNestedExpression[] ; -// tag::docsStatsStdDeviationNestedExpression-result[] -stdev_salary_change:double +// tag::docsStatsStdDevNestedExpression-result[] +stddev_salary_change:double 6.875829592924112 -// end::docsStatsStdDeviationNestedExpression-result[] +// end::docsStatsStdDevNestedExpression-result[] ; stdDeviationWithLongs -required_capability: std_deviation +required_capability: std_dev FROM employees -| STATS STD_DEVIATION(avg_worked_seconds) +| STATS STD_DEV(avg_worked_seconds) ; -STD_DEVIATION(avg_worked_seconds):double +STD_DEV(avg_worked_seconds):double 5.76010425971634E7 ; stdDeviationWithInts -required_capability: std_deviation +required_capability: std_dev FROM employees -| STATS STD_DEVIATION(salary) +| STATS STD_DEV(salary) ; -STD_DEVIATION(salary):double +STD_DEV(salary):double 13765.12550278783 ; stdDeviationConstantValue -required_capability: std_deviation +required_capability: std_dev FROM employees | WHERE languages == 2 -| STATS STD_DEVIATION(languages) +| STATS STD_DEV(languages) ; -STD_DEVIATION(languages):double +STD_DEV(languages):double 0.0 ; stdDeviationGrouped -required_capability: std_deviation +required_capability: std_dev FROM employees -| STATS STD_DEVIATION(height) BY languages +| STATS STD_DEV(height) BY languages | SORT languages asc ; -STD_DEVIATION(height):double | languages:integer -0.22106409327010415 | 1 -0.22797190865484734 | 2 -0.18893070075713295 | 3 -0.14656141004227627 | 4 -0.17733860152780256 | 5 -0.2486543786061287 | null +STD_DEV(height):double | languages:integer +0.22106409327010415 | 1 +0.22797190865484734 | 2 +0.18893070075713295 | 3 +0.14656141004227627 | 4 +0.17733860152780256 | 5 +0.2486543786061287 | null ; stdDeviationGrouped1 -required_capability: std_deviation +required_capability: std_dev FROM employees | WHERE languages == 1 -| STATS STD_DEVIATION(height) +| STATS STD_DEV(height) ; -STD_DEVIATION(height):double +STD_DEV(height):double 0.22106409327010415 ; stdDeviationGrouped2 -required_capability: std_deviation +required_capability: std_dev FROM employees | WHERE languages == 2 -| STATS STD_DEVIATION(height) +| STATS STD_DEV(height) ; -STD_DEVIATION(height):double +STD_DEV(height):double 0.22797190865484734 ; stdDeviationGrouped3 -required_capability: std_deviation +required_capability: std_dev FROM employees | WHERE languages == 3 -| STATS STD_DEVIATION(height) +| STATS STD_DEV(height) ; -STD_DEVIATION(height):double +STD_DEV(height):double 0.18893070075713295 ; stdDeviationGrouped4 -required_capability: std_deviation +required_capability: std_dev FROM employees | WHERE languages == 4 -| STATS STD_DEVIATION(height) +| STATS STD_DEV(height) ; -STD_DEVIATION(height):double +STD_DEV(height):double 0.14656141004227627 ; stdDeviationGrouped5 -required_capability: std_deviation +required_capability: std_dev FROM employees | WHERE languages == 5 -| STATS STD_DEVIATION(height) +| STATS STD_DEV(height) ; -STD_DEVIATION(height):double +STD_DEV(height):double 0.17733860152780256 ; stdDeviationNoRows -required_capability: std_deviation +required_capability: std_dev FROM employees | WHERE languages IS null -| STATS STD_DEVIATION(languages) +| STATS STD_DEV(languages) ; -STD_DEVIATION(languages):double +STD_DEV(languages):double null ; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index f88eb24f52c7e..f295898709319 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -431,9 +431,9 @@ public enum Cap { PER_AGG_FILTERING_ORDS, /** - * Support for {@code STD_DEVIATION} aggregation. + * Support for {@code STD_DEV} aggregation. */ - STD_DEVIATION, + STD_DEV, /** * Fix for https://github.com/elastic/elasticsearch/issues/114714 diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index acf5ca3143cbc..f5261c9c63969 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -28,7 +28,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile; import org.elasticsearch.xpack.esql.expression.function.aggregate.Rate; import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid; -import org.elasticsearch.xpack.esql.expression.function.aggregate.StdDeviation; +import org.elasticsearch.xpack.esql.expression.function.aggregate.StdDev; import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum; import org.elasticsearch.xpack.esql.expression.function.aggregate.Top; import org.elasticsearch.xpack.esql.expression.function.aggregate.Values; @@ -274,7 +274,7 @@ private FunctionDefinition[][] functions() { def(MedianAbsoluteDeviation.class, uni(MedianAbsoluteDeviation::new), "median_absolute_deviation"), def(Min.class, uni(Min::new), "min"), def(Percentile.class, bi(Percentile::new), "percentile"), - def(StdDeviation.class, uni(StdDeviation::new), "std_deviation"), + def(StdDev.class, uni(StdDev::new), "std_dev"), def(Sum.class, uni(Sum::new), "sum"), def(Top.class, tri(Top::new), "top"), def(Values.class, uni(Values::new), "values"), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java index 5e661979ba3b9..19b037be23136 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java @@ -42,7 +42,7 @@ public static List getNamedWriteables() { Percentile.ENTRY, Rate.ENTRY, SpatialCentroid.ENTRY, - StdDeviation.ENTRY, + StdDev.ENTRY, Sum.ENTRY, Top.ENTRY, Values.ENTRY, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDeviation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDev.java similarity index 64% rename from x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDeviation.java rename to x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDev.java index 301bb5b519b52..03e9cfcb81d31 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDeviation.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDev.java @@ -10,9 +10,9 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; -import org.elasticsearch.compute.aggregation.StdDeviationDoubleAggregatorFunctionSupplier; -import org.elasticsearch.compute.aggregation.StdDeviationIntAggregatorFunctionSupplier; -import org.elasticsearch.compute.aggregation.StdDeviationLongAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.StdDevDoubleAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.StdDevIntAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.StdDevLongAggregatorFunctionSupplier; import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; @@ -29,11 +29,11 @@ import static java.util.Collections.emptyList; -public class StdDeviation extends AggregateFunction implements ToAggregator { +public class StdDev extends AggregateFunction implements ToAggregator { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( Expression.class, - "StdDeviation", - StdDeviation::new + "StdDev", + StdDev::new ); @FunctionInfo( @@ -45,20 +45,20 @@ public class StdDeviation extends AggregateFunction implements ToAggregator { @Example( description = "The expression can use inline functions. For example, to calculate the standard " + "deviation of each employee's maximum salary changes, first use `MV_MAX` on each row, " - + "and then use `StdDeviation` on the result", + + "and then use `STD_DEV` on the result", file = "stats", - tag = "docsStatsStdDeviationNestedExpression" + tag = "docsStatsStdDevNestedExpression" ) } ) - public StdDeviation(Source source, @Param(name = "number", type = { "double", "integer", "long" }) Expression field) { + public StdDev(Source source, @Param(name = "number", type = { "double", "integer", "long" }) Expression field) { this(source, field, Literal.TRUE); } - public StdDeviation(Source source, Expression field, Expression filter) { + public StdDev(Source source, Expression field, Expression filter) { super(source, field, filter, emptyList()); } - private StdDeviation(StreamInput in) throws IOException { + private StdDev(StreamInput in) throws IOException { super(in); } @@ -73,30 +73,30 @@ public DataType dataType() { } @Override - protected NodeInfo info() { - return NodeInfo.create(this, StdDeviation::new, field(), filter()); + protected NodeInfo info() { + return NodeInfo.create(this, StdDev::new, field(), filter()); } @Override - public StdDeviation replaceChildren(List newChildren) { - return new StdDeviation(source(), newChildren.get(0), newChildren.get(1)); + public StdDev replaceChildren(List newChildren) { + return new StdDev(source(), newChildren.get(0), newChildren.get(1)); } - public StdDeviation withFilter(Expression filter) { - return new StdDeviation(source(), field(), filter); + public StdDev withFilter(Expression filter) { + return new StdDev(source(), field(), filter); } @Override public final AggregatorFunctionSupplier supplier(List inputChannels) { DataType type = field().dataType(); if (type == DataType.LONG) { - return new StdDeviationLongAggregatorFunctionSupplier(inputChannels); + return new StdDevLongAggregatorFunctionSupplier(inputChannels); } if (type == DataType.INTEGER) { - return new StdDeviationIntAggregatorFunctionSupplier(inputChannels); + return new StdDevIntAggregatorFunctionSupplier(inputChannels); } if (type == DataType.DOUBLE) { - return new StdDeviationDoubleAggregatorFunctionSupplier(inputChannels); + return new StdDevDoubleAggregatorFunctionSupplier(inputChannels); } throw EsqlIllegalArgumentException.illegalDataType(type); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java index 5bcbf69ae63db..605e0d7c3109c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java @@ -34,7 +34,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.Rate; import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialAggregateFunction; import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid; -import org.elasticsearch.xpack.esql.expression.function.aggregate.StdDeviation; +import org.elasticsearch.xpack.esql.expression.function.aggregate.StdDev; import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum; import org.elasticsearch.xpack.esql.expression.function.aggregate.ToPartial; import org.elasticsearch.xpack.esql.expression.function.aggregate.Top; @@ -49,9 +49,6 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -import static org.elasticsearch.xpack.esql.core.type.DataType.CARTESIAN_POINT; -import static org.elasticsearch.xpack.esql.core.type.DataType.GEO_POINT; - /** * Static class used to convert aggregate expressions to the named expressions that represent their intermediate state. *

@@ -79,7 +76,7 @@ final class AggregateMapper { Min.class, Percentile.class, SpatialCentroid.class, - StdDeviation.class, + StdDev.class, Sum.class, Values.class, Top.class, @@ -173,7 +170,7 @@ private static Stream, Tuple>> typeAndNames(Class types = List.of("Int", "Long", "Double", "Boolean", "BytesRef"); } else if (Top.class.isAssignableFrom(clazz)) { types = List.of("Boolean", "Int", "Long", "Double", "Ip", "BytesRef"); - } else if (Rate.class.isAssignableFrom(clazz) || StdDeviation.class.isAssignableFrom(clazz)) { + } else if (Rate.class.isAssignableFrom(clazz) || StdDev.class.isAssignableFrom(clazz)) { types = List.of("Int", "Long", "Double"); } else if (FromPartial.class.isAssignableFrom(clazz) || ToPartial.class.isAssignableFrom(clazz)) { types = List.of(""); // no type diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDeviationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDevTests.java similarity index 87% rename from x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDeviationTests.java rename to x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDevTests.java index f107c5623a35f..faec43dc563c2 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDeviationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDevTests.java @@ -27,8 +27,8 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.nullValue; -public class StdDeviationTests extends AbstractAggregationTestCase { - public StdDeviationTests(@Name("TestCase") Supplier testCaseSupplier) { +public class StdDevTests extends AbstractAggregationTestCase { + public StdDevTests(@Name("TestCase") Supplier testCaseSupplier) { this.testCase = testCaseSupplier.get(); } @@ -40,7 +40,7 @@ public static Iterable parameters() { MultiRowTestCaseSupplier.intCases(1, 1000, Integer.MIN_VALUE, Integer.MAX_VALUE, true), MultiRowTestCaseSupplier.longCases(1, 1000, Long.MIN_VALUE, Long.MAX_VALUE, true), MultiRowTestCaseSupplier.doubleCases(1, 1000, -Double.MAX_VALUE, Double.MAX_VALUE, true) - ).flatMap(List::stream).map(StdDeviationTests::makeSupplier).collect(Collectors.toCollection(() -> suppliers)); + ).flatMap(List::stream).map(StdDevTests::makeSupplier).collect(Collectors.toCollection(() -> suppliers)); // No rows for (var dataType : List.of(DataType.INTEGER, DataType.LONG, DataType.DOUBLE)) { @@ -50,7 +50,7 @@ public static Iterable parameters() { List.of(dataType), () -> new TestCaseSupplier.TestCase( List.of(TestCaseSupplier.TypedData.multiRow(List.of(), dataType, "field")), - "StdDeviation[field=Attribute[channel=0]]", + "StdDev[field=Attribute[channel=0]]", DataType.DOUBLE, nullValue() ) @@ -62,7 +62,7 @@ public static Iterable parameters() { @Override protected Expression build(Source source, List args) { - return new StdDeviation(source, args.get(0)); + return new StdDev(source, args.get(0)); } private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier fieldSupplier) { @@ -80,7 +80,7 @@ private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier var expected = Double.isInfinite(result) ? null : result; return new TestCaseSupplier.TestCase( List.of(fieldTypedData), - "StdDeviation[field=Attribute[channel=0]]", + "StdDev[field=Attribute[channel=0]]", DataType.DOUBLE, equalTo(expected) ); From 35039349e7dba87d38ccc3289ab991aa3a79471f Mon Sep 17 00:00:00 2001 From: Larisa Motova Date: Thu, 14 Nov 2024 12:23:36 -1000 Subject: [PATCH 08/20] fix parallel algorithm --- .../org/elasticsearch/compute/aggregation/WelfordAlgorithm.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/WelfordAlgorithm.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/WelfordAlgorithm.java index 6d8d0beb51c22..29cb7e62e1a8c 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/WelfordAlgorithm.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/WelfordAlgorithm.java @@ -69,8 +69,8 @@ public void add(double meanValue, double m2Value, long countValue) { } double delta = mean - meanValue; m2 += m2Value + delta * delta * count * countValue / (count + countValue); + mean = (mean * count + meanValue * countValue) / (count + countValue); count += countValue; - mean += delta * countValue / (count); } public double evaluate() { From 2b4f80daf71f790732cebdb17c1c2825940daba2 Mon Sep 17 00:00:00 2001 From: Larisa Motova Date: Thu, 14 Nov 2024 12:48:36 -1000 Subject: [PATCH 09/20] whoops docs --- docs/reference/esql/functions/aggregation-functions.asciidoc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/reference/esql/functions/aggregation-functions.asciidoc b/docs/reference/esql/functions/aggregation-functions.asciidoc index 2c9d831dd05c2..9e8ddb483f8f4 100644 --- a/docs/reference/esql/functions/aggregation-functions.asciidoc +++ b/docs/reference/esql/functions/aggregation-functions.asciidoc @@ -33,7 +33,7 @@ include::layout/median_absolute_deviation.asciidoc[] include::layout/min.asciidoc[] include::layout/percentile.asciidoc[] include::layout/st_centroid_agg.asciidoc[] -include::layout/std_deviation.asciidoc[] +include::layout/std_dev.asciidoc[] include::layout/sum.asciidoc[] include::layout/top.asciidoc[] include::layout/values.asciidoc[] From bb329478a3d7342d8307bedd3d5d68715ebf1168 Mon Sep 17 00:00:00 2001 From: Larisa Motova Date: Thu, 14 Nov 2024 13:01:51 -1000 Subject: [PATCH 10/20] lint --- .../xpack/esql/expression/function/aggregate/StdDev.java | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDev.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDev.java index 03e9cfcb81d31..dbd1da348d8ba 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDev.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDev.java @@ -30,11 +30,7 @@ import static java.util.Collections.emptyList; public class StdDev extends AggregateFunction implements ToAggregator { - public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( - Expression.class, - "StdDev", - StdDev::new - ); + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "StdDev", StdDev::new); @FunctionInfo( returnType = "double", From f69d640ce6db855f43b111f6078045ef465895a6 Mon Sep 17 00:00:00 2001 From: Larisa Motova Date: Thu, 14 Nov 2024 14:49:59 -1000 Subject: [PATCH 11/20] more renaming --- x-pack/plugin/esql/compute/build.gradle | 10 +++---- .../aggregation/StdDevDoubleAggregator.java | 26 +++++++++---------- .../aggregation/StdDevFloatAggregator.java | 26 +++++++++---------- .../aggregation/StdDevIntAggregator.java | 26 +++++++++---------- .../aggregation/StdDevLongAggregator.java | 26 +++++++++---------- .../StdDevDoubleAggregatorFunction.java | 4 +-- ...tdDevDoubleGroupingAggregatorFunction.java | 6 ++--- .../StdDevFloatAggregatorFunction.java | 4 +-- ...StdDevFloatGroupingAggregatorFunction.java | 6 ++--- .../StdDevIntAggregatorFunction.java | 4 +-- .../StdDevIntGroupingAggregatorFunction.java | 6 ++--- .../StdDevLongAggregatorFunction.java | 4 +-- .../StdDevLongGroupingAggregatorFunction.java | 6 ++--- ...DeviationStates.java => StdDevStates.java} | 4 +-- .../aggregation/X-StdDevAggregator.java.st | 26 +++++++++---------- 15 files changed, 92 insertions(+), 92 deletions(-) rename x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/{StdDeviationStates.java => StdDevStates.java} (98%) diff --git a/x-pack/plugin/esql/compute/build.gradle b/x-pack/plugin/esql/compute/build.gradle index da7c7d876ce6e..609c778df5929 100644 --- a/x-pack/plugin/esql/compute/build.gradle +++ b/x-pack/plugin/esql/compute/build.gradle @@ -608,25 +608,25 @@ tasks.named('stringTemplates').configure { it.outputFile = "org/elasticsearch/compute/aggregation/RateDoubleAggregator.java" } - File stdDeviationAggregatorInputFile = file("src/main/java/org/elasticsearch/compute/aggregation/X-StdDevAggregator.java.st") + File stdDevAggregatorInputFile = file("src/main/java/org/elasticsearch/compute/aggregation/X-StdDevAggregator.java.st") template { it.properties = intProperties - it.inputFile = stdDeviationAggregatorInputFile + it.inputFile = stdDevAggregatorInputFile it.outputFile = "org/elasticsearch/compute/aggregation/StdDevIntAggregator.java" } template { it.properties = longProperties - it.inputFile = stdDeviationAggregatorInputFile + it.inputFile = stdDevAggregatorInputFile it.outputFile = "org/elasticsearch/compute/aggregation/StdDevLongAggregator.java" } template { it.properties = floatProperties - it.inputFile = stdDeviationAggregatorInputFile + it.inputFile = stdDevAggregatorInputFile it.outputFile = "org/elasticsearch/compute/aggregation/StdDevFloatAggregator.java" } template { it.properties = doubleProperties - it.inputFile = stdDeviationAggregatorInputFile + it.inputFile = stdDevAggregatorInputFile it.outputFile = "org/elasticsearch/compute/aggregation/StdDevDoubleAggregator.java" } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevDoubleAggregator.java index babd7e25f5027..67453422a6b74 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevDoubleAggregator.java @@ -18,7 +18,7 @@ /** * A standard deviation aggregation definition for double. - * This class is generated. Edit `X-StdDeviationAggregator.java.st` instead. + * This class is generated. Edit `X-StdDevAggregator.java.st` instead. */ @Aggregator( { @@ -29,19 +29,19 @@ @GroupingAggregator public class StdDevDoubleAggregator { - public static StdDeviationStates.SingleState initSingle() { - return new StdDeviationStates.SingleState(); + public static StdDevStates.SingleState initSingle() { + return new StdDevStates.SingleState(); } - public static void combine(StdDeviationStates.SingleState state, double value) { + public static void combine(StdDevStates.SingleState state, double value) { state.add(value); } - public static void combineIntermediate(StdDeviationStates.SingleState state, double mean, double m2, long count) { + public static void combineIntermediate(StdDevStates.SingleState state, double mean, double m2, long count) { state.combine(mean, m2, count); } - public static Block evaluateFinal(StdDeviationStates.SingleState state, DriverContext driverContext) { + public static Block evaluateFinal(StdDevStates.SingleState state, DriverContext driverContext) { final long count = state.count(); final double m2 = state.m2(); if (count == 0 || Double.isFinite(m2) == false) { @@ -50,28 +50,28 @@ public static Block evaluateFinal(StdDeviationStates.SingleState state, DriverCo return driverContext.blockFactory().newConstantDoubleBlockWith(state.evaluateFinal(), 1); } - public static StdDeviationStates.GroupingState initGrouping(BigArrays bigArrays) { - return new StdDeviationStates.GroupingState(bigArrays); + public static StdDevStates.GroupingState initGrouping(BigArrays bigArrays) { + return new StdDevStates.GroupingState(bigArrays); } - public static void combine(StdDeviationStates.GroupingState current, int groupId, double value) { + public static void combine(StdDevStates.GroupingState current, int groupId, double value) { current.add(groupId, value); } public static void combineStates( - StdDeviationStates.GroupingState current, + StdDevStates.GroupingState current, int groupId, - StdDeviationStates.GroupingState state, + StdDevStates.GroupingState state, int statePosition ) { current.combine(groupId, state.getOrNull(statePosition)); } - public static void combineIntermediate(StdDeviationStates.GroupingState state, int groupId, double mean, double m2, long count) { + public static void combineIntermediate(StdDevStates.GroupingState state, int groupId, double mean, double m2, long count) { state.combine(groupId, mean, m2, count); } - public static Block evaluateFinal(StdDeviationStates.GroupingState state, IntVector selected, DriverContext driverContext) { + public static Block evaluateFinal(StdDevStates.GroupingState state, IntVector selected, DriverContext driverContext) { try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount())) { for (int i = 0; i < selected.getPositionCount(); i++) { final var groupId = selected.getInt(i); diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevFloatAggregator.java index d118b913cc880..e24cf03752862 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevFloatAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevFloatAggregator.java @@ -18,7 +18,7 @@ /** * A standard deviation aggregation definition for float. - * This class is generated. Edit `X-StdDeviationAggregator.java.st` instead. + * This class is generated. Edit `X-StdDevAggregator.java.st` instead. */ @Aggregator( { @@ -29,19 +29,19 @@ @GroupingAggregator public class StdDevFloatAggregator { - public static StdDeviationStates.SingleState initSingle() { - return new StdDeviationStates.SingleState(); + public static StdDevStates.SingleState initSingle() { + return new StdDevStates.SingleState(); } - public static void combine(StdDeviationStates.SingleState state, float value) { + public static void combine(StdDevStates.SingleState state, float value) { state.add(value); } - public static void combineIntermediate(StdDeviationStates.SingleState state, double mean, double m2, long count) { + public static void combineIntermediate(StdDevStates.SingleState state, double mean, double m2, long count) { state.combine(mean, m2, count); } - public static Block evaluateFinal(StdDeviationStates.SingleState state, DriverContext driverContext) { + public static Block evaluateFinal(StdDevStates.SingleState state, DriverContext driverContext) { final long count = state.count(); final double m2 = state.m2(); if (count == 0 || Double.isFinite(m2) == false) { @@ -50,28 +50,28 @@ public static Block evaluateFinal(StdDeviationStates.SingleState state, DriverCo return driverContext.blockFactory().newConstantDoubleBlockWith(state.evaluateFinal(), 1); } - public static StdDeviationStates.GroupingState initGrouping(BigArrays bigArrays) { - return new StdDeviationStates.GroupingState(bigArrays); + public static StdDevStates.GroupingState initGrouping(BigArrays bigArrays) { + return new StdDevStates.GroupingState(bigArrays); } - public static void combine(StdDeviationStates.GroupingState current, int groupId, float value) { + public static void combine(StdDevStates.GroupingState current, int groupId, float value) { current.add(groupId, value); } public static void combineStates( - StdDeviationStates.GroupingState current, + StdDevStates.GroupingState current, int groupId, - StdDeviationStates.GroupingState state, + StdDevStates.GroupingState state, int statePosition ) { current.combine(groupId, state.getOrNull(statePosition)); } - public static void combineIntermediate(StdDeviationStates.GroupingState state, int groupId, double mean, double m2, long count) { + public static void combineIntermediate(StdDevStates.GroupingState state, int groupId, double mean, double m2, long count) { state.combine(groupId, mean, m2, count); } - public static Block evaluateFinal(StdDeviationStates.GroupingState state, IntVector selected, DriverContext driverContext) { + public static Block evaluateFinal(StdDevStates.GroupingState state, IntVector selected, DriverContext driverContext) { try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount())) { for (int i = 0; i < selected.getPositionCount(); i++) { final var groupId = selected.getInt(i); diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevIntAggregator.java index 10466b25cb75c..c773e6737f950 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevIntAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevIntAggregator.java @@ -18,7 +18,7 @@ /** * A standard deviation aggregation definition for int. - * This class is generated. Edit `X-StdDeviationAggregator.java.st` instead. + * This class is generated. Edit `X-StdDevAggregator.java.st` instead. */ @Aggregator( { @@ -29,19 +29,19 @@ @GroupingAggregator public class StdDevIntAggregator { - public static StdDeviationStates.SingleState initSingle() { - return new StdDeviationStates.SingleState(); + public static StdDevStates.SingleState initSingle() { + return new StdDevStates.SingleState(); } - public static void combine(StdDeviationStates.SingleState state, int value) { + public static void combine(StdDevStates.SingleState state, int value) { state.add(value); } - public static void combineIntermediate(StdDeviationStates.SingleState state, double mean, double m2, long count) { + public static void combineIntermediate(StdDevStates.SingleState state, double mean, double m2, long count) { state.combine(mean, m2, count); } - public static Block evaluateFinal(StdDeviationStates.SingleState state, DriverContext driverContext) { + public static Block evaluateFinal(StdDevStates.SingleState state, DriverContext driverContext) { final long count = state.count(); final double m2 = state.m2(); if (count == 0 || Double.isFinite(m2) == false) { @@ -50,28 +50,28 @@ public static Block evaluateFinal(StdDeviationStates.SingleState state, DriverCo return driverContext.blockFactory().newConstantDoubleBlockWith(state.evaluateFinal(), 1); } - public static StdDeviationStates.GroupingState initGrouping(BigArrays bigArrays) { - return new StdDeviationStates.GroupingState(bigArrays); + public static StdDevStates.GroupingState initGrouping(BigArrays bigArrays) { + return new StdDevStates.GroupingState(bigArrays); } - public static void combine(StdDeviationStates.GroupingState current, int groupId, int value) { + public static void combine(StdDevStates.GroupingState current, int groupId, int value) { current.add(groupId, value); } public static void combineStates( - StdDeviationStates.GroupingState current, + StdDevStates.GroupingState current, int groupId, - StdDeviationStates.GroupingState state, + StdDevStates.GroupingState state, int statePosition ) { current.combine(groupId, state.getOrNull(statePosition)); } - public static void combineIntermediate(StdDeviationStates.GroupingState state, int groupId, double mean, double m2, long count) { + public static void combineIntermediate(StdDevStates.GroupingState state, int groupId, double mean, double m2, long count) { state.combine(groupId, mean, m2, count); } - public static Block evaluateFinal(StdDeviationStates.GroupingState state, IntVector selected, DriverContext driverContext) { + public static Block evaluateFinal(StdDevStates.GroupingState state, IntVector selected, DriverContext driverContext) { try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount())) { for (int i = 0; i < selected.getPositionCount(); i++) { final var groupId = selected.getInt(i); diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevLongAggregator.java index f1f40e211d7b9..149b0c1eedb9d 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevLongAggregator.java @@ -18,7 +18,7 @@ /** * A standard deviation aggregation definition for long. - * This class is generated. Edit `X-StdDeviationAggregator.java.st` instead. + * This class is generated. Edit `X-StdDevAggregator.java.st` instead. */ @Aggregator( { @@ -29,19 +29,19 @@ @GroupingAggregator public class StdDevLongAggregator { - public static StdDeviationStates.SingleState initSingle() { - return new StdDeviationStates.SingleState(); + public static StdDevStates.SingleState initSingle() { + return new StdDevStates.SingleState(); } - public static void combine(StdDeviationStates.SingleState state, long value) { + public static void combine(StdDevStates.SingleState state, long value) { state.add(value); } - public static void combineIntermediate(StdDeviationStates.SingleState state, double mean, double m2, long count) { + public static void combineIntermediate(StdDevStates.SingleState state, double mean, double m2, long count) { state.combine(mean, m2, count); } - public static Block evaluateFinal(StdDeviationStates.SingleState state, DriverContext driverContext) { + public static Block evaluateFinal(StdDevStates.SingleState state, DriverContext driverContext) { final long count = state.count(); final double m2 = state.m2(); if (count == 0 || Double.isFinite(m2) == false) { @@ -50,28 +50,28 @@ public static Block evaluateFinal(StdDeviationStates.SingleState state, DriverCo return driverContext.blockFactory().newConstantDoubleBlockWith(state.evaluateFinal(), 1); } - public static StdDeviationStates.GroupingState initGrouping(BigArrays bigArrays) { - return new StdDeviationStates.GroupingState(bigArrays); + public static StdDevStates.GroupingState initGrouping(BigArrays bigArrays) { + return new StdDevStates.GroupingState(bigArrays); } - public static void combine(StdDeviationStates.GroupingState current, int groupId, long value) { + public static void combine(StdDevStates.GroupingState current, int groupId, long value) { current.add(groupId, value); } public static void combineStates( - StdDeviationStates.GroupingState current, + StdDevStates.GroupingState current, int groupId, - StdDeviationStates.GroupingState state, + StdDevStates.GroupingState state, int statePosition ) { current.combine(groupId, state.getOrNull(statePosition)); } - public static void combineIntermediate(StdDeviationStates.GroupingState state, int groupId, double mean, double m2, long count) { + public static void combineIntermediate(StdDevStates.GroupingState state, int groupId, double mean, double m2, long count) { state.combine(groupId, mean, m2, count); } - public static Block evaluateFinal(StdDeviationStates.GroupingState state, IntVector selected, DriverContext driverContext) { + public static Block evaluateFinal(StdDevStates.GroupingState state, IntVector selected, DriverContext driverContext) { try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount())) { for (int i = 0; i < selected.getPositionCount(); i++) { final var groupId = selected.getInt(i); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleAggregatorFunction.java index c6ba833b3499b..dd6cc89401a99 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleAggregatorFunction.java @@ -31,12 +31,12 @@ public final class StdDevDoubleAggregatorFunction implements AggregatorFunction private final DriverContext driverContext; - private final StdDeviationStates.SingleState state; + private final StdDevStates.SingleState state; private final List channels; public StdDevDoubleAggregatorFunction(DriverContext driverContext, List channels, - StdDeviationStates.SingleState state) { + StdDevStates.SingleState state) { this.driverContext = driverContext; this.channels = channels; this.state = state; diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleGroupingAggregatorFunction.java index 663b6661aaec8..da49c254e353a 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleGroupingAggregatorFunction.java @@ -30,14 +30,14 @@ public final class StdDevDoubleGroupingAggregatorFunction implements GroupingAgg new IntermediateStateDesc("m2", ElementType.DOUBLE), new IntermediateStateDesc("count", ElementType.LONG) ); - private final StdDeviationStates.GroupingState state; + private final StdDevStates.GroupingState state; private final List channels; private final DriverContext driverContext; public StdDevDoubleGroupingAggregatorFunction(List channels, - StdDeviationStates.GroupingState state, DriverContext driverContext) { + StdDevStates.GroupingState state, DriverContext driverContext) { this.channels = channels; this.state = state; this.driverContext = driverContext; @@ -191,7 +191,7 @@ public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction inpu if (input.getClass() != getClass()) { throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); } - StdDeviationStates.GroupingState inState = ((StdDevDoubleGroupingAggregatorFunction) input).state; + StdDevStates.GroupingState inState = ((StdDevDoubleGroupingAggregatorFunction) input).state; state.enableGroupIdTracking(new SeenGroupIds.Empty()); StdDevDoubleAggregator.combineStates(state, groupId, inState, position); } diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatAggregatorFunction.java index cd2f5b88932cb..bf8c4854f6b93 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatAggregatorFunction.java @@ -33,12 +33,12 @@ public final class StdDevFloatAggregatorFunction implements AggregatorFunction { private final DriverContext driverContext; - private final StdDeviationStates.SingleState state; + private final StdDevStates.SingleState state; private final List channels; public StdDevFloatAggregatorFunction(DriverContext driverContext, List channels, - StdDeviationStates.SingleState state) { + StdDevStates.SingleState state) { this.driverContext = driverContext; this.channels = channels; this.state = state; diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatGroupingAggregatorFunction.java index cf79620e19ace..bf994aaf2840e 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatGroupingAggregatorFunction.java @@ -32,14 +32,14 @@ public final class StdDevFloatGroupingAggregatorFunction implements GroupingAggr new IntermediateStateDesc("m2", ElementType.DOUBLE), new IntermediateStateDesc("count", ElementType.LONG) ); - private final StdDeviationStates.GroupingState state; + private final StdDevStates.GroupingState state; private final List channels; private final DriverContext driverContext; public StdDevFloatGroupingAggregatorFunction(List channels, - StdDeviationStates.GroupingState state, DriverContext driverContext) { + StdDevStates.GroupingState state, DriverContext driverContext) { this.channels = channels; this.state = state; this.driverContext = driverContext; @@ -193,7 +193,7 @@ public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction inpu if (input.getClass() != getClass()) { throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); } - StdDeviationStates.GroupingState inState = ((StdDevFloatGroupingAggregatorFunction) input).state; + StdDevStates.GroupingState inState = ((StdDevFloatGroupingAggregatorFunction) input).state; state.enableGroupIdTracking(new SeenGroupIds.Empty()); StdDevFloatAggregator.combineStates(state, groupId, inState, position); } diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntAggregatorFunction.java index a499ae4698819..4a5585a7dd454 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntAggregatorFunction.java @@ -33,12 +33,12 @@ public final class StdDevIntAggregatorFunction implements AggregatorFunction { private final DriverContext driverContext; - private final StdDeviationStates.SingleState state; + private final StdDevStates.SingleState state; private final List channels; public StdDevIntAggregatorFunction(DriverContext driverContext, List channels, - StdDeviationStates.SingleState state) { + StdDevStates.SingleState state) { this.driverContext = driverContext; this.channels = channels; this.state = state; diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntGroupingAggregatorFunction.java index 6196f89f3d300..139cc24d3541f 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntGroupingAggregatorFunction.java @@ -30,14 +30,14 @@ public final class StdDevIntGroupingAggregatorFunction implements GroupingAggreg new IntermediateStateDesc("m2", ElementType.DOUBLE), new IntermediateStateDesc("count", ElementType.LONG) ); - private final StdDeviationStates.GroupingState state; + private final StdDevStates.GroupingState state; private final List channels; private final DriverContext driverContext; public StdDevIntGroupingAggregatorFunction(List channels, - StdDeviationStates.GroupingState state, DriverContext driverContext) { + StdDevStates.GroupingState state, DriverContext driverContext) { this.channels = channels; this.state = state; this.driverContext = driverContext; @@ -191,7 +191,7 @@ public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction inpu if (input.getClass() != getClass()) { throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); } - StdDeviationStates.GroupingState inState = ((StdDevIntGroupingAggregatorFunction) input).state; + StdDevStates.GroupingState inState = ((StdDevIntGroupingAggregatorFunction) input).state; state.enableGroupIdTracking(new SeenGroupIds.Empty()); StdDevIntAggregator.combineStates(state, groupId, inState, position); } diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongAggregatorFunction.java index ec5f38833ff0b..b5ed31116a90c 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongAggregatorFunction.java @@ -31,12 +31,12 @@ public final class StdDevLongAggregatorFunction implements AggregatorFunction { private final DriverContext driverContext; - private final StdDeviationStates.SingleState state; + private final StdDevStates.SingleState state; private final List channels; public StdDevLongAggregatorFunction(DriverContext driverContext, List channels, - StdDeviationStates.SingleState state) { + StdDevStates.SingleState state) { this.driverContext = driverContext; this.channels = channels; this.state = state; diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongGroupingAggregatorFunction.java index 7e16fa2967c74..da7a5f4bdea0d 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongGroupingAggregatorFunction.java @@ -30,14 +30,14 @@ public final class StdDevLongGroupingAggregatorFunction implements GroupingAggre new IntermediateStateDesc("m2", ElementType.DOUBLE), new IntermediateStateDesc("count", ElementType.LONG) ); - private final StdDeviationStates.GroupingState state; + private final StdDevStates.GroupingState state; private final List channels; private final DriverContext driverContext; public StdDevLongGroupingAggregatorFunction(List channels, - StdDeviationStates.GroupingState state, DriverContext driverContext) { + StdDevStates.GroupingState state, DriverContext driverContext) { this.channels = channels; this.state = state; this.driverContext = driverContext; @@ -191,7 +191,7 @@ public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction inpu if (input.getClass() != getClass()) { throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); } - StdDeviationStates.GroupingState inState = ((StdDevLongGroupingAggregatorFunction) input).state; + StdDevStates.GroupingState inState = ((StdDevLongGroupingAggregatorFunction) input).state; state.enableGroupIdTracking(new SeenGroupIds.Empty()); StdDevLongAggregator.combineStates(state, groupId, inState, position); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDeviationStates.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDevStates.java similarity index 98% rename from x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDeviationStates.java rename to x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDevStates.java index 03cb6c49ce04f..9e2a8f3381a65 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDeviationStates.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDevStates.java @@ -15,9 +15,9 @@ import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasables; -public final class StdDeviationStates { +public final class StdDevStates { - private StdDeviationStates() {} + private StdDevStates() {} static final class SingleState implements AggregatorState { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDevAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDevAggregator.java.st index cf52ee572a727..74385b5013308 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDevAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDevAggregator.java.st @@ -18,7 +18,7 @@ import org.elasticsearch.compute.operator.DriverContext; /** * A standard deviation aggregation definition for $type$. - * This class is generated. Edit `X-StdDeviationAggregator.java.st` instead. + * This class is generated. Edit `X-StdDevAggregator.java.st` instead. */ @Aggregator( { @@ -29,19 +29,19 @@ import org.elasticsearch.compute.operator.DriverContext; @GroupingAggregator public class StdDev$Type$Aggregator { - public static StdDeviationStates.SingleState initSingle() { - return new StdDeviationStates.SingleState(); + public static StdDevStates.SingleState initSingle() { + return new StdDevStates.SingleState(); } - public static void combine(StdDeviationStates.SingleState state, $type$ value) { + public static void combine(StdDevStates.SingleState state, $type$ value) { state.add(value); } - public static void combineIntermediate(StdDeviationStates.SingleState state, double mean, double m2, long count) { + public static void combineIntermediate(StdDevStates.SingleState state, double mean, double m2, long count) { state.combine(mean, m2, count); } - public static Block evaluateFinal(StdDeviationStates.SingleState state, DriverContext driverContext) { + public static Block evaluateFinal(StdDevStates.SingleState state, DriverContext driverContext) { final long count = state.count(); final double m2 = state.m2(); if (count == 0 || Double.isFinite(m2) == false) { @@ -50,28 +50,28 @@ public class StdDev$Type$Aggregator { return driverContext.blockFactory().newConstantDoubleBlockWith(state.evaluateFinal(), 1); } - public static StdDeviationStates.GroupingState initGrouping(BigArrays bigArrays) { - return new StdDeviationStates.GroupingState(bigArrays); + public static StdDevStates.GroupingState initGrouping(BigArrays bigArrays) { + return new StdDevStates.GroupingState(bigArrays); } - public static void combine(StdDeviationStates.GroupingState current, int groupId, $type$ value) { + public static void combine(StdDevStates.GroupingState current, int groupId, $type$ value) { current.add(groupId, value); } public static void combineStates( - StdDeviationStates.GroupingState current, + StdDevStates.GroupingState current, int groupId, - StdDeviationStates.GroupingState state, + StdDevStates.GroupingState state, int statePosition ) { current.combine(groupId, state.getOrNull(statePosition)); } - public static void combineIntermediate(StdDeviationStates.GroupingState state, int groupId, double mean, double m2, long count) { + public static void combineIntermediate(StdDevStates.GroupingState state, int groupId, double mean, double m2, long count) { state.combine(groupId, mean, m2, count); } - public static Block evaluateFinal(StdDeviationStates.GroupingState state, IntVector selected, DriverContext driverContext) { + public static Block evaluateFinal(StdDevStates.GroupingState state, IntVector selected, DriverContext driverContext) { try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount())) { for (int i = 0; i < selected.getPositionCount(); i++) { final var groupId = selected.getInt(i); From 8a5ac5c13e61cd8af27b0aba4d335ca6dfd7c464 Mon Sep 17 00:00:00 2001 From: Larisa Motova Date: Thu, 14 Nov 2024 15:45:42 -1000 Subject: [PATCH 12/20] linting continues --- .../compute/aggregation/StdDevDoubleAggregator.java | 7 +------ .../compute/aggregation/StdDevFloatAggregator.java | 7 +------ .../compute/aggregation/StdDevIntAggregator.java | 7 +------ .../compute/aggregation/StdDevLongAggregator.java | 7 +------ .../compute/aggregation/X-StdDevAggregator.java.st | 7 +------ 5 files changed, 5 insertions(+), 30 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevDoubleAggregator.java index 67453422a6b74..82d7082858dc1 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevDoubleAggregator.java @@ -58,12 +58,7 @@ public static void combine(StdDevStates.GroupingState current, int groupId, doub current.add(groupId, value); } - public static void combineStates( - StdDevStates.GroupingState current, - int groupId, - StdDevStates.GroupingState state, - int statePosition - ) { + public static void combineStates(StdDevStates.GroupingState current, int groupId, StdDevStates.GroupingState state, int statePosition) { current.combine(groupId, state.getOrNull(statePosition)); } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevFloatAggregator.java index e24cf03752862..bb5d1ef81523e 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevFloatAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevFloatAggregator.java @@ -58,12 +58,7 @@ public static void combine(StdDevStates.GroupingState current, int groupId, floa current.add(groupId, value); } - public static void combineStates( - StdDevStates.GroupingState current, - int groupId, - StdDevStates.GroupingState state, - int statePosition - ) { + public static void combineStates(StdDevStates.GroupingState current, int groupId, StdDevStates.GroupingState state, int statePosition) { current.combine(groupId, state.getOrNull(statePosition)); } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevIntAggregator.java index c773e6737f950..f2eba851c7523 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevIntAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevIntAggregator.java @@ -58,12 +58,7 @@ public static void combine(StdDevStates.GroupingState current, int groupId, int current.add(groupId, value); } - public static void combineStates( - StdDevStates.GroupingState current, - int groupId, - StdDevStates.GroupingState state, - int statePosition - ) { + public static void combineStates(StdDevStates.GroupingState current, int groupId, StdDevStates.GroupingState state, int statePosition) { current.combine(groupId, state.getOrNull(statePosition)); } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevLongAggregator.java index 149b0c1eedb9d..d2bb8b9730d56 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevLongAggregator.java @@ -58,12 +58,7 @@ public static void combine(StdDevStates.GroupingState current, int groupId, long current.add(groupId, value); } - public static void combineStates( - StdDevStates.GroupingState current, - int groupId, - StdDevStates.GroupingState state, - int statePosition - ) { + public static void combineStates(StdDevStates.GroupingState current, int groupId, StdDevStates.GroupingState state, int statePosition) { current.combine(groupId, state.getOrNull(statePosition)); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDevAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDevAggregator.java.st index 74385b5013308..6f92a55986d10 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDevAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDevAggregator.java.st @@ -58,12 +58,7 @@ public class StdDev$Type$Aggregator { current.add(groupId, value); } - public static void combineStates( - StdDevStates.GroupingState current, - int groupId, - StdDevStates.GroupingState state, - int statePosition - ) { + public static void combineStates(StdDevStates.GroupingState current, int groupId, StdDevStates.GroupingState state, int statePosition) { current.combine(groupId, state.getOrNull(statePosition)); } From 93b36db7befe828559dec7df80304ae5da5cf5b1 Mon Sep 17 00:00:00 2001 From: Larisa Motova Date: Tue, 19 Nov 2024 17:24:27 -1000 Subject: [PATCH 13/20] move evaluate final to states and fix entry --- .../aggregation/StdDevDoubleAggregator.java | 26 ++------------- .../aggregation/StdDevFloatAggregator.java | 26 ++------------- .../aggregation/StdDevIntAggregator.java | 26 ++------------- .../aggregation/StdDevLongAggregator.java | 26 ++------------- .../compute/aggregation/StdDevStates.java | 32 ++++++++++++++++++- .../aggregation/X-StdDevAggregator.java.st | 26 ++------------- .../aggregate/AggregateWritables.java | 1 + 7 files changed, 42 insertions(+), 121 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevDoubleAggregator.java index 82d7082858dc1..3a1185d34fa23 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevDoubleAggregator.java @@ -12,7 +12,6 @@ import org.elasticsearch.compute.ann.GroupingAggregator; import org.elasticsearch.compute.ann.IntermediateState; import org.elasticsearch.compute.data.Block; -import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.operator.DriverContext; @@ -42,12 +41,7 @@ public static void combineIntermediate(StdDevStates.SingleState state, double me } public static Block evaluateFinal(StdDevStates.SingleState state, DriverContext driverContext) { - final long count = state.count(); - final double m2 = state.m2(); - if (count == 0 || Double.isFinite(m2) == false) { - return driverContext.blockFactory().newConstantNullBlock(1); - } - return driverContext.blockFactory().newConstantDoubleBlockWith(state.evaluateFinal(), 1); + return state.evaluateFinal(driverContext); } public static StdDevStates.GroupingState initGrouping(BigArrays bigArrays) { @@ -67,22 +61,6 @@ public static void combineIntermediate(StdDevStates.GroupingState state, int gro } public static Block evaluateFinal(StdDevStates.GroupingState state, IntVector selected, DriverContext driverContext) { - try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount())) { - for (int i = 0; i < selected.getPositionCount(); i++) { - final var groupId = selected.getInt(i); - final var st = state.getOrNull(groupId); - if (st != null) { - final var m2 = st.m2(); - if (Double.isFinite(m2) == false) { - builder.appendNull(); - } else { - builder.appendDouble(st.evaluateFinal()); - } - } else { - builder.appendNull(); - } - } - return builder.build(); - } + return state.evaluateFinal(selected, driverContext); } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevFloatAggregator.java index bb5d1ef81523e..51c22e7e29c1e 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevFloatAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevFloatAggregator.java @@ -12,7 +12,6 @@ import org.elasticsearch.compute.ann.GroupingAggregator; import org.elasticsearch.compute.ann.IntermediateState; import org.elasticsearch.compute.data.Block; -import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.operator.DriverContext; @@ -42,12 +41,7 @@ public static void combineIntermediate(StdDevStates.SingleState state, double me } public static Block evaluateFinal(StdDevStates.SingleState state, DriverContext driverContext) { - final long count = state.count(); - final double m2 = state.m2(); - if (count == 0 || Double.isFinite(m2) == false) { - return driverContext.blockFactory().newConstantNullBlock(1); - } - return driverContext.blockFactory().newConstantDoubleBlockWith(state.evaluateFinal(), 1); + return state.evaluateFinal(driverContext); } public static StdDevStates.GroupingState initGrouping(BigArrays bigArrays) { @@ -67,22 +61,6 @@ public static void combineIntermediate(StdDevStates.GroupingState state, int gro } public static Block evaluateFinal(StdDevStates.GroupingState state, IntVector selected, DriverContext driverContext) { - try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount())) { - for (int i = 0; i < selected.getPositionCount(); i++) { - final var groupId = selected.getInt(i); - final var st = state.getOrNull(groupId); - if (st != null) { - final var m2 = st.m2(); - if (Double.isFinite(m2) == false) { - builder.appendNull(); - } else { - builder.appendDouble(st.evaluateFinal()); - } - } else { - builder.appendNull(); - } - } - return builder.build(); - } + return state.evaluateFinal(selected, driverContext); } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevIntAggregator.java index f2eba851c7523..24eae35cb3249 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevIntAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevIntAggregator.java @@ -12,7 +12,6 @@ import org.elasticsearch.compute.ann.GroupingAggregator; import org.elasticsearch.compute.ann.IntermediateState; import org.elasticsearch.compute.data.Block; -import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.operator.DriverContext; @@ -42,12 +41,7 @@ public static void combineIntermediate(StdDevStates.SingleState state, double me } public static Block evaluateFinal(StdDevStates.SingleState state, DriverContext driverContext) { - final long count = state.count(); - final double m2 = state.m2(); - if (count == 0 || Double.isFinite(m2) == false) { - return driverContext.blockFactory().newConstantNullBlock(1); - } - return driverContext.blockFactory().newConstantDoubleBlockWith(state.evaluateFinal(), 1); + return state.evaluateFinal(driverContext); } public static StdDevStates.GroupingState initGrouping(BigArrays bigArrays) { @@ -67,22 +61,6 @@ public static void combineIntermediate(StdDevStates.GroupingState state, int gro } public static Block evaluateFinal(StdDevStates.GroupingState state, IntVector selected, DriverContext driverContext) { - try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount())) { - for (int i = 0; i < selected.getPositionCount(); i++) { - final var groupId = selected.getInt(i); - final var st = state.getOrNull(groupId); - if (st != null) { - final var m2 = st.m2(); - if (Double.isFinite(m2) == false) { - builder.appendNull(); - } else { - builder.appendDouble(st.evaluateFinal()); - } - } else { - builder.appendNull(); - } - } - return builder.build(); - } + return state.evaluateFinal(selected, driverContext); } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevLongAggregator.java index d2bb8b9730d56..888ace30a0c8e 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevLongAggregator.java @@ -12,7 +12,6 @@ import org.elasticsearch.compute.ann.GroupingAggregator; import org.elasticsearch.compute.ann.IntermediateState; import org.elasticsearch.compute.data.Block; -import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.operator.DriverContext; @@ -42,12 +41,7 @@ public static void combineIntermediate(StdDevStates.SingleState state, double me } public static Block evaluateFinal(StdDevStates.SingleState state, DriverContext driverContext) { - final long count = state.count(); - final double m2 = state.m2(); - if (count == 0 || Double.isFinite(m2) == false) { - return driverContext.blockFactory().newConstantNullBlock(1); - } - return driverContext.blockFactory().newConstantDoubleBlockWith(state.evaluateFinal(), 1); + return state.evaluateFinal(driverContext); } public static StdDevStates.GroupingState initGrouping(BigArrays bigArrays) { @@ -67,22 +61,6 @@ public static void combineIntermediate(StdDevStates.GroupingState state, int gro } public static Block evaluateFinal(StdDevStates.GroupingState state, IntVector selected, DriverContext driverContext) { - try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount())) { - for (int i = 0; i < selected.getPositionCount(); i++) { - final var groupId = selected.getInt(i); - final var st = state.getOrNull(groupId); - if (st != null) { - final var m2 = st.m2(); - if (Double.isFinite(m2) == false) { - builder.appendNull(); - } else { - builder.appendDouble(st.evaluateFinal()); - } - } else { - builder.appendNull(); - } - } - return builder.build(); - } + return state.evaluateFinal(selected, driverContext); } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDevStates.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDevStates.java index 9e2a8f3381a65..67dbb7a7bfd1b 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDevStates.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDevStates.java @@ -11,6 +11,7 @@ import org.elasticsearch.common.util.ObjectArray; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasables; @@ -21,7 +22,7 @@ private StdDevStates() {} static final class SingleState implements AggregatorState { - private WelfordAlgorithm welfordAlgorithm; + private final WelfordAlgorithm welfordAlgorithm; SingleState() { this(0, 0, 0); @@ -74,6 +75,15 @@ public long count() { public double evaluateFinal() { return welfordAlgorithm.evaluate(); } + + public Block evaluateFinal(DriverContext driverContext) { + final long count = count(); + final double m2 = m2(); + if (count == 0 || Double.isFinite(m2) == false) { + return driverContext.blockFactory().newConstantNullBlock(1); + } + return driverContext.blockFactory().newConstantDoubleBlockWith(evaluateFinal(), 1); + } } static final class GroupingState implements GroupingAggregatorState { @@ -168,6 +178,26 @@ public void toIntermediate(Block[] blocks, int offset, IntVector selected, Drive } } + public Block evaluateFinal(IntVector selected, DriverContext driverContext) { + try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount())) { + for (int i = 0; i < selected.getPositionCount(); i++) { + final var groupId = selected.getInt(i); + final var st = getOrNull(groupId); + if (st != null) { + final var m2 = st.m2(); + if (Double.isFinite(m2) == false) { + builder.appendNull(); + } else { + builder.appendDouble(st.evaluateFinal()); + } + } else { + builder.appendNull(); + } + } + return builder.build(); + } + } + @Override public void close() { Releasables.close(states); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDevAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDevAggregator.java.st index 6f92a55986d10..510d770f90d62 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDevAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDevAggregator.java.st @@ -12,7 +12,6 @@ import org.elasticsearch.compute.ann.Aggregator; import org.elasticsearch.compute.ann.GroupingAggregator; import org.elasticsearch.compute.ann.IntermediateState; import org.elasticsearch.compute.data.Block; -import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.operator.DriverContext; @@ -42,12 +41,7 @@ public class StdDev$Type$Aggregator { } public static Block evaluateFinal(StdDevStates.SingleState state, DriverContext driverContext) { - final long count = state.count(); - final double m2 = state.m2(); - if (count == 0 || Double.isFinite(m2) == false) { - return driverContext.blockFactory().newConstantNullBlock(1); - } - return driverContext.blockFactory().newConstantDoubleBlockWith(state.evaluateFinal(), 1); + return state.evaluateFinal(driverContext); } public static StdDevStates.GroupingState initGrouping(BigArrays bigArrays) { @@ -67,22 +61,6 @@ public class StdDev$Type$Aggregator { } public static Block evaluateFinal(StdDevStates.GroupingState state, IntVector selected, DriverContext driverContext) { - try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount())) { - for (int i = 0; i < selected.getPositionCount(); i++) { - final var groupId = selected.getInt(i); - final var st = state.getOrNull(groupId); - if (st != null) { - final var m2 = st.m2(); - if (Double.isFinite(m2) == false) { - builder.appendNull(); - } else { - builder.appendDouble(st.evaluateFinal()); - } - } else { - builder.appendNull(); - } - } - return builder.build(); - } + return state.evaluateFinal(selected, driverContext); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateWritables.java index b9cfd8892dd69..d74b5c8b386b8 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateWritables.java @@ -25,6 +25,7 @@ public static List getNamedWriteables() { Percentile.ENTRY, Rate.ENTRY, SpatialCentroid.ENTRY, + StdDev.ENTRY, Sum.ENTRY, Top.ENTRY, Values.ENTRY, From 99411a917646ee3ac5d6cf21905718a0c80808e0 Mon Sep 17 00:00:00 2001 From: Larisa Motova Date: Tue, 19 Nov 2024 18:22:50 -1000 Subject: [PATCH 14/20] Update docs/changelog/116531.yaml --- docs/changelog/116531.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/116531.yaml diff --git a/docs/changelog/116531.yaml b/docs/changelog/116531.yaml new file mode 100644 index 0000000000000..ef1dd7f83f49c --- /dev/null +++ b/docs/changelog/116531.yaml @@ -0,0 +1,5 @@ +pr: 116531 +summary: "[ES|QL] Add a standard deviation function" +area: ES|QL +type: enhancement +issues: [] From bc715f4fc049c2a92f808082b4bf45ae5fbeb407 Mon Sep 17 00:00:00 2001 From: Larisa Motova Date: Wed, 20 Nov 2024 14:18:37 -1000 Subject: [PATCH 15/20] change SingleState to WelfordAlgorithm and change tests --- .../compute/aggregation/StdDevStates.java | 16 ++--- .../src/main/resources/stats.csv-spec | 62 ++++--------------- .../function/aggregate/StdDevTests.java | 18 +----- 3 files changed, 21 insertions(+), 75 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDevStates.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDevStates.java index 67dbb7a7bfd1b..d748728fc21de 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDevStates.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDevStates.java @@ -88,7 +88,7 @@ public Block evaluateFinal(DriverContext driverContext) { static final class GroupingState implements GroupingAggregatorState { - private ObjectArray states; + private ObjectArray states; private final BigArrays bigArrays; GroupingState(BigArrays bigArrays) { @@ -96,7 +96,7 @@ static final class GroupingState implements GroupingAggregatorState { this.bigArrays = bigArrays; } - SingleState getOrNull(int position) { + WelfordAlgorithm getOrNull(int position) { if (position < states.size()) { return states.get(position); } else { @@ -104,7 +104,7 @@ SingleState getOrNull(int position) { } } - public void combine(int groupId, SingleState state) { + public void combine(int groupId, WelfordAlgorithm state) { if (state == null) { return; } @@ -115,18 +115,18 @@ public void combine(int groupId, double meanValue, double m2Value, long countVal ensureCapacity(groupId); var state = states.get(groupId); if (state == null) { - state = new SingleState(meanValue, m2Value, countValue); + state = new WelfordAlgorithm(meanValue, m2Value, countValue); states.set(groupId, state); } else { - state.combine(meanValue, m2Value, countValue); + state.add(meanValue, m2Value, countValue); } } - public SingleState getOrSet(int groupId) { + public WelfordAlgorithm getOrSet(int groupId) { ensureCapacity(groupId); var state = states.get(groupId); if (state == null) { - state = new SingleState(); + state = new WelfordAlgorithm(); states.set(groupId, state); } return state; @@ -188,7 +188,7 @@ public Block evaluateFinal(IntVector selected, DriverContext driverContext) { if (Double.isFinite(m2) == false) { builder.appendNull(); } else { - builder.appendDouble(st.evaluateFinal()); + builder.appendDouble(st.evaluate()); } } else { builder.appendNull(); diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec index 83b1772addbe3..de2678c70b693 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec @@ -2908,7 +2908,7 @@ STD_DEV(languages):double 0.0 ; -stdDeviationGrouped +stdDeviationGroupedDoublesOnly required_capability: std_dev FROM employees | STATS STD_DEV(height) BY languages @@ -2924,59 +2924,21 @@ STD_DEV(height):double | languages:integer 0.2486543786061287 | null ; -stdDeviationGrouped1 +stdDeviationGroupedAllTypes required_capability: std_dev FROM employees -| WHERE languages == 1 -| STATS STD_DEV(height) -; - -STD_DEV(height):double -0.22106409327010415 -; - -stdDeviationGrouped2 -required_capability: std_dev -FROM employees -| WHERE languages == 2 -| STATS STD_DEV(height) -; - -STD_DEV(height):double -0.22797190865484734 -; - -stdDeviationGrouped3 -required_capability: std_dev -FROM employees -| WHERE languages == 3 -| STATS STD_DEV(height) -; - -STD_DEV(height):double -0.18893070075713295 -; - -stdDeviationGrouped4 -required_capability: std_dev -FROM employees -| WHERE languages == 4 -| STATS STD_DEV(height) -; - -STD_DEV(height):double -0.14656141004227627 -; - -stdDeviationGrouped5 -required_capability: std_dev -FROM employees -| WHERE languages == 5 -| STATS STD_DEV(height) +| WHERE languages < 3 +| STATS + double_std_dev = STD_DEV(height), + int_std_dev = STD_DEV(salary), + long_std_dev = STD_DEV(avg_worked_seconds) + BY languages +| SORT languages asc ; -STD_DEV(height):double -0.17733860152780256 +double_std_dev:double | int_std_dev:double | long_std_dev:double | languages:integer +0.22106409327010415 | 15166.244178730898 | 5.1998715922156096E7 | 1 +0.22797190865484734 | 12139.61099378116 | 5.309085506583288E7 | 2 ; stdDeviationNoRows diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDevTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDevTests.java index faec43dc563c2..85b96e29d1f6a 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDevTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDevTests.java @@ -25,7 +25,6 @@ import java.util.stream.Stream; import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.nullValue; public class StdDevTests extends AbstractAggregationTestCase { public StdDevTests(@Name("TestCase") Supplier testCaseSupplier) { @@ -42,22 +41,7 @@ public static Iterable parameters() { MultiRowTestCaseSupplier.doubleCases(1, 1000, -Double.MAX_VALUE, Double.MAX_VALUE, true) ).flatMap(List::stream).map(StdDevTests::makeSupplier).collect(Collectors.toCollection(() -> suppliers)); - // No rows - for (var dataType : List.of(DataType.INTEGER, DataType.LONG, DataType.DOUBLE)) { - suppliers.add( - new TestCaseSupplier( - "No rows (" + dataType + ")", - List.of(dataType), - () -> new TestCaseSupplier.TestCase( - List.of(TestCaseSupplier.TypedData.multiRow(List.of(), dataType, "field")), - "StdDev[field=Attribute[channel=0]]", - DataType.DOUBLE, - nullValue() - ) - ) - ); - } - return parameterSuppliersFromTypedData(randomizeBytesRefsOffset(suppliers)); + return parameterSuppliersFromTypedDataWithDefaultChecks(suppliers); } @Override From a3477eff5800152723e9d67b6494d3dbc9b547eb Mon Sep 17 00:00:00 2001 From: Larisa Motova Date: Wed, 20 Nov 2024 14:28:48 -1000 Subject: [PATCH 16/20] dot --- .../org/elasticsearch/compute/aggregation/WelfordAlgorithm.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/WelfordAlgorithm.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/WelfordAlgorithm.java index 29cb7e62e1a8c..8ccb985507247 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/WelfordAlgorithm.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/WelfordAlgorithm.java @@ -8,7 +8,7 @@ package org.elasticsearch.compute.aggregation; /** - * Algorithm for calculating standard deviation, one value at a time + * Algorithm for calculating standard deviation, one value at a time. * * @see * Welford's_online_algorithm and From a476e0366e08a25a83bf237ff7b002448b8a0197 Mon Sep 17 00:00:00 2001 From: Larisa Motova Date: Thu, 21 Nov 2024 15:43:46 -1000 Subject: [PATCH 17/20] tests, resolveType, changelog --- docs/changelog/116531.yaml | 2 +- .../src/main/resources/stats.csv-spec | 35 +++++++++++++++++++ .../expression/function/aggregate/StdDev.java | 13 +++++++ 3 files changed, 49 insertions(+), 1 deletion(-) diff --git a/docs/changelog/116531.yaml b/docs/changelog/116531.yaml index ef1dd7f83f49c..8686fb8f37e84 100644 --- a/docs/changelog/116531.yaml +++ b/docs/changelog/116531.yaml @@ -1,5 +1,5 @@ pr: 116531 -summary: "[ES|QL] Add a standard deviation function" +summary: "Add a standard deviation function" area: ES|QL type: enhancement issues: [] diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec index de2678c70b693..5505f428714b4 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec @@ -2951,3 +2951,38 @@ FROM employees STD_DEV(languages):double null ; + +stdDevMultiValue +required_capability: std_dev +FROM employees +| STATS STD_DEV(salary_change) +; + +STD_DEV(salary_change):double +7.062226788733394 +; + +stdDevFilter +required_capability: std_dev +FROM employees +| STATS greater_than = STD_DEV(salary_change) WHERE languages > 3 +, less_than = STD_DEV(salary_change) WHERE languages <= 3 +, salary = STD_DEV(salary * 2) +, count = COUNT(*) BY gender +; + +greater_than:double | less_than:double | salary:double | count:long | gender:keyword +6.949207097931448 | 7.127229475750027 | 27921.220736207077 | 10 | null +6.975232333891946 | 6.604807075547775 | 26171.331109641273 | 57 | M +6.4543266953142835 | 7.57786788789264 | 29045.770666969744 | 33 | F +; + +stdDevRow +required_capability: std_dev +ROW a = [1,2,3], b = 5 +| STATS STD_DEV(a), STD_DEV(b) +; + +STD_DEV(a):double | STD_DEV(b):double +0.816496580927726 | 0.0 +; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDev.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDev.java index dbd1da348d8ba..189b6a81912cb 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDev.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/StdDev.java @@ -28,6 +28,8 @@ import java.util.List; import static java.util.Collections.emptyList; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; public class StdDev extends AggregateFunction implements ToAggregator { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "StdDev", StdDev::new); @@ -68,6 +70,17 @@ public DataType dataType() { return DataType.DOUBLE; } + @Override + protected Expression.TypeResolution resolveType() { + return isType( + field(), + dt -> dt.isNumeric() && dt != DataType.UNSIGNED_LONG, + sourceText(), + DEFAULT, + "numeric except unsigned_long or counter types" + ); + } + @Override protected NodeInfo info() { return NodeInfo.create(this, StdDev::new, field(), filter()); From c59efab5d8975802e63d3c6975b42302797cb57c Mon Sep 17 00:00:00 2001 From: Larisa Motova Date: Thu, 21 Nov 2024 16:35:53 -1000 Subject: [PATCH 18/20] specify order of filter test --- .../esql/qa/testFixtures/src/main/resources/stats.csv-spec | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec index f1fbca6693ee0..eb88f4b9dd1ef 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec @@ -2989,12 +2989,13 @@ FROM employees , less_than = STD_DEV(salary_change) WHERE languages <= 3 , salary = STD_DEV(salary * 2) , count = COUNT(*) BY gender +| SORT gender asc ; greater_than:double | less_than:double | salary:double | count:long | gender:keyword -6.949207097931448 | 7.127229475750027 | 27921.220736207077 | 10 | null -6.975232333891946 | 6.604807075547775 | 26171.331109641273 | 57 | M 6.4543266953142835 | 7.57786788789264 | 29045.770666969744 | 33 | F +6.975232333891946 | 6.604807075547775 | 26171.331109641273 | 57 | M +6.949207097931448 | 7.127229475750027 | 27921.220736207077 | 10 | null ; stdDevRow From 03e2a70748626e68da8ee64ea0ef7d0f5f8039e9 Mon Sep 17 00:00:00 2001 From: Larisa Motova Date: Fri, 22 Nov 2024 09:46:40 -1000 Subject: [PATCH 19/20] fix changelog --- docs/changelog/116531.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/changelog/116531.yaml b/docs/changelog/116531.yaml index 8686fb8f37e84..908bbff487973 100644 --- a/docs/changelog/116531.yaml +++ b/docs/changelog/116531.yaml @@ -1,5 +1,5 @@ pr: 116531 -summary: "Add a standard deviation function" +summary: "Add a standard deviation aggregating function: STD_DEV" area: ES|QL type: enhancement issues: [] From 827af56e1f868dea4caba847db36393fdb4146bb Mon Sep 17 00:00:00 2001 From: Larisa Motova Date: Fri, 22 Nov 2024 10:38:38 -1000 Subject: [PATCH 20/20] set intermediate values to 0 when null --- .../elasticsearch/compute/aggregation/StdDevStates.java | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDevStates.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDevStates.java index d748728fc21de..bff8903fd3bec 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDevStates.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDevStates.java @@ -167,9 +167,9 @@ public void toIntermediate(Block[] blocks, int offset, IntVector selected, Drive m2Builder.appendDouble(state.m2()); countBuilder.appendLong(state.count()); } else { - meanBuilder.appendNull(); - m2Builder.appendNull(); - countBuilder.appendNull(); + meanBuilder.appendDouble(0.0); + m2Builder.appendDouble(0.0); + countBuilder.appendLong(0); } } blocks[offset + 0] = meanBuilder.build(); @@ -185,7 +185,8 @@ public Block evaluateFinal(IntVector selected, DriverContext driverContext) { final var st = getOrNull(groupId); if (st != null) { final var m2 = st.m2(); - if (Double.isFinite(m2) == false) { + final var count = st.count(); + if (count == 0 || Double.isFinite(m2) == false) { builder.appendNull(); } else { builder.appendDouble(st.evaluate());