From 29742e6b96dea63f568e83a53e75164f3957d46c Mon Sep 17 00:00:00 2001 From: Pablo Date: Fri, 24 Oct 2025 20:37:54 -0700 Subject: [PATCH 01/27] Initial steps of implementing Derivative using linear regression --- .../DerivGroupingAggregatorFunction.java | 296 ++++++++++++++++++ .../SimpleLinearRegressionWithTimeseries.java | 45 +++ .../main/resources/k8s-timeseries.csv-spec | 21 ++ .../function/EsqlFunctionRegistry.java | 2 + .../aggregate/AggregateWritables.java | 1 + .../expression/function/aggregate/Deriv.java | 101 ++++++ 6 files changed, 466 insertions(+) create mode 100644 x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivGroupingAggregatorFunction.java create mode 100644 x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SimpleLinearRegressionWithTimeseries.java create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..b7e6de3e0e74f --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivGroupingAggregatorFunction.java @@ -0,0 +1,296 @@ +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.util.ObjectArray; +import org.elasticsearch.compute.aggregation.GroupingAggregatorEvaluationContext; +import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction; +import org.elasticsearch.compute.aggregation.SeenGroupIds; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; + +import java.util.List; + +public class DerivGroupingAggregatorFunction implements GroupingAggregatorFunction { + + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("sumVal", ElementType.DOUBLE), + new IntermediateStateDesc("sumTs", ElementType.LONG), + new IntermediateStateDesc("sumTsVal", ElementType.DOUBLE), + new IntermediateStateDesc("sumTsSq", ElementType.LONG), + new IntermediateStateDesc("count", ElementType.LONG) + ); + + private final List channels; + private final DriverContext driverContext; + private ObjectArray states; + + public DerivGroupingAggregatorFunction(List channels, DriverContext driverContext) { + this.states = driverContext.bigArrays().newObjectArray(256); + this.channels = channels; + this.driverContext = driverContext; + } + + public static class Supplier implements AggregatorFunctionSupplier { + + @Override + public List nonGroupingIntermediateStateDesc() { + throw new UnsupportedOperationException("DerivGroupingAggregatorFunction does not support non-grouping aggregation"); + } + + @Override + public List groupingIntermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public AggregatorFunction aggregator(DriverContext driverContext, List channels) { + throw new UnsupportedOperationException("DerivGroupingAggregatorFunction does not support non-grouping aggregation"); + } + + @Override + public GroupingAggregatorFunction groupingAggregator(DriverContext driverContext, List channels) { + return new DerivGroupingAggregatorFunction(channels, driverContext); + } + + @Override + public String describe() { + return "derivative"; + } + } + + @Override + public AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { + final DoubleBlock valueBlock = page.getBlock(channels.get(0)); + final LongBlock timestampBlock = page.getBlock(channels.get(1)); + final DoubleVector valueVector = valueBlock.asVector(); + return new AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addBlockInput(positionOffset, groupIds, timestampBlock, valueVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addBlockInput(positionOffset, groupIds, timestampBlock, valueVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + int groupId = groupIds.getInt(groupPosition); + double vValue = valueVector.getDouble(valuesPosition); + long ts = timestampBlock.getLong(valuesPosition); + SimpleLinearRegressionWithTimeseries state = states.get(groupId); + if (state == null) { + state = new SimpleLinearRegressionWithTimeseries(); + states.set(groupId, state); + } + state.add(ts, vValue); + } + } + + @Override + public void close() { + + } + + private void addBlockInput(int positionOffset, IntBlock groupIds, LongBlock timestampBlock, DoubleVector valueVector) { + for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { + if (groupIds.isNull(groupPosition)) { + continue; + } + int valuePosition = groupPosition + positionOffset; + if (valueBlock.isNull(valuePosition)) { + continue; + } + int groupStart = groupIds.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groupIds.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groupIds.getInt(g); + int vStart = valueBlock.getFirstValueIndex(valuePosition); + int vEnd = vStart + valueBlock.getValueCount(valuePosition); + for (int v = vStart; v < vEnd; v++) { + long ts = timestampBlock.getLong(valuePosition); + double val = valueVector.getDouble(valuePosition); + SimpleLinearRegressionWithTimeseries state = states.get(groupId); + if (state == null) { + state = new SimpleLinearRegressionWithTimeseries(); + states.set(groupId, state); + } + state.add(ts, val); + } + } + } + } + }; + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + // No-op + } + + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groupIdVector, Page page) { + addIntermediateBlockInput(positionOffset, groupIdVector, page); + } + + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groupIdVector, Page page) { + addIntermediateBlockInput(positionOffset, groupIdVector, page); + } + + private void addIntermediateBlockInput(int positionOffset, IntBlock groupIdVector, Page page) { + DoubleBlock sumValBlock = page.getBlock(channels.get(0)); + LongBlock sumTsBlock = page.getBlock(channels.get(1)); + DoubleBlock sumTsValBlock = page.getBlock(channels.get(2)); + LongBlock sumTsSqBlock = page.getBlock(channels.get(3)); + LongBlock countBlock = page.getBlock(channels.get(4)); + + if (sumTsBlock.getTotalValueCount() != sumValBlock.getTotalValueCount() + || sumTsBlock.getTotalValueCount() != sumTsValBlock.getTotalValueCount() + || sumTsBlock.getTotalValueCount() != countBlock.getTotalValueCount()) { + throw new IllegalStateException("Mismatched intermediate state block value counts"); + } + + for (int groupPos = 0; groupPos < groupIdVector.getPositionCount(); groupPos++) { + int valuePos = groupPos + positionOffset; + + int firstGroup = groupIdVector.getFirstValueIndex(groupPos); + int lastGroup = firstGroup + groupIdVector.getValueCount(groupPos); + + for (int g = firstGroup; g < lastGroup; g++) { + int groupId = groupIdVector.getInt(g); + states = driverContext.bigArrays().grow(states, groupId + 1); + var state = states.get(groupId); + if (state == null) { + state = new SimpleLinearRegressionWithTimeseries(); + states.set(groupId, state); + } + long sumTs = sumTsBlock.getLong(valuePos); + double sumVal = sumValBlock.getDouble(valuePos); + double sumTsVal = sumTsValBlock.getDouble(valuePos); + long sumTsSq = sumTsSqBlock.getLong(valuePos); + long count = countBlock.getLong(valuePos); + state.sumTs += sumTs; + state.sumVal += sumVal; + state.sumTsVal += sumTsVal; + state.sumTsSq += sumTsSq; + state.count += count; + } + } + + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groupIdVector, Page page) { + DoubleBlock sumValBlock = page.getBlock(channels.get(0)); + LongBlock sumTsBlock = page.getBlock(channels.get(1)); + DoubleBlock sumTsValBlock = page.getBlock(channels.get(2)); + LongBlock sumTsSqBlock = page.getBlock(channels.get(3)); + LongBlock countBlock = page.getBlock(channels.get(4)); + + if (sumTsBlock.getTotalValueCount() != sumValBlock.getTotalValueCount() + || sumTsBlock.getTotalValueCount() != sumTsValBlock.getTotalValueCount() + || sumTsBlock.getTotalValueCount() != countBlock.getTotalValueCount()) { + throw new IllegalStateException("Mismatched intermediate state block value counts"); + } + + for (int groupPos = 0; groupPos < groupIdVector.getPositionCount(); groupPos++) { + int valuePos = groupPos + positionOffset; + int groupId = groupIdVector.getInt(groupPos); + states = driverContext.bigArrays().grow(states, groupId + 1); + var state = states.get(groupId); + if (state == null) { + state = new SimpleLinearRegressionWithTimeseries(); + states.set(groupId, state); + } + long sumTs = sumTsBlock.getLong(valuePos); + double sumVal = sumValBlock.getDouble(valuePos); + double sumTsVal = sumTsValBlock.getDouble(valuePos); + long sumTsSq = sumTsSqBlock.getLong(valuePos); + long count = countBlock.getLong(valuePos); + state.sumTs += sumTs; + state.sumVal += sumVal; + state.sumTsVal += sumTsVal; + state.sumTsSq += sumTsSq; + state.count += count; + } + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + BlockFactory blockFactory = driverContext.blockFactory(); + int positionCount = selected.getPositionCount(); + try ( + var sumValBuilder = blockFactory.newDoubleBlockBuilder(positionCount); + var sumTsBuilder = blockFactory.newLongBlockBuilder(positionCount); + var sumTsValBuilder = blockFactory.newDoubleBlockBuilder(positionCount); + var sumTsSqBuilder = blockFactory.newLongBlockBuilder(positionCount); + var countBuilder = blockFactory.newLongBlockBuilder(positionCount) + ) { + for (int p = 0; p < positionCount; p++) { + int groupId = selected.getInt(p); + SimpleLinearRegressionWithTimeseries state = states.get(groupId); + if (state == null) { + sumValBuilder.appendNull(); + sumTsBuilder.appendNull(); + sumTsValBuilder.appendNull(); + sumTsSqBuilder.appendNull(); + countBuilder.appendNull(); + } else { + sumValBuilder.appendDouble(state.sumVal); + sumTsBuilder.appendLong(state.sumTs); + sumTsValBuilder.appendDouble(state.sumTsVal); + sumTsSqBuilder.appendLong(state.sumTsSq); + countBuilder.appendLong(state.count); + } + } + blocks[offset] = sumValBuilder.build(); + blocks[offset + 1] = sumTsBuilder.build(); + blocks[offset + 2] = sumTsValBuilder.build(); + blocks[offset + 3] = sumTsSqBuilder.build(); + blocks[offset + 4] = countBuilder.build(); + } + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, GroupingAggregatorEvaluationContext evaluationContext) { + BlockFactory blockFactory = driverContext.blockFactory(); + int positionCount = selected.getPositionCount(); + try (var resultBuilder = blockFactory.newDoubleBlockBuilder(positionCount)) { + for (int p = 0; p < positionCount; p++) { + int groupId = selected.getInt(p); + SimpleLinearRegressionWithTimeseries state = states.get(groupId); + if (state == null) { + resultBuilder.appendNull(); + } else { + double deriv = state.slope(); + resultBuilder.appendDouble(deriv); + } + } + blocks[offset] = resultBuilder.build(); + } + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void close() { + Releasables.close(states); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SimpleLinearRegressionWithTimeseries.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SimpleLinearRegressionWithTimeseries.java new file mode 100644 index 0000000000000..f0cef837d5090 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SimpleLinearRegressionWithTimeseries.java @@ -0,0 +1,45 @@ +package org.elasticsearch.compute.aggregation; + +class SimpleLinearRegressionWithTimeseries { + long count; + double sumVal; + long sumTs; + double sumTsVal; + long sumTsSq; + + SimpleLinearRegressionWithTimeseries() { + this.count = 0; + this.sumVal = 0.0; + this.sumTs = 0; + this.sumTsVal = 0.0; + this.sumTsSq = 0; + } + + void add(long ts, double val) { + count++; + sumVal += val; + sumTs += ts; + sumTsVal += ts * val; + sumTsSq += ts * ts; + } + + double slope() { + if (count <= 1) { + return Double.NaN; + } + double numerator = count * sumTsVal - sumTs * sumVal; + double denominator = count * sumTsSq - sumTs * sumTs; + if (denominator == 0) { + return Double.NaN; + } + return numerator / denominator; + } + + double intercept() { + if (count == 0) { + return 0.0; // or handle as needed + } + return (sumVal - slope() * sumTs) / count; + } + +} diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec index 64cdae20e5635..f2a303fabfaaa 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec @@ -594,6 +594,7 @@ max_cost:integer | pod:keyword | time_bucket:datetime | max_cluster_cost:int 973 | two | 2024-05-10T00:00:00.000Z | 1209 | staging ; + max_of_stddev_over_time required_capability: ts_command_v0 required_capability: variance_stddev_over_time @@ -720,4 +721,24 @@ mx:integer | tbucket:datetime 1716 | 2024-05-10T00:00:00.000Z ; +derivative_of_gauge_metric +required_capability: ts_command_v0 +TS k8s +| STATS max_deriv = max(deriv(network.cost)) BY time_bucket = bucket(@timestamp,5minute), pod +| SORT pod, time_bucket +| LIMIT 10 +; + +max_deriv:double | time_bucket:datetime | pod:keyword +2.7573529411764707E-5 | 2024-05-10T00:00:00.000Z | one +-9.657960185083621E-6 | 2024-05-10T00:05:00.000Z | one +2.6771965095388827E-5 | 2024-05-10T00:10:00.000Z | one +4.405946549223768E-5 | 2024-05-10T00:15:00.000Z | one +8.564814814814814E-5 | 2024-05-10T00:20:00.000Z | one +9.27740599107712E-5 | 2024-05-10T00:00:00.000Z | three +3.80263223304832E-5 | 2024-05-10T00:05:00.000Z | three +-1.9699890788329952E-5 | 2024-05-10T00:10:00.000Z | three +-1.2087599544937429E-6 | 2024-05-10T00:15:00.000Z | three +NaN | 2024-05-10T00:20:00.000Z | three +; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index 9031aed5b9e52..91c663d6d57e7 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -28,6 +28,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.CountDistinctOverTime; import org.elasticsearch.xpack.esql.expression.function.aggregate.CountOverTime; import org.elasticsearch.xpack.esql.expression.function.aggregate.Delta; +import org.elasticsearch.xpack.esql.expression.function.aggregate.Deriv; import org.elasticsearch.xpack.esql.expression.function.aggregate.First; import org.elasticsearch.xpack.esql.expression.function.aggregate.FirstOverTime; import org.elasticsearch.xpack.esql.expression.function.aggregate.Idelta; @@ -536,6 +537,7 @@ private static FunctionDefinition[][] functions() { defTS(Idelta.class, bi(Idelta::new), "idelta"), defTS(Delta.class, bi(Delta::new), "delta"), defTS(Increase.class, bi(Increase::new), "increase"), + def(Deriv.class, uni(Deriv::new), "deriv"), def(MaxOverTime.class, uni(MaxOverTime::new), "max_over_time"), def(MinOverTime.class, uni(MinOverTime::new), "min_over_time"), def(SumOverTime.class, uni(SumOverTime::new), "sum_over_time"), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateWritables.java index 5b3bddd89d093..f6aba19c8c671 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateWritables.java @@ -29,6 +29,7 @@ public static List getNamedWriteables() { Idelta.ENTRY, Increase.ENTRY, Delta.ENTRY, + Deriv.ENTRY, Sample.ENTRY, SpatialCentroid.ENTRY, SpatialExtent.ENTRY, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java new file mode 100644 index 0000000000000..27cbb47fcb727 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java @@ -0,0 +1,101 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.DerivGroupingAggregatorFunction; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; +import org.elasticsearch.xpack.esql.expression.function.FunctionType; +import org.elasticsearch.xpack.esql.expression.function.Param; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.planner.ToAggregator; + +import java.util.List; + +/** + * Calculates the derivative over time of a numeric field using linear regression. + */ +public class Deriv extends TimeSeriesAggregateFunction implements ToAggregator { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Deriv", Deriv::new); + private final Expression timestamp; + + @FunctionInfo( + type = FunctionType.TIME_SERIES_AGGREGATE, + returnType = { "double" }, + description = "Calculates the derivative over time of a numeric field using linear regression." + ) + public Deriv(Source source, @Param(name = "field", type = { "long", "integer", "double" }) Expression field) { + this(source, field, new UnresolvedAttribute(source, "@timestamp")); + } + + public Deriv(Source source, Expression field, Expression timestamp) { + super(source, field, Literal.TRUE, List.of(timestamp)); + this.timestamp = timestamp; + } + + public Deriv(Source source, Expression field, Expression filter, Expression timestamp) { + super(source, field, filter, List.of(timestamp)); + this.timestamp = timestamp; + } + + private Deriv(org.elasticsearch.common.io.stream.StreamInput in) throws java.io.IOException { + this( + Source.readFrom((PlanStreamInput) in), + in.readNamedWriteable(Expression.class), + in.readNamedWriteable(Expression.class), + in.readNamedWriteable(Expression.class) + ); + } + + @Override + public AggregateFunction perTimeSeriesAggregation() { + return this; + } + + @Override + public AggregateFunction withFilter(Expression filter) { + return new Deriv(source(), field(), filter, timestamp); + } + + @Override + public DataType dataType() { + return DataType.DOUBLE; + } + + @Override + public Expression replaceChildren(List newChildren) { + if (newChildren.size() == 3) { + return new Deriv(source(), newChildren.get(0), newChildren.get(1), newChildren.get(2)); + } else { + assert newChildren.size() == 2; + return new Deriv(source(), newChildren.get(0), newChildren.get(1)); + } + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, Deriv::new, field(), filter(), timestamp); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + @Override + public AggregatorFunctionSupplier supplier() { + return new DerivGroupingAggregatorFunction.Supplier(); + } +} From 96ea0d305c3562d2bbf50176b1dffdb6a4303af0 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Sat, 25 Oct 2025 03:44:58 +0000 Subject: [PATCH 02/27] [CI] Auto commit changes from spotless --- .../DerivGroupingAggregatorFunction.java | 296 ------------------ 1 file changed, 296 deletions(-) delete mode 100644 x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivGroupingAggregatorFunction.java diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivGroupingAggregatorFunction.java deleted file mode 100644 index b7e6de3e0e74f..0000000000000 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivGroupingAggregatorFunction.java +++ /dev/null @@ -1,296 +0,0 @@ -package org.elasticsearch.compute.aggregation; - -import org.elasticsearch.common.util.ObjectArray; -import org.elasticsearch.compute.aggregation.GroupingAggregatorEvaluationContext; -import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction; -import org.elasticsearch.compute.aggregation.SeenGroupIds; -import org.elasticsearch.compute.data.Block; -import org.elasticsearch.compute.data.BlockFactory; -import org.elasticsearch.compute.data.DoubleBlock; -import org.elasticsearch.compute.data.DoubleVector; -import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; -import org.elasticsearch.compute.data.IntBlock; -import org.elasticsearch.compute.data.IntVector; -import org.elasticsearch.compute.data.LongBlock; -import org.elasticsearch.compute.data.Page; -import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.core.Releasables; - -import java.util.List; - -public class DerivGroupingAggregatorFunction implements GroupingAggregatorFunction { - - private static final List INTERMEDIATE_STATE_DESC = List.of( - new IntermediateStateDesc("sumVal", ElementType.DOUBLE), - new IntermediateStateDesc("sumTs", ElementType.LONG), - new IntermediateStateDesc("sumTsVal", ElementType.DOUBLE), - new IntermediateStateDesc("sumTsSq", ElementType.LONG), - new IntermediateStateDesc("count", ElementType.LONG) - ); - - private final List channels; - private final DriverContext driverContext; - private ObjectArray states; - - public DerivGroupingAggregatorFunction(List channels, DriverContext driverContext) { - this.states = driverContext.bigArrays().newObjectArray(256); - this.channels = channels; - this.driverContext = driverContext; - } - - public static class Supplier implements AggregatorFunctionSupplier { - - @Override - public List nonGroupingIntermediateStateDesc() { - throw new UnsupportedOperationException("DerivGroupingAggregatorFunction does not support non-grouping aggregation"); - } - - @Override - public List groupingIntermediateStateDesc() { - return INTERMEDIATE_STATE_DESC; - } - - @Override - public AggregatorFunction aggregator(DriverContext driverContext, List channels) { - throw new UnsupportedOperationException("DerivGroupingAggregatorFunction does not support non-grouping aggregation"); - } - - @Override - public GroupingAggregatorFunction groupingAggregator(DriverContext driverContext, List channels) { - return new DerivGroupingAggregatorFunction(channels, driverContext); - } - - @Override - public String describe() { - return "derivative"; - } - } - - @Override - public AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { - final DoubleBlock valueBlock = page.getBlock(channels.get(0)); - final LongBlock timestampBlock = page.getBlock(channels.get(1)); - final DoubleVector valueVector = valueBlock.asVector(); - return new AddInput() { - @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addBlockInput(positionOffset, groupIds, timestampBlock, valueVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { - addBlockInput(positionOffset, groupIds, timestampBlock, valueVector); - } - - @Override - public void add(int positionOffset, IntVector groupIds) { - for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { - int valuesPosition = groupPosition + positionOffset; - int groupId = groupIds.getInt(groupPosition); - double vValue = valueVector.getDouble(valuesPosition); - long ts = timestampBlock.getLong(valuesPosition); - SimpleLinearRegressionWithTimeseries state = states.get(groupId); - if (state == null) { - state = new SimpleLinearRegressionWithTimeseries(); - states.set(groupId, state); - } - state.add(ts, vValue); - } - } - - @Override - public void close() { - - } - - private void addBlockInput(int positionOffset, IntBlock groupIds, LongBlock timestampBlock, DoubleVector valueVector) { - for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { - if (groupIds.isNull(groupPosition)) { - continue; - } - int valuePosition = groupPosition + positionOffset; - if (valueBlock.isNull(valuePosition)) { - continue; - } - int groupStart = groupIds.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groupIds.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groupIds.getInt(g); - int vStart = valueBlock.getFirstValueIndex(valuePosition); - int vEnd = vStart + valueBlock.getValueCount(valuePosition); - for (int v = vStart; v < vEnd; v++) { - long ts = timestampBlock.getLong(valuePosition); - double val = valueVector.getDouble(valuePosition); - SimpleLinearRegressionWithTimeseries state = states.get(groupId); - if (state == null) { - state = new SimpleLinearRegressionWithTimeseries(); - states.set(groupId, state); - } - state.add(ts, val); - } - } - } - } - }; - } - - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - // No-op - } - - @Override - public void addIntermediateInput(int positionOffset, IntArrayBlock groupIdVector, Page page) { - addIntermediateBlockInput(positionOffset, groupIdVector, page); - } - - @Override - public void addIntermediateInput(int positionOffset, IntBigArrayBlock groupIdVector, Page page) { - addIntermediateBlockInput(positionOffset, groupIdVector, page); - } - - private void addIntermediateBlockInput(int positionOffset, IntBlock groupIdVector, Page page) { - DoubleBlock sumValBlock = page.getBlock(channels.get(0)); - LongBlock sumTsBlock = page.getBlock(channels.get(1)); - DoubleBlock sumTsValBlock = page.getBlock(channels.get(2)); - LongBlock sumTsSqBlock = page.getBlock(channels.get(3)); - LongBlock countBlock = page.getBlock(channels.get(4)); - - if (sumTsBlock.getTotalValueCount() != sumValBlock.getTotalValueCount() - || sumTsBlock.getTotalValueCount() != sumTsValBlock.getTotalValueCount() - || sumTsBlock.getTotalValueCount() != countBlock.getTotalValueCount()) { - throw new IllegalStateException("Mismatched intermediate state block value counts"); - } - - for (int groupPos = 0; groupPos < groupIdVector.getPositionCount(); groupPos++) { - int valuePos = groupPos + positionOffset; - - int firstGroup = groupIdVector.getFirstValueIndex(groupPos); - int lastGroup = firstGroup + groupIdVector.getValueCount(groupPos); - - for (int g = firstGroup; g < lastGroup; g++) { - int groupId = groupIdVector.getInt(g); - states = driverContext.bigArrays().grow(states, groupId + 1); - var state = states.get(groupId); - if (state == null) { - state = new SimpleLinearRegressionWithTimeseries(); - states.set(groupId, state); - } - long sumTs = sumTsBlock.getLong(valuePos); - double sumVal = sumValBlock.getDouble(valuePos); - double sumTsVal = sumTsValBlock.getDouble(valuePos); - long sumTsSq = sumTsSqBlock.getLong(valuePos); - long count = countBlock.getLong(valuePos); - state.sumTs += sumTs; - state.sumVal += sumVal; - state.sumTsVal += sumTsVal; - state.sumTsSq += sumTsSq; - state.count += count; - } - } - - } - - @Override - public void addIntermediateInput(int positionOffset, IntVector groupIdVector, Page page) { - DoubleBlock sumValBlock = page.getBlock(channels.get(0)); - LongBlock sumTsBlock = page.getBlock(channels.get(1)); - DoubleBlock sumTsValBlock = page.getBlock(channels.get(2)); - LongBlock sumTsSqBlock = page.getBlock(channels.get(3)); - LongBlock countBlock = page.getBlock(channels.get(4)); - - if (sumTsBlock.getTotalValueCount() != sumValBlock.getTotalValueCount() - || sumTsBlock.getTotalValueCount() != sumTsValBlock.getTotalValueCount() - || sumTsBlock.getTotalValueCount() != countBlock.getTotalValueCount()) { - throw new IllegalStateException("Mismatched intermediate state block value counts"); - } - - for (int groupPos = 0; groupPos < groupIdVector.getPositionCount(); groupPos++) { - int valuePos = groupPos + positionOffset; - int groupId = groupIdVector.getInt(groupPos); - states = driverContext.bigArrays().grow(states, groupId + 1); - var state = states.get(groupId); - if (state == null) { - state = new SimpleLinearRegressionWithTimeseries(); - states.set(groupId, state); - } - long sumTs = sumTsBlock.getLong(valuePos); - double sumVal = sumValBlock.getDouble(valuePos); - double sumTsVal = sumTsValBlock.getDouble(valuePos); - long sumTsSq = sumTsSqBlock.getLong(valuePos); - long count = countBlock.getLong(valuePos); - state.sumTs += sumTs; - state.sumVal += sumVal; - state.sumTsVal += sumTsVal; - state.sumTsSq += sumTsSq; - state.count += count; - } - } - - @Override - public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { - BlockFactory blockFactory = driverContext.blockFactory(); - int positionCount = selected.getPositionCount(); - try ( - var sumValBuilder = blockFactory.newDoubleBlockBuilder(positionCount); - var sumTsBuilder = blockFactory.newLongBlockBuilder(positionCount); - var sumTsValBuilder = blockFactory.newDoubleBlockBuilder(positionCount); - var sumTsSqBuilder = blockFactory.newLongBlockBuilder(positionCount); - var countBuilder = blockFactory.newLongBlockBuilder(positionCount) - ) { - for (int p = 0; p < positionCount; p++) { - int groupId = selected.getInt(p); - SimpleLinearRegressionWithTimeseries state = states.get(groupId); - if (state == null) { - sumValBuilder.appendNull(); - sumTsBuilder.appendNull(); - sumTsValBuilder.appendNull(); - sumTsSqBuilder.appendNull(); - countBuilder.appendNull(); - } else { - sumValBuilder.appendDouble(state.sumVal); - sumTsBuilder.appendLong(state.sumTs); - sumTsValBuilder.appendDouble(state.sumTsVal); - sumTsSqBuilder.appendLong(state.sumTsSq); - countBuilder.appendLong(state.count); - } - } - blocks[offset] = sumValBuilder.build(); - blocks[offset + 1] = sumTsBuilder.build(); - blocks[offset + 2] = sumTsValBuilder.build(); - blocks[offset + 3] = sumTsSqBuilder.build(); - blocks[offset + 4] = countBuilder.build(); - } - } - - @Override - public void evaluateFinal(Block[] blocks, int offset, IntVector selected, GroupingAggregatorEvaluationContext evaluationContext) { - BlockFactory blockFactory = driverContext.blockFactory(); - int positionCount = selected.getPositionCount(); - try (var resultBuilder = blockFactory.newDoubleBlockBuilder(positionCount)) { - for (int p = 0; p < positionCount; p++) { - int groupId = selected.getInt(p); - SimpleLinearRegressionWithTimeseries state = states.get(groupId); - if (state == null) { - resultBuilder.appendNull(); - } else { - double deriv = state.slope(); - resultBuilder.appendDouble(deriv); - } - } - blocks[offset] = resultBuilder.build(); - } - } - - @Override - public int intermediateBlockCount() { - return INTERMEDIATE_STATE_DESC.size(); - } - - @Override - public void close() { - Releasables.close(states); - } -} From c1852ba5867a7dd93e2293ef83be6d08880e429a Mon Sep 17 00:00:00 2001 From: Pablo Date: Fri, 24 Oct 2025 21:13:58 -0700 Subject: [PATCH 03/27] Using code generation --- x-pack/plugin/esql/compute/build.gradle | 17 + ...DerivDoubleGroupingAggregatorFunction.java | 295 ++++++++++++++++++ .../DerivIntGroupingAggregatorFunction.java | 295 ++++++++++++++++++ .../DerivLongGroupingAggregatorFunction.java | 295 ++++++++++++++++++ .../DerivGroupingAggregatorFunction.java | 293 +++++++++++++++++ .../GroupingAggregatorFunction.java | 7 + .../X-DerivGroupingAggregatorFunction.java.st | 295 ++++++++++++++++++ .../expression/function/aggregate/Deriv.java | 10 +- 8 files changed, 1506 insertions(+), 1 deletion(-) create mode 100644 x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java create mode 100644 x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java create mode 100644 x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java create mode 100644 x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivGroupingAggregatorFunction.java create mode 100644 x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-DerivGroupingAggregatorFunction.java.st diff --git a/x-pack/plugin/esql/compute/build.gradle b/x-pack/plugin/esql/compute/build.gradle index e792bad34f67a..77fe309569170 100644 --- a/x-pack/plugin/esql/compute/build.gradle +++ b/x-pack/plugin/esql/compute/build.gradle @@ -993,4 +993,21 @@ tasks.named('stringTemplates').configure { it.inputFile = rateAggregatorInputFile it.outputFile = "org/elasticsearch/compute/aggregation/RateLongGroupingAggregatorFunction.java" } + + File derivAggregatorInputFile = file("src/main/java/org/elasticsearch/compute/aggregation/X-DerivGroupingAggregatorFunction.java.st") + template { + it.properties = intProperties + it.inputFile = derivAggregatorInputFile + it.outputFile = "org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java" + } + template { + it.properties = doubleProperties + it.inputFile = derivAggregatorInputFile + it.outputFile = "org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java" + } + template { + it.properties = longProperties + it.inputFile = derivAggregatorInputFile + it.outputFile = "org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java" + } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..473a83488eb3b --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java @@ -0,0 +1,295 @@ +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.util.ObjectArray; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; + +import java.util.List; + +@SuppressWarnings("cast") +public class DerivDoubleGroupingAggregatorFunction implements GroupingAggregatorFunction { + + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("sumVal", ElementType.DOUBLE), + new IntermediateStateDesc("sumTs", ElementType.LONG), + new IntermediateStateDesc("sumTsVal", ElementType.DOUBLE), + new IntermediateStateDesc("sumTsSq", ElementType.LONG), + new IntermediateStateDesc("count", ElementType.LONG) + ); + + private final List channels; + private final DriverContext driverContext; + private ObjectArray states; + + public DerivDoubleGroupingAggregatorFunction(List channels, DriverContext driverContext) { + this.states = driverContext.bigArrays().newObjectArray(256); + this.channels = channels; + this.driverContext = driverContext; + } + + public static class Supplier implements AggregatorFunctionSupplier { + + @Override + public List nonGroupingIntermediateStateDesc() { + throw new UnsupportedOperationException("DerivGroupingAggregatorFunction does not support non-grouping aggregation"); + } + + @Override + public List groupingIntermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public AggregatorFunction aggregator(DriverContext driverContext, List channels) { + throw new UnsupportedOperationException("DerivGroupingAggregatorFunction does not support non-grouping aggregation"); + } + + @Override + public GroupingAggregatorFunction groupingAggregator(DriverContext driverContext, List channels) { + return new DerivGroupingAggregatorFunction(channels, driverContext); + } + + @Override + public String describe() { + return "derivative"; + } + } + + @Override + public AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { + final DoubleBlock valueBlock = page.getBlock(channels.get(0)); + final LongBlock timestampBlock = page.getBlock(channels.get(1)); + final DoubleVector valueVector = valueBlock.asVector(); + return new AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addBlockInput(positionOffset, groupIds, timestampBlock, valueVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addBlockInput(positionOffset, groupIds, timestampBlock, valueVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + int groupId = groupIds.getInt(groupPosition); + double vValue = valueVector.getDouble(valuesPosition); + long ts = timestampBlock.getLong(valuesPosition); + SimpleLinearRegressionWithTimeseries state = states.get(groupId); + if (state == null) { + state = new SimpleLinearRegressionWithTimeseries(); + states.set(groupId, state); + } + state.add(ts, (double) vValue); // TODO - value needs to be converted to double + } + } + + @Override + public void close() { + + } + + private void addBlockInput(int positionOffset, IntBlock groupIds, LongBlock timestampBlock, DoubleVector valueVector) { + for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { + if (groupIds.isNull(groupPosition)) { + continue; + } + int valuePosition = groupPosition + positionOffset; + if (valueBlock.isNull(valuePosition)) { + continue; + } + int groupStart = groupIds.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groupIds.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groupIds.getInt(g); + int vStart = valueBlock.getFirstValueIndex(valuePosition); + int vEnd = vStart + valueBlock.getValueCount(valuePosition); + for (int v = vStart; v < vEnd; v++) { + long ts = timestampBlock.getLong(valuePosition); + double val = valueVector.getDouble(valuePosition); + SimpleLinearRegressionWithTimeseries state = states.get(groupId); + if (state == null) { + state = new SimpleLinearRegressionWithTimeseries(); + states.set(groupId, state); + } + state.add(ts, val); // TODO - value needs to be converted to double + } + } + } + } + }; + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + // No-op + } + + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groupIdVector, Page page) { + addIntermediateBlockInput(positionOffset, groupIdVector, page); + } + + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groupIdVector, Page page) { + addIntermediateBlockInput(positionOffset, groupIdVector, page); + } + + private void addIntermediateBlockInput(int positionOffset, IntBlock groupIdVector, Page page) { + DoubleBlock sumValBlock = page.getBlock(channels.get(0)); + LongBlock sumTsBlock = page.getBlock(channels.get(1)); + DoubleBlock sumTsValBlock = page.getBlock(channels.get(2)); + LongBlock sumTsSqBlock = page.getBlock(channels.get(3)); + LongBlock countBlock = page.getBlock(channels.get(4)); + + if (sumTsBlock.getTotalValueCount() != sumValBlock.getTotalValueCount() + || sumTsBlock.getTotalValueCount() != sumTsValBlock.getTotalValueCount() + || sumTsBlock.getTotalValueCount() != countBlock.getTotalValueCount()) { + throw new IllegalStateException("Mismatched intermediate state block value counts"); + } + + for (int groupPos = 0; groupPos < groupIdVector.getPositionCount(); groupPos++) { + int valuePos = groupPos + positionOffset; + + int firstGroup = groupIdVector.getFirstValueIndex(groupPos); + int lastGroup = firstGroup + groupIdVector.getValueCount(groupPos); + + for (int g = firstGroup; g < lastGroup; g++) { + int groupId = groupIdVector.getInt(g); + states = driverContext.bigArrays().grow(states, groupId + 1); + var state = states.get(groupId); + if (state == null) { + state = new SimpleLinearRegressionWithTimeseries(); // TODO - what happens for int / long + states.set(groupId, state); + } + long sumTs = sumTsBlock.getLong(valuePos); + double sumVal = sumValBlock.getDouble(valuePos); + double sumTsVal = sumTsValBlock.getDouble(valuePos); + long sumTsSq = sumTsSqBlock.getLong(valuePos); + long count = countBlock.getLong(valuePos); + state.sumTs += sumTs; + state.sumVal += sumVal; + state.sumTsVal += sumTsVal; + state.sumTsSq += sumTsSq; + state.count += count; + } + } + + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groupIdVector, Page page) { + DoubleBlock sumValBlock = page.getBlock(channels.get(0)); + LongBlock sumTsBlock = page.getBlock(channels.get(1)); + DoubleBlock sumTsValBlock = page.getBlock(channels.get(2)); + LongBlock sumTsSqBlock = page.getBlock(channels.get(3)); + LongBlock countBlock = page.getBlock(channels.get(4)); + + if (sumTsBlock.getTotalValueCount() != sumValBlock.getTotalValueCount() + || sumTsBlock.getTotalValueCount() != sumTsValBlock.getTotalValueCount() + || sumTsBlock.getTotalValueCount() != countBlock.getTotalValueCount()) { + throw new IllegalStateException("Mismatched intermediate state block value counts"); + } + + for (int groupPos = 0; groupPos < groupIdVector.getPositionCount(); groupPos++) { + int valuePos = groupPos + positionOffset; + int groupId = groupIdVector.getInt(groupPos); + states = driverContext.bigArrays().grow(states, groupId + 1); + var state = states.get(groupId); + if (state == null) { + state = new SimpleLinearRegressionWithTimeseries(); // TODO: what about double conversion + states.set(groupId, state); + } + long sumTs = sumTsBlock.getLong(valuePos); + double sumVal = sumValBlock.getDouble(valuePos); + double sumTsVal = sumTsValBlock.getDouble(valuePos); + long sumTsSq = sumTsSqBlock.getLong(valuePos); + long count = countBlock.getLong(valuePos); + state.sumTs += sumTs; + state.sumVal += sumVal; + state.sumTsVal += sumTsVal; + state.sumTsSq += sumTsSq; + state.count += count; + } + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + BlockFactory blockFactory = driverContext.blockFactory(); + int positionCount = selected.getPositionCount(); + try ( + var sumValBuilder = blockFactory.newDoubleBlockBuilder(positionCount); + var sumTsBuilder = blockFactory.newLongBlockBuilder(positionCount); + var sumTsValBuilder = blockFactory.newDoubleBlockBuilder(positionCount); + var sumTsSqBuilder = blockFactory.newLongBlockBuilder(positionCount); + var countBuilder = blockFactory.newLongBlockBuilder(positionCount) + ) { + for (int p = 0; p < positionCount; p++) { + int groupId = selected.getInt(p); + SimpleLinearRegressionWithTimeseries state = states.get(groupId); + if (state == null) { + sumValBuilder.appendNull(); + sumTsBuilder.appendNull(); + sumTsValBuilder.appendNull(); + sumTsSqBuilder.appendNull(); + countBuilder.appendNull(); + } else { + sumValBuilder.appendDouble((double) state.sumVal); + sumTsBuilder.appendLong(state.sumTs); + sumTsValBuilder.appendDouble((double) state.sumTsVal); // TODO: fix this actually + sumTsSqBuilder.appendLong(state.sumTsSq); + countBuilder.appendLong(state.count); + } + } + blocks[offset] = sumValBuilder.build(); + blocks[offset + 1] = sumTsBuilder.build(); + blocks[offset + 2] = sumTsValBuilder.build(); + blocks[offset + 3] = sumTsSqBuilder.build(); + blocks[offset + 4] = countBuilder.build(); + } + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, GroupingAggregatorEvaluationContext evaluationContext) { + BlockFactory blockFactory = driverContext.blockFactory(); + int positionCount = selected.getPositionCount(); + try (var resultBuilder = blockFactory.newDoubleBlockBuilder(positionCount)) { + for (int p = 0; p < positionCount; p++) { + int groupId = selected.getInt(p); + SimpleLinearRegressionWithTimeseries state = states.get(groupId); + if (state == null) { + resultBuilder.appendNull(); + } else { + double deriv = state.slope(); + resultBuilder.appendDouble(deriv); + } + } + blocks[offset] = resultBuilder.build(); + } + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void close() { + Releasables.close(states); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..de82eca0471c6 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java @@ -0,0 +1,295 @@ +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.util.ObjectArray; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; + +import java.util.List; + +@SuppressWarnings("cast") +public class DerivIntGroupingAggregatorFunction implements GroupingAggregatorFunction { + + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("sumVal", ElementType.INT), + new IntermediateStateDesc("sumTs", ElementType.LONG), + new IntermediateStateDesc("sumTsVal", ElementType.INT), + new IntermediateStateDesc("sumTsSq", ElementType.LONG), + new IntermediateStateDesc("count", ElementType.LONG) + ); + + private final List channels; + private final DriverContext driverContext; + private ObjectArray states; + + public DerivIntGroupingAggregatorFunction(List channels, DriverContext driverContext) { + this.states = driverContext.bigArrays().newObjectArray(256); + this.channels = channels; + this.driverContext = driverContext; + } + + public static class Supplier implements AggregatorFunctionSupplier { + + @Override + public List nonGroupingIntermediateStateDesc() { + throw new UnsupportedOperationException("DerivGroupingAggregatorFunction does not support non-grouping aggregation"); + } + + @Override + public List groupingIntermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public AggregatorFunction aggregator(DriverContext driverContext, List channels) { + throw new UnsupportedOperationException("DerivGroupingAggregatorFunction does not support non-grouping aggregation"); + } + + @Override + public GroupingAggregatorFunction groupingAggregator(DriverContext driverContext, List channels) { + return new DerivGroupingAggregatorFunction(channels, driverContext); + } + + @Override + public String describe() { + return "derivative"; + } + } + + @Override + public AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { + final IntBlock valueBlock = page.getBlock(channels.get(0)); + final LongBlock timestampBlock = page.getBlock(channels.get(1)); + final IntVector valueVector = valueBlock.asVector(); + return new AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addBlockInput(positionOffset, groupIds, timestampBlock, valueVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addBlockInput(positionOffset, groupIds, timestampBlock, valueVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + int groupId = groupIds.getInt(groupPosition); + int vValue = valueVector.getInt(valuesPosition); + long ts = timestampBlock.getLong(valuesPosition); + SimpleLinearRegressionWithTimeseries state = states.get(groupId); + if (state == null) { + state = new SimpleLinearRegressionWithTimeseries(); + states.set(groupId, state); + } + state.add(ts, (double) vValue); // TODO - value needs to be converted to double + } + } + + @Override + public void close() { + + } + + private void addBlockInput(int positionOffset, IntBlock groupIds, LongBlock timestampBlock, IntVector valueVector) { + for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { + if (groupIds.isNull(groupPosition)) { + continue; + } + int valuePosition = groupPosition + positionOffset; + if (valueBlock.isNull(valuePosition)) { + continue; + } + int groupStart = groupIds.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groupIds.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groupIds.getInt(g); + int vStart = valueBlock.getFirstValueIndex(valuePosition); + int vEnd = vStart + valueBlock.getValueCount(valuePosition); + for (int v = vStart; v < vEnd; v++) { + long ts = timestampBlock.getLong(valuePosition); + int val = valueVector.getInt(valuePosition); + SimpleLinearRegressionWithTimeseries state = states.get(groupId); + if (state == null) { + state = new SimpleLinearRegressionWithTimeseries(); + states.set(groupId, state); + } + state.add(ts, val); // TODO - value needs to be converted to double + } + } + } + } + }; + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + // No-op + } + + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groupIdVector, Page page) { + addIntermediateBlockInput(positionOffset, groupIdVector, page); + } + + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groupIdVector, Page page) { + addIntermediateBlockInput(positionOffset, groupIdVector, page); + } + + private void addIntermediateBlockInput(int positionOffset, IntBlock groupIdVector, Page page) { + IntBlock sumValBlock = page.getBlock(channels.get(0)); + LongBlock sumTsBlock = page.getBlock(channels.get(1)); + IntBlock sumTsValBlock = page.getBlock(channels.get(2)); + LongBlock sumTsSqBlock = page.getBlock(channels.get(3)); + LongBlock countBlock = page.getBlock(channels.get(4)); + + if (sumTsBlock.getTotalValueCount() != sumValBlock.getTotalValueCount() + || sumTsBlock.getTotalValueCount() != sumTsValBlock.getTotalValueCount() + || sumTsBlock.getTotalValueCount() != countBlock.getTotalValueCount()) { + throw new IllegalStateException("Mismatched intermediate state block value counts"); + } + + for (int groupPos = 0; groupPos < groupIdVector.getPositionCount(); groupPos++) { + int valuePos = groupPos + positionOffset; + + int firstGroup = groupIdVector.getFirstValueIndex(groupPos); + int lastGroup = firstGroup + groupIdVector.getValueCount(groupPos); + + for (int g = firstGroup; g < lastGroup; g++) { + int groupId = groupIdVector.getInt(g); + states = driverContext.bigArrays().grow(states, groupId + 1); + var state = states.get(groupId); + if (state == null) { + state = new SimpleLinearRegressionWithTimeseries(); // TODO - what happens for int / long + states.set(groupId, state); + } + long sumTs = sumTsBlock.getLong(valuePos); + int sumVal = sumValBlock.getInt(valuePos); + int sumTsVal = sumTsValBlock.getInt(valuePos); + long sumTsSq = sumTsSqBlock.getLong(valuePos); + long count = countBlock.getLong(valuePos); + state.sumTs += sumTs; + state.sumVal += sumVal; + state.sumTsVal += sumTsVal; + state.sumTsSq += sumTsSq; + state.count += count; + } + } + + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groupIdVector, Page page) { + IntBlock sumValBlock = page.getBlock(channels.get(0)); + LongBlock sumTsBlock = page.getBlock(channels.get(1)); + IntBlock sumTsValBlock = page.getBlock(channels.get(2)); + LongBlock sumTsSqBlock = page.getBlock(channels.get(3)); + LongBlock countBlock = page.getBlock(channels.get(4)); + + if (sumTsBlock.getTotalValueCount() != sumValBlock.getTotalValueCount() + || sumTsBlock.getTotalValueCount() != sumTsValBlock.getTotalValueCount() + || sumTsBlock.getTotalValueCount() != countBlock.getTotalValueCount()) { + throw new IllegalStateException("Mismatched intermediate state block value counts"); + } + + for (int groupPos = 0; groupPos < groupIdVector.getPositionCount(); groupPos++) { + int valuePos = groupPos + positionOffset; + int groupId = groupIdVector.getInt(groupPos); + states = driverContext.bigArrays().grow(states, groupId + 1); + var state = states.get(groupId); + if (state == null) { + state = new SimpleLinearRegressionWithTimeseries(); // TODO: what about double conversion + states.set(groupId, state); + } + long sumTs = sumTsBlock.getLong(valuePos); + int sumVal = sumValBlock.getInt(valuePos); + int sumTsVal = sumTsValBlock.getInt(valuePos); + long sumTsSq = sumTsSqBlock.getLong(valuePos); + long count = countBlock.getLong(valuePos); + state.sumTs += sumTs; + state.sumVal += sumVal; + state.sumTsVal += sumTsVal; + state.sumTsSq += sumTsSq; + state.count += count; + } + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + BlockFactory blockFactory = driverContext.blockFactory(); + int positionCount = selected.getPositionCount(); + try ( + var sumValBuilder = blockFactory.newIntBlockBuilder(positionCount); + var sumTsBuilder = blockFactory.newLongBlockBuilder(positionCount); + var sumTsValBuilder = blockFactory.newIntBlockBuilder(positionCount); + var sumTsSqBuilder = blockFactory.newLongBlockBuilder(positionCount); + var countBuilder = blockFactory.newLongBlockBuilder(positionCount) + ) { + for (int p = 0; p < positionCount; p++) { + int groupId = selected.getInt(p); + SimpleLinearRegressionWithTimeseries state = states.get(groupId); + if (state == null) { + sumValBuilder.appendNull(); + sumTsBuilder.appendNull(); + sumTsValBuilder.appendNull(); + sumTsSqBuilder.appendNull(); + countBuilder.appendNull(); + } else { + sumValBuilder.appendInt((int) state.sumVal); + sumTsBuilder.appendLong(state.sumTs); + sumTsValBuilder.appendInt((int) state.sumTsVal); // TODO: fix this actually + sumTsSqBuilder.appendLong(state.sumTsSq); + countBuilder.appendLong(state.count); + } + } + blocks[offset] = sumValBuilder.build(); + blocks[offset + 1] = sumTsBuilder.build(); + blocks[offset + 2] = sumTsValBuilder.build(); + blocks[offset + 3] = sumTsSqBuilder.build(); + blocks[offset + 4] = countBuilder.build(); + } + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, GroupingAggregatorEvaluationContext evaluationContext) { + BlockFactory blockFactory = driverContext.blockFactory(); + int positionCount = selected.getPositionCount(); + try (var resultBuilder = blockFactory.newDoubleBlockBuilder(positionCount)) { + for (int p = 0; p < positionCount; p++) { + int groupId = selected.getInt(p); + SimpleLinearRegressionWithTimeseries state = states.get(groupId); + if (state == null) { + resultBuilder.appendNull(); + } else { + double deriv = state.slope(); + resultBuilder.appendDouble(deriv); + } + } + blocks[offset] = resultBuilder.build(); + } + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void close() { + Releasables.close(states); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..5290843ce2635 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java @@ -0,0 +1,295 @@ +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.util.ObjectArray; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; + +import java.util.List; + +@SuppressWarnings("cast") +public class DerivLongGroupingAggregatorFunction implements GroupingAggregatorFunction { + + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("sumVal", ElementType.LONG), + new IntermediateStateDesc("sumTs", ElementType.LONG), + new IntermediateStateDesc("sumTsVal", ElementType.LONG), + new IntermediateStateDesc("sumTsSq", ElementType.LONG), + new IntermediateStateDesc("count", ElementType.LONG) + ); + + private final List channels; + private final DriverContext driverContext; + private ObjectArray states; + + public DerivLongGroupingAggregatorFunction(List channels, DriverContext driverContext) { + this.states = driverContext.bigArrays().newObjectArray(256); + this.channels = channels; + this.driverContext = driverContext; + } + + public static class Supplier implements AggregatorFunctionSupplier { + + @Override + public List nonGroupingIntermediateStateDesc() { + throw new UnsupportedOperationException("DerivGroupingAggregatorFunction does not support non-grouping aggregation"); + } + + @Override + public List groupingIntermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public AggregatorFunction aggregator(DriverContext driverContext, List channels) { + throw new UnsupportedOperationException("DerivGroupingAggregatorFunction does not support non-grouping aggregation"); + } + + @Override + public GroupingAggregatorFunction groupingAggregator(DriverContext driverContext, List channels) { + return new DerivGroupingAggregatorFunction(channels, driverContext); + } + + @Override + public String describe() { + return "derivative"; + } + } + + @Override + public AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { + final LongBlock valueBlock = page.getBlock(channels.get(0)); + final LongBlock timestampBlock = page.getBlock(channels.get(1)); + final LongVector valueVector = valueBlock.asVector(); + return new AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addBlockInput(positionOffset, groupIds, timestampBlock, valueVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addBlockInput(positionOffset, groupIds, timestampBlock, valueVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + int groupId = groupIds.getInt(groupPosition); + long vValue = valueVector.getLong(valuesPosition); + long ts = timestampBlock.getLong(valuesPosition); + SimpleLinearRegressionWithTimeseries state = states.get(groupId); + if (state == null) { + state = new SimpleLinearRegressionWithTimeseries(); + states.set(groupId, state); + } + state.add(ts, (double) vValue); // TODO - value needs to be converted to double + } + } + + @Override + public void close() { + + } + + private void addBlockInput(int positionOffset, IntBlock groupIds, LongBlock timestampBlock, LongVector valueVector) { + for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { + if (groupIds.isNull(groupPosition)) { + continue; + } + int valuePosition = groupPosition + positionOffset; + if (valueBlock.isNull(valuePosition)) { + continue; + } + int groupStart = groupIds.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groupIds.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groupIds.getInt(g); + int vStart = valueBlock.getFirstValueIndex(valuePosition); + int vEnd = vStart + valueBlock.getValueCount(valuePosition); + for (int v = vStart; v < vEnd; v++) { + long ts = timestampBlock.getLong(valuePosition); + long val = valueVector.getLong(valuePosition); + SimpleLinearRegressionWithTimeseries state = states.get(groupId); + if (state == null) { + state = new SimpleLinearRegressionWithTimeseries(); + states.set(groupId, state); + } + state.add(ts, val); // TODO - value needs to be converted to double + } + } + } + } + }; + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + // No-op + } + + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groupIdVector, Page page) { + addIntermediateBlockInput(positionOffset, groupIdVector, page); + } + + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groupIdVector, Page page) { + addIntermediateBlockInput(positionOffset, groupIdVector, page); + } + + private void addIntermediateBlockInput(int positionOffset, IntBlock groupIdVector, Page page) { + LongBlock sumValBlock = page.getBlock(channels.get(0)); + LongBlock sumTsBlock = page.getBlock(channels.get(1)); + LongBlock sumTsValBlock = page.getBlock(channels.get(2)); + LongBlock sumTsSqBlock = page.getBlock(channels.get(3)); + LongBlock countBlock = page.getBlock(channels.get(4)); + + if (sumTsBlock.getTotalValueCount() != sumValBlock.getTotalValueCount() + || sumTsBlock.getTotalValueCount() != sumTsValBlock.getTotalValueCount() + || sumTsBlock.getTotalValueCount() != countBlock.getTotalValueCount()) { + throw new IllegalStateException("Mismatched intermediate state block value counts"); + } + + for (int groupPos = 0; groupPos < groupIdVector.getPositionCount(); groupPos++) { + int valuePos = groupPos + positionOffset; + + int firstGroup = groupIdVector.getFirstValueIndex(groupPos); + int lastGroup = firstGroup + groupIdVector.getValueCount(groupPos); + + for (int g = firstGroup; g < lastGroup; g++) { + int groupId = groupIdVector.getInt(g); + states = driverContext.bigArrays().grow(states, groupId + 1); + var state = states.get(groupId); + if (state == null) { + state = new SimpleLinearRegressionWithTimeseries(); // TODO - what happens for int / long + states.set(groupId, state); + } + long sumTs = sumTsBlock.getLong(valuePos); + long sumVal = sumValBlock.getLong(valuePos); + long sumTsVal = sumTsValBlock.getLong(valuePos); + long sumTsSq = sumTsSqBlock.getLong(valuePos); + long count = countBlock.getLong(valuePos); + state.sumTs += sumTs; + state.sumVal += sumVal; + state.sumTsVal += sumTsVal; + state.sumTsSq += sumTsSq; + state.count += count; + } + } + + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groupIdVector, Page page) { + LongBlock sumValBlock = page.getBlock(channels.get(0)); + LongBlock sumTsBlock = page.getBlock(channels.get(1)); + LongBlock sumTsValBlock = page.getBlock(channels.get(2)); + LongBlock sumTsSqBlock = page.getBlock(channels.get(3)); + LongBlock countBlock = page.getBlock(channels.get(4)); + + if (sumTsBlock.getTotalValueCount() != sumValBlock.getTotalValueCount() + || sumTsBlock.getTotalValueCount() != sumTsValBlock.getTotalValueCount() + || sumTsBlock.getTotalValueCount() != countBlock.getTotalValueCount()) { + throw new IllegalStateException("Mismatched intermediate state block value counts"); + } + + for (int groupPos = 0; groupPos < groupIdVector.getPositionCount(); groupPos++) { + int valuePos = groupPos + positionOffset; + int groupId = groupIdVector.getInt(groupPos); + states = driverContext.bigArrays().grow(states, groupId + 1); + var state = states.get(groupId); + if (state == null) { + state = new SimpleLinearRegressionWithTimeseries(); // TODO: what about double conversion + states.set(groupId, state); + } + long sumTs = sumTsBlock.getLong(valuePos); + long sumVal = sumValBlock.getLong(valuePos); + long sumTsVal = sumTsValBlock.getLong(valuePos); + long sumTsSq = sumTsSqBlock.getLong(valuePos); + long count = countBlock.getLong(valuePos); + state.sumTs += sumTs; + state.sumVal += sumVal; + state.sumTsVal += sumTsVal; + state.sumTsSq += sumTsSq; + state.count += count; + } + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + BlockFactory blockFactory = driverContext.blockFactory(); + int positionCount = selected.getPositionCount(); + try ( + var sumValBuilder = blockFactory.newLongBlockBuilder(positionCount); + var sumTsBuilder = blockFactory.newLongBlockBuilder(positionCount); + var sumTsValBuilder = blockFactory.newLongBlockBuilder(positionCount); + var sumTsSqBuilder = blockFactory.newLongBlockBuilder(positionCount); + var countBuilder = blockFactory.newLongBlockBuilder(positionCount) + ) { + for (int p = 0; p < positionCount; p++) { + int groupId = selected.getInt(p); + SimpleLinearRegressionWithTimeseries state = states.get(groupId); + if (state == null) { + sumValBuilder.appendNull(); + sumTsBuilder.appendNull(); + sumTsValBuilder.appendNull(); + sumTsSqBuilder.appendNull(); + countBuilder.appendNull(); + } else { + sumValBuilder.appendLong((long) state.sumVal); + sumTsBuilder.appendLong(state.sumTs); + sumTsValBuilder.appendLong((long) state.sumTsVal); // TODO: fix this actually + sumTsSqBuilder.appendLong(state.sumTsSq); + countBuilder.appendLong(state.count); + } + } + blocks[offset] = sumValBuilder.build(); + blocks[offset + 1] = sumTsBuilder.build(); + blocks[offset + 2] = sumTsValBuilder.build(); + blocks[offset + 3] = sumTsSqBuilder.build(); + blocks[offset + 4] = countBuilder.build(); + } + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, GroupingAggregatorEvaluationContext evaluationContext) { + BlockFactory blockFactory = driverContext.blockFactory(); + int positionCount = selected.getPositionCount(); + try (var resultBuilder = blockFactory.newDoubleBlockBuilder(positionCount)) { + for (int p = 0; p < positionCount; p++) { + int groupId = selected.getInt(p); + SimpleLinearRegressionWithTimeseries state = states.get(groupId); + if (state == null) { + resultBuilder.appendNull(); + } else { + double deriv = state.slope(); + resultBuilder.appendDouble(deriv); + } + } + blocks[offset] = resultBuilder.build(); + } + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void close() { + Releasables.close(states); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..769c9bfd5f919 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivGroupingAggregatorFunction.java @@ -0,0 +1,293 @@ +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.util.ObjectArray; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; + +import java.util.List; + +public class DerivGroupingAggregatorFunction implements GroupingAggregatorFunction { + + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("sumVal", ElementType.DOUBLE), + new IntermediateStateDesc("sumTs", ElementType.LONG), + new IntermediateStateDesc("sumTsVal", ElementType.DOUBLE), + new IntermediateStateDesc("sumTsSq", ElementType.LONG), + new IntermediateStateDesc("count", ElementType.LONG) + ); + + private final List channels; + private final DriverContext driverContext; + private ObjectArray states; + + public DerivGroupingAggregatorFunction(List channels, DriverContext driverContext) { + this.states = driverContext.bigArrays().newObjectArray(256); + this.channels = channels; + this.driverContext = driverContext; + } + + public static class Supplier implements AggregatorFunctionSupplier { + + @Override + public List nonGroupingIntermediateStateDesc() { + throw new UnsupportedOperationException("DerivGroupingAggregatorFunction does not support non-grouping aggregation"); + } + + @Override + public List groupingIntermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public AggregatorFunction aggregator(DriverContext driverContext, List channels) { + throw new UnsupportedOperationException("DerivGroupingAggregatorFunction does not support non-grouping aggregation"); + } + + @Override + public GroupingAggregatorFunction groupingAggregator(DriverContext driverContext, List channels) { + return new DerivGroupingAggregatorFunction(channels, driverContext); + } + + @Override + public String describe() { + return "derivative"; + } + } + + @Override + public AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { + final DoubleBlock valueBlock = page.getBlock(channels.get(0)); + final LongBlock timestampBlock = page.getBlock(channels.get(1)); + final DoubleVector valueVector = valueBlock.asVector(); + return new AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addBlockInput(positionOffset, groupIds, timestampBlock, valueVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addBlockInput(positionOffset, groupIds, timestampBlock, valueVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + int groupId = groupIds.getInt(groupPosition); + double vValue = valueVector.getDouble(valuesPosition); + long ts = timestampBlock.getLong(valuesPosition); + SimpleLinearRegressionWithTimeseries state = states.get(groupId); + if (state == null) { + state = new SimpleLinearRegressionWithTimeseries(); + states.set(groupId, state); + } + state.add(ts, vValue); + } + } + + @Override + public void close() { + + } + + private void addBlockInput(int positionOffset, IntBlock groupIds, LongBlock timestampBlock, DoubleVector valueVector) { + for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { + if (groupIds.isNull(groupPosition)) { + continue; + } + int valuePosition = groupPosition + positionOffset; + if (valueBlock.isNull(valuePosition)) { + continue; + } + int groupStart = groupIds.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groupIds.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groupIds.getInt(g); + int vStart = valueBlock.getFirstValueIndex(valuePosition); + int vEnd = vStart + valueBlock.getValueCount(valuePosition); + for (int v = vStart; v < vEnd; v++) { + long ts = timestampBlock.getLong(valuePosition); + double val = valueVector.getDouble(valuePosition); + SimpleLinearRegressionWithTimeseries state = states.get(groupId); + if (state == null) { + state = new SimpleLinearRegressionWithTimeseries(); + states.set(groupId, state); + } + state.add(ts, val); + } + } + } + } + }; + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + // No-op + } + + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groupIdVector, Page page) { + addIntermediateBlockInput(positionOffset, groupIdVector, page); + } + + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groupIdVector, Page page) { + addIntermediateBlockInput(positionOffset, groupIdVector, page); + } + + private void addIntermediateBlockInput(int positionOffset, IntBlock groupIdVector, Page page) { + DoubleBlock sumValBlock = page.getBlock(channels.get(0)); + LongBlock sumTsBlock = page.getBlock(channels.get(1)); + DoubleBlock sumTsValBlock = page.getBlock(channels.get(2)); + LongBlock sumTsSqBlock = page.getBlock(channels.get(3)); + LongBlock countBlock = page.getBlock(channels.get(4)); + + if (sumTsBlock.getTotalValueCount() != sumValBlock.getTotalValueCount() + || sumTsBlock.getTotalValueCount() != sumTsValBlock.getTotalValueCount() + || sumTsBlock.getTotalValueCount() != countBlock.getTotalValueCount()) { + throw new IllegalStateException("Mismatched intermediate state block value counts"); + } + + for (int groupPos = 0; groupPos < groupIdVector.getPositionCount(); groupPos++) { + int valuePos = groupPos + positionOffset; + + int firstGroup = groupIdVector.getFirstValueIndex(groupPos); + int lastGroup = firstGroup + groupIdVector.getValueCount(groupPos); + + for (int g = firstGroup; g < lastGroup; g++) { + int groupId = groupIdVector.getInt(g); + states = driverContext.bigArrays().grow(states, groupId + 1); + var state = states.get(groupId); + if (state == null) { + state = new SimpleLinearRegressionWithTimeseries(); + states.set(groupId, state); + } + long sumTs = sumTsBlock.getLong(valuePos); + double sumVal = sumValBlock.getDouble(valuePos); + double sumTsVal = sumTsValBlock.getDouble(valuePos); + long sumTsSq = sumTsSqBlock.getLong(valuePos); + long count = countBlock.getLong(valuePos); + state.sumTs += sumTs; + state.sumVal += sumVal; + state.sumTsVal += sumTsVal; + state.sumTsSq += sumTsSq; + state.count += count; + } + } + + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groupIdVector, Page page) { + DoubleBlock sumValBlock = page.getBlock(channels.get(0)); + LongBlock sumTsBlock = page.getBlock(channels.get(1)); + DoubleBlock sumTsValBlock = page.getBlock(channels.get(2)); + LongBlock sumTsSqBlock = page.getBlock(channels.get(3)); + LongBlock countBlock = page.getBlock(channels.get(4)); + + if (sumTsBlock.getTotalValueCount() != sumValBlock.getTotalValueCount() + || sumTsBlock.getTotalValueCount() != sumTsValBlock.getTotalValueCount() + || sumTsBlock.getTotalValueCount() != countBlock.getTotalValueCount()) { + throw new IllegalStateException("Mismatched intermediate state block value counts"); + } + + for (int groupPos = 0; groupPos < groupIdVector.getPositionCount(); groupPos++) { + int valuePos = groupPos + positionOffset; + int groupId = groupIdVector.getInt(groupPos); + states = driverContext.bigArrays().grow(states, groupId + 1); + var state = states.get(groupId); + if (state == null) { + state = new SimpleLinearRegressionWithTimeseries(); + states.set(groupId, state); + } + long sumTs = sumTsBlock.getLong(valuePos); + double sumVal = sumValBlock.getDouble(valuePos); + double sumTsVal = sumTsValBlock.getDouble(valuePos); + long sumTsSq = sumTsSqBlock.getLong(valuePos); + long count = countBlock.getLong(valuePos); + state.sumTs += sumTs; + state.sumVal += sumVal; + state.sumTsVal += sumTsVal; + state.sumTsSq += sumTsSq; + state.count += count; + } + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + BlockFactory blockFactory = driverContext.blockFactory(); + int positionCount = selected.getPositionCount(); + try ( + var sumValBuilder = blockFactory.newDoubleBlockBuilder(positionCount); + var sumTsBuilder = blockFactory.newLongBlockBuilder(positionCount); + var sumTsValBuilder = blockFactory.newDoubleBlockBuilder(positionCount); + var sumTsSqBuilder = blockFactory.newLongBlockBuilder(positionCount); + var countBuilder = blockFactory.newLongBlockBuilder(positionCount) + ) { + for (int p = 0; p < positionCount; p++) { + int groupId = selected.getInt(p); + SimpleLinearRegressionWithTimeseries state = states.get(groupId); + if (state == null) { + sumValBuilder.appendNull(); + sumTsBuilder.appendNull(); + sumTsValBuilder.appendNull(); + sumTsSqBuilder.appendNull(); + countBuilder.appendNull(); + } else { + sumValBuilder.appendDouble(state.sumVal); + sumTsBuilder.appendLong(state.sumTs); + sumTsValBuilder.appendDouble(state.sumTsVal); + sumTsSqBuilder.appendLong(state.sumTsSq); + countBuilder.appendLong(state.count); + } + } + blocks[offset] = sumValBuilder.build(); + blocks[offset + 1] = sumTsBuilder.build(); + blocks[offset + 2] = sumTsValBuilder.build(); + blocks[offset + 3] = sumTsSqBuilder.build(); + blocks[offset + 4] = countBuilder.build(); + } + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, GroupingAggregatorEvaluationContext evaluationContext) { + BlockFactory blockFactory = driverContext.blockFactory(); + int positionCount = selected.getPositionCount(); + try (var resultBuilder = blockFactory.newDoubleBlockBuilder(positionCount)) { + for (int p = 0; p < positionCount; p++) { + int groupId = selected.getInt(p); + SimpleLinearRegressionWithTimeseries state = states.get(groupId); + if (state == null) { + resultBuilder.appendNull(); + } else { + double deriv = state.slope(); + resultBuilder.appendDouble(deriv); + } + } + blocks[offset] = resultBuilder.build(); + } + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void close() { + Releasables.close(states); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunction.java index a60bcb1523ffc..2170b587172fb 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunction.java @@ -142,6 +142,9 @@ default void add(int positionOffset, IntBlock groupIds) { * Build the intermediate results for this aggregation. * @param selected the groupIds that have been selected to be included in * the results. Always ascending. + * + *

This function is called in the coordinator node after all intermediate + * blocks have been gathered from the data nodes.

*/ void evaluateIntermediate(Block[] blocks, int offset, IntVector selected); @@ -149,6 +152,10 @@ default void add(int positionOffset, IntBlock groupIds) { * Build the final results for this aggregation. * @param selected the groupIds that have been selected to be included in * the results. Always ascending. + * + *

This function is called in the coordinator node after all intermediate + * results have been gathered from the worker nodes, and aggregated into + * intermediate blocks.

*/ void evaluateFinal(Block[] blocks, int offset, IntVector selected, GroupingAggregatorEvaluationContext evaluationContext); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-DerivGroupingAggregatorFunction.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-DerivGroupingAggregatorFunction.java.st new file mode 100644 index 0000000000000..9559f71c23740 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-DerivGroupingAggregatorFunction.java.st @@ -0,0 +1,295 @@ +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.util.ObjectArray; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; + +import java.util.List; + +@SuppressWarnings("cast") +public class Deriv$Type$GroupingAggregatorFunction implements GroupingAggregatorFunction { + + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("sumVal", ElementType.$TYPE$), + new IntermediateStateDesc("sumTs", ElementType.LONG), + new IntermediateStateDesc("sumTsVal", ElementType.$TYPE$), + new IntermediateStateDesc("sumTsSq", ElementType.LONG), + new IntermediateStateDesc("count", ElementType.LONG) + ); + + private final List channels; + private final DriverContext driverContext; + private ObjectArray states; + + public Deriv$Type$GroupingAggregatorFunction(List channels, DriverContext driverContext) { + this.states = driverContext.bigArrays().newObjectArray(256); + this.channels = channels; + this.driverContext = driverContext; + } + + public static class Supplier implements AggregatorFunctionSupplier { + + @Override + public List nonGroupingIntermediateStateDesc() { + throw new UnsupportedOperationException("DerivGroupingAggregatorFunction does not support non-grouping aggregation"); + } + + @Override + public List groupingIntermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public AggregatorFunction aggregator(DriverContext driverContext, List channels) { + throw new UnsupportedOperationException("DerivGroupingAggregatorFunction does not support non-grouping aggregation"); + } + + @Override + public GroupingAggregatorFunction groupingAggregator(DriverContext driverContext, List channels) { + return new DerivGroupingAggregatorFunction(channels, driverContext); + } + + @Override + public String describe() { + return "derivative"; + } + } + + @Override + public AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { + final $Type$Block valueBlock = page.getBlock(channels.get(0)); + final LongBlock timestampBlock = page.getBlock(channels.get(1)); + final $Type$Vector valueVector = valueBlock.asVector(); + return new AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addBlockInput(positionOffset, groupIds, timestampBlock, valueVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addBlockInput(positionOffset, groupIds, timestampBlock, valueVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + int groupId = groupIds.getInt(groupPosition); + $type$ vValue = valueVector.get$Type$(valuesPosition); + long ts = timestampBlock.getLong(valuesPosition); + SimpleLinearRegressionWithTimeseries state = states.get(groupId); + if (state == null) { + state = new SimpleLinearRegressionWithTimeseries(); + states.set(groupId, state); + } + state.add(ts, (double) vValue); // TODO - value needs to be converted to double + } + } + + @Override + public void close() { + + } + + private void addBlockInput(int positionOffset, IntBlock groupIds, LongBlock timestampBlock, $Type$Vector valueVector) { + for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { + if (groupIds.isNull(groupPosition)) { + continue; + } + int valuePosition = groupPosition + positionOffset; + if (valueBlock.isNull(valuePosition)) { + continue; + } + int groupStart = groupIds.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groupIds.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groupIds.getInt(g); + int vStart = valueBlock.getFirstValueIndex(valuePosition); + int vEnd = vStart + valueBlock.getValueCount(valuePosition); + for (int v = vStart; v < vEnd; v++) { + long ts = timestampBlock.getLong(valuePosition); + $type$ val = valueVector.get$Type$(valuePosition); + SimpleLinearRegressionWithTimeseries state = states.get(groupId); + if (state == null) { + state = new SimpleLinearRegressionWithTimeseries(); + states.set(groupId, state); + } + state.add(ts, val); // TODO - value needs to be converted to double + } + } + } + } + }; + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + // No-op + } + + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groupIdVector, Page page) { + addIntermediateBlockInput(positionOffset, groupIdVector, page); + } + + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groupIdVector, Page page) { + addIntermediateBlockInput(positionOffset, groupIdVector, page); + } + + private void addIntermediateBlockInput(int positionOffset, IntBlock groupIdVector, Page page) { + $Type$Block sumValBlock = page.getBlock(channels.get(0)); + LongBlock sumTsBlock = page.getBlock(channels.get(1)); + $Type$Block sumTsValBlock = page.getBlock(channels.get(2)); + LongBlock sumTsSqBlock = page.getBlock(channels.get(3)); + LongBlock countBlock = page.getBlock(channels.get(4)); + + if (sumTsBlock.getTotalValueCount() != sumValBlock.getTotalValueCount() + || sumTsBlock.getTotalValueCount() != sumTsValBlock.getTotalValueCount() + || sumTsBlock.getTotalValueCount() != countBlock.getTotalValueCount()) { + throw new IllegalStateException("Mismatched intermediate state block value counts"); + } + + for (int groupPos = 0; groupPos < groupIdVector.getPositionCount(); groupPos++) { + int valuePos = groupPos + positionOffset; + + int firstGroup = groupIdVector.getFirstValueIndex(groupPos); + int lastGroup = firstGroup + groupIdVector.getValueCount(groupPos); + + for (int g = firstGroup; g < lastGroup; g++) { + int groupId = groupIdVector.getInt(g); + states = driverContext.bigArrays().grow(states, groupId + 1); + var state = states.get(groupId); + if (state == null) { + state = new SimpleLinearRegressionWithTimeseries(); // TODO - what happens for int / long + states.set(groupId, state); + } + long sumTs = sumTsBlock.getLong(valuePos); + $type$ sumVal = sumValBlock.get$Type$(valuePos); + $type$ sumTsVal = sumTsValBlock.get$Type$(valuePos); + long sumTsSq = sumTsSqBlock.getLong(valuePos); + long count = countBlock.getLong(valuePos); + state.sumTs += sumTs; + state.sumVal += sumVal; + state.sumTsVal += sumTsVal; + state.sumTsSq += sumTsSq; + state.count += count; + } + } + + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groupIdVector, Page page) { + $Type$Block sumValBlock = page.getBlock(channels.get(0)); + LongBlock sumTsBlock = page.getBlock(channels.get(1)); + $Type$Block sumTsValBlock = page.getBlock(channels.get(2)); + LongBlock sumTsSqBlock = page.getBlock(channels.get(3)); + LongBlock countBlock = page.getBlock(channels.get(4)); + + if (sumTsBlock.getTotalValueCount() != sumValBlock.getTotalValueCount() + || sumTsBlock.getTotalValueCount() != sumTsValBlock.getTotalValueCount() + || sumTsBlock.getTotalValueCount() != countBlock.getTotalValueCount()) { + throw new IllegalStateException("Mismatched intermediate state block value counts"); + } + + for (int groupPos = 0; groupPos < groupIdVector.getPositionCount(); groupPos++) { + int valuePos = groupPos + positionOffset; + int groupId = groupIdVector.getInt(groupPos); + states = driverContext.bigArrays().grow(states, groupId + 1); + var state = states.get(groupId); + if (state == null) { + state = new SimpleLinearRegressionWithTimeseries(); // TODO: what about double conversion + states.set(groupId, state); + } + long sumTs = sumTsBlock.getLong(valuePos); + $type$ sumVal = sumValBlock.get$Type$(valuePos); + $type$ sumTsVal = sumTsValBlock.get$Type$(valuePos); + long sumTsSq = sumTsSqBlock.getLong(valuePos); + long count = countBlock.getLong(valuePos); + state.sumTs += sumTs; + state.sumVal += sumVal; + state.sumTsVal += sumTsVal; + state.sumTsSq += sumTsSq; + state.count += count; + } + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + BlockFactory blockFactory = driverContext.blockFactory(); + int positionCount = selected.getPositionCount(); + try ( + var sumValBuilder = blockFactory.new$Type$BlockBuilder(positionCount); + var sumTsBuilder = blockFactory.newLongBlockBuilder(positionCount); + var sumTsValBuilder = blockFactory.new$Type$BlockBuilder(positionCount); + var sumTsSqBuilder = blockFactory.newLongBlockBuilder(positionCount); + var countBuilder = blockFactory.newLongBlockBuilder(positionCount) + ) { + for (int p = 0; p < positionCount; p++) { + int groupId = selected.getInt(p); + SimpleLinearRegressionWithTimeseries state = states.get(groupId); + if (state == null) { + sumValBuilder.appendNull(); + sumTsBuilder.appendNull(); + sumTsValBuilder.appendNull(); + sumTsSqBuilder.appendNull(); + countBuilder.appendNull(); + } else { + sumValBuilder.append$Type$(($type$) state.sumVal); + sumTsBuilder.appendLong(state.sumTs); + sumTsValBuilder.append$Type$(($type$) state.sumTsVal); // TODO: fix this actually + sumTsSqBuilder.appendLong(state.sumTsSq); + countBuilder.appendLong(state.count); + } + } + blocks[offset] = sumValBuilder.build(); + blocks[offset + 1] = sumTsBuilder.build(); + blocks[offset + 2] = sumTsValBuilder.build(); + blocks[offset + 3] = sumTsSqBuilder.build(); + blocks[offset + 4] = countBuilder.build(); + } + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, GroupingAggregatorEvaluationContext evaluationContext) { + BlockFactory blockFactory = driverContext.blockFactory(); + int positionCount = selected.getPositionCount(); + try (var resultBuilder = blockFactory.newDoubleBlockBuilder(positionCount)) { + for (int p = 0; p < positionCount; p++) { + int groupId = selected.getInt(p); + SimpleLinearRegressionWithTimeseries state = states.get(groupId); + if (state == null) { + resultBuilder.appendNull(); + } else { + double deriv = state.slope(); + resultBuilder.appendDouble(deriv); + } + } + blocks[offset] = resultBuilder.build(); + } + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void close() { + Releasables.close(states); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java index 27cbb47fcb727..65ed3ef410069 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java @@ -10,6 +10,8 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.DerivGroupingAggregatorFunction; +import org.elasticsearch.compute.aggregation.DerivIntGroupingAggregatorFunction; +import org.elasticsearch.compute.aggregation.DerivLongGroupingAggregatorFunction; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute; @@ -96,6 +98,12 @@ public String getWriteableName() { @Override public AggregatorFunctionSupplier supplier() { - return new DerivGroupingAggregatorFunction.Supplier(); + final DataType type = field().dataType(); + return switch (type) { + case INTEGER -> new DerivIntGroupingAggregatorFunction.Supplier(); + case LONG -> new DerivLongGroupingAggregatorFunction.Supplier(); + case DOUBLE -> new DerivGroupingAggregatorFunction.Supplier(); + default -> throw new IllegalArgumentException("Unsupported data type for deriv aggregation: " + type); + }; } } From e1f5dfe6f32908058b440267cd1b0a17d0cf86d4 Mon Sep 17 00:00:00 2001 From: Pablo Date: Fri, 24 Oct 2025 21:17:27 -0700 Subject: [PATCH 04/27] fixup --- .../aggregation/DerivDoubleGroupingAggregatorFunction.java | 2 +- .../compute/aggregation/DerivIntGroupingAggregatorFunction.java | 2 +- .../aggregation/DerivLongGroupingAggregatorFunction.java | 2 +- .../aggregation/X-DerivGroupingAggregatorFunction.java.st | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java index 473a83488eb3b..b6d9879ae4d61 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java @@ -58,7 +58,7 @@ public AggregatorFunction aggregator(DriverContext driverContext, List @Override public GroupingAggregatorFunction groupingAggregator(DriverContext driverContext, List channels) { - return new DerivGroupingAggregatorFunction(channels, driverContext); + return new DerivDoubleGroupingAggregatorFunction(channels, driverContext); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java index de82eca0471c6..58cfe2b6e731c 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java @@ -58,7 +58,7 @@ public AggregatorFunction aggregator(DriverContext driverContext, List @Override public GroupingAggregatorFunction groupingAggregator(DriverContext driverContext, List channels) { - return new DerivGroupingAggregatorFunction(channels, driverContext); + return new DerivIntGroupingAggregatorFunction(channels, driverContext); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java index 5290843ce2635..ca26cbcbe7569 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java @@ -58,7 +58,7 @@ public AggregatorFunction aggregator(DriverContext driverContext, List @Override public GroupingAggregatorFunction groupingAggregator(DriverContext driverContext, List channels) { - return new DerivGroupingAggregatorFunction(channels, driverContext); + return new DerivLongGroupingAggregatorFunction(channels, driverContext); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-DerivGroupingAggregatorFunction.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-DerivGroupingAggregatorFunction.java.st index 9559f71c23740..9e975613ce802 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-DerivGroupingAggregatorFunction.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-DerivGroupingAggregatorFunction.java.st @@ -58,7 +58,7 @@ public class Deriv$Type$GroupingAggregatorFunction implements GroupingAggregator @Override public GroupingAggregatorFunction groupingAggregator(DriverContext driverContext, List channels) { - return new DerivGroupingAggregatorFunction(channels, driverContext); + return new Deriv$Type$GroupingAggregatorFunction(channels, driverContext); } @Override From a3c9610e60286f9940f169a590cb9acb2e8d598d Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Sat, 25 Oct 2025 04:24:14 +0000 Subject: [PATCH 05/27] [CI] Auto commit changes from spotless --- .../aggregation/DerivDoubleGroupingAggregatorFunction.java | 1 - .../aggregation/DerivIntGroupingAggregatorFunction.java | 3 --- .../aggregation/DerivLongGroupingAggregatorFunction.java | 2 -- 3 files changed, 6 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java index b6d9879ae4d61..3550578673e32 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java @@ -11,7 +11,6 @@ import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; -import org.elasticsearch.compute.data.LongVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasables; diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java index 58cfe2b6e731c..2abd6dfcb0463 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java @@ -3,15 +3,12 @@ import org.elasticsearch.common.util.ObjectArray; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; -import org.elasticsearch.compute.data.DoubleBlock; -import org.elasticsearch.compute.data.DoubleVector; import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.IntArrayBlock; import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; -import org.elasticsearch.compute.data.LongVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasables; diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java index ca26cbcbe7569..3fb34babb61e5 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java @@ -3,8 +3,6 @@ import org.elasticsearch.common.util.ObjectArray; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; -import org.elasticsearch.compute.data.DoubleBlock; -import org.elasticsearch.compute.data.DoubleVector; import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.IntArrayBlock; import org.elasticsearch.compute.data.IntBigArrayBlock; From 2ae4e423a930b419f81bc7eefae505158d32d6ec Mon Sep 17 00:00:00 2001 From: Pablo Date: Fri, 24 Oct 2025 21:18:06 -0700 Subject: [PATCH 06/27] fixup --- .../DerivGroupingAggregatorFunction.java | 293 ------------------ 1 file changed, 293 deletions(-) delete mode 100644 x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivGroupingAggregatorFunction.java diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivGroupingAggregatorFunction.java deleted file mode 100644 index 769c9bfd5f919..0000000000000 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivGroupingAggregatorFunction.java +++ /dev/null @@ -1,293 +0,0 @@ -package org.elasticsearch.compute.aggregation; - -import org.elasticsearch.common.util.ObjectArray; -import org.elasticsearch.compute.data.Block; -import org.elasticsearch.compute.data.BlockFactory; -import org.elasticsearch.compute.data.DoubleBlock; -import org.elasticsearch.compute.data.DoubleVector; -import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; -import org.elasticsearch.compute.data.IntBlock; -import org.elasticsearch.compute.data.IntVector; -import org.elasticsearch.compute.data.LongBlock; -import org.elasticsearch.compute.data.Page; -import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.core.Releasables; - -import java.util.List; - -public class DerivGroupingAggregatorFunction implements GroupingAggregatorFunction { - - private static final List INTERMEDIATE_STATE_DESC = List.of( - new IntermediateStateDesc("sumVal", ElementType.DOUBLE), - new IntermediateStateDesc("sumTs", ElementType.LONG), - new IntermediateStateDesc("sumTsVal", ElementType.DOUBLE), - new IntermediateStateDesc("sumTsSq", ElementType.LONG), - new IntermediateStateDesc("count", ElementType.LONG) - ); - - private final List channels; - private final DriverContext driverContext; - private ObjectArray states; - - public DerivGroupingAggregatorFunction(List channels, DriverContext driverContext) { - this.states = driverContext.bigArrays().newObjectArray(256); - this.channels = channels; - this.driverContext = driverContext; - } - - public static class Supplier implements AggregatorFunctionSupplier { - - @Override - public List nonGroupingIntermediateStateDesc() { - throw new UnsupportedOperationException("DerivGroupingAggregatorFunction does not support non-grouping aggregation"); - } - - @Override - public List groupingIntermediateStateDesc() { - return INTERMEDIATE_STATE_DESC; - } - - @Override - public AggregatorFunction aggregator(DriverContext driverContext, List channels) { - throw new UnsupportedOperationException("DerivGroupingAggregatorFunction does not support non-grouping aggregation"); - } - - @Override - public GroupingAggregatorFunction groupingAggregator(DriverContext driverContext, List channels) { - return new DerivGroupingAggregatorFunction(channels, driverContext); - } - - @Override - public String describe() { - return "derivative"; - } - } - - @Override - public AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { - final DoubleBlock valueBlock = page.getBlock(channels.get(0)); - final LongBlock timestampBlock = page.getBlock(channels.get(1)); - final DoubleVector valueVector = valueBlock.asVector(); - return new AddInput() { - @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addBlockInput(positionOffset, groupIds, timestampBlock, valueVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { - addBlockInput(positionOffset, groupIds, timestampBlock, valueVector); - } - - @Override - public void add(int positionOffset, IntVector groupIds) { - for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { - int valuesPosition = groupPosition + positionOffset; - int groupId = groupIds.getInt(groupPosition); - double vValue = valueVector.getDouble(valuesPosition); - long ts = timestampBlock.getLong(valuesPosition); - SimpleLinearRegressionWithTimeseries state = states.get(groupId); - if (state == null) { - state = new SimpleLinearRegressionWithTimeseries(); - states.set(groupId, state); - } - state.add(ts, vValue); - } - } - - @Override - public void close() { - - } - - private void addBlockInput(int positionOffset, IntBlock groupIds, LongBlock timestampBlock, DoubleVector valueVector) { - for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { - if (groupIds.isNull(groupPosition)) { - continue; - } - int valuePosition = groupPosition + positionOffset; - if (valueBlock.isNull(valuePosition)) { - continue; - } - int groupStart = groupIds.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groupIds.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groupIds.getInt(g); - int vStart = valueBlock.getFirstValueIndex(valuePosition); - int vEnd = vStart + valueBlock.getValueCount(valuePosition); - for (int v = vStart; v < vEnd; v++) { - long ts = timestampBlock.getLong(valuePosition); - double val = valueVector.getDouble(valuePosition); - SimpleLinearRegressionWithTimeseries state = states.get(groupId); - if (state == null) { - state = new SimpleLinearRegressionWithTimeseries(); - states.set(groupId, state); - } - state.add(ts, val); - } - } - } - } - }; - } - - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - // No-op - } - - @Override - public void addIntermediateInput(int positionOffset, IntArrayBlock groupIdVector, Page page) { - addIntermediateBlockInput(positionOffset, groupIdVector, page); - } - - @Override - public void addIntermediateInput(int positionOffset, IntBigArrayBlock groupIdVector, Page page) { - addIntermediateBlockInput(positionOffset, groupIdVector, page); - } - - private void addIntermediateBlockInput(int positionOffset, IntBlock groupIdVector, Page page) { - DoubleBlock sumValBlock = page.getBlock(channels.get(0)); - LongBlock sumTsBlock = page.getBlock(channels.get(1)); - DoubleBlock sumTsValBlock = page.getBlock(channels.get(2)); - LongBlock sumTsSqBlock = page.getBlock(channels.get(3)); - LongBlock countBlock = page.getBlock(channels.get(4)); - - if (sumTsBlock.getTotalValueCount() != sumValBlock.getTotalValueCount() - || sumTsBlock.getTotalValueCount() != sumTsValBlock.getTotalValueCount() - || sumTsBlock.getTotalValueCount() != countBlock.getTotalValueCount()) { - throw new IllegalStateException("Mismatched intermediate state block value counts"); - } - - for (int groupPos = 0; groupPos < groupIdVector.getPositionCount(); groupPos++) { - int valuePos = groupPos + positionOffset; - - int firstGroup = groupIdVector.getFirstValueIndex(groupPos); - int lastGroup = firstGroup + groupIdVector.getValueCount(groupPos); - - for (int g = firstGroup; g < lastGroup; g++) { - int groupId = groupIdVector.getInt(g); - states = driverContext.bigArrays().grow(states, groupId + 1); - var state = states.get(groupId); - if (state == null) { - state = new SimpleLinearRegressionWithTimeseries(); - states.set(groupId, state); - } - long sumTs = sumTsBlock.getLong(valuePos); - double sumVal = sumValBlock.getDouble(valuePos); - double sumTsVal = sumTsValBlock.getDouble(valuePos); - long sumTsSq = sumTsSqBlock.getLong(valuePos); - long count = countBlock.getLong(valuePos); - state.sumTs += sumTs; - state.sumVal += sumVal; - state.sumTsVal += sumTsVal; - state.sumTsSq += sumTsSq; - state.count += count; - } - } - - } - - @Override - public void addIntermediateInput(int positionOffset, IntVector groupIdVector, Page page) { - DoubleBlock sumValBlock = page.getBlock(channels.get(0)); - LongBlock sumTsBlock = page.getBlock(channels.get(1)); - DoubleBlock sumTsValBlock = page.getBlock(channels.get(2)); - LongBlock sumTsSqBlock = page.getBlock(channels.get(3)); - LongBlock countBlock = page.getBlock(channels.get(4)); - - if (sumTsBlock.getTotalValueCount() != sumValBlock.getTotalValueCount() - || sumTsBlock.getTotalValueCount() != sumTsValBlock.getTotalValueCount() - || sumTsBlock.getTotalValueCount() != countBlock.getTotalValueCount()) { - throw new IllegalStateException("Mismatched intermediate state block value counts"); - } - - for (int groupPos = 0; groupPos < groupIdVector.getPositionCount(); groupPos++) { - int valuePos = groupPos + positionOffset; - int groupId = groupIdVector.getInt(groupPos); - states = driverContext.bigArrays().grow(states, groupId + 1); - var state = states.get(groupId); - if (state == null) { - state = new SimpleLinearRegressionWithTimeseries(); - states.set(groupId, state); - } - long sumTs = sumTsBlock.getLong(valuePos); - double sumVal = sumValBlock.getDouble(valuePos); - double sumTsVal = sumTsValBlock.getDouble(valuePos); - long sumTsSq = sumTsSqBlock.getLong(valuePos); - long count = countBlock.getLong(valuePos); - state.sumTs += sumTs; - state.sumVal += sumVal; - state.sumTsVal += sumTsVal; - state.sumTsSq += sumTsSq; - state.count += count; - } - } - - @Override - public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { - BlockFactory blockFactory = driverContext.blockFactory(); - int positionCount = selected.getPositionCount(); - try ( - var sumValBuilder = blockFactory.newDoubleBlockBuilder(positionCount); - var sumTsBuilder = blockFactory.newLongBlockBuilder(positionCount); - var sumTsValBuilder = blockFactory.newDoubleBlockBuilder(positionCount); - var sumTsSqBuilder = blockFactory.newLongBlockBuilder(positionCount); - var countBuilder = blockFactory.newLongBlockBuilder(positionCount) - ) { - for (int p = 0; p < positionCount; p++) { - int groupId = selected.getInt(p); - SimpleLinearRegressionWithTimeseries state = states.get(groupId); - if (state == null) { - sumValBuilder.appendNull(); - sumTsBuilder.appendNull(); - sumTsValBuilder.appendNull(); - sumTsSqBuilder.appendNull(); - countBuilder.appendNull(); - } else { - sumValBuilder.appendDouble(state.sumVal); - sumTsBuilder.appendLong(state.sumTs); - sumTsValBuilder.appendDouble(state.sumTsVal); - sumTsSqBuilder.appendLong(state.sumTsSq); - countBuilder.appendLong(state.count); - } - } - blocks[offset] = sumValBuilder.build(); - blocks[offset + 1] = sumTsBuilder.build(); - blocks[offset + 2] = sumTsValBuilder.build(); - blocks[offset + 3] = sumTsSqBuilder.build(); - blocks[offset + 4] = countBuilder.build(); - } - } - - @Override - public void evaluateFinal(Block[] blocks, int offset, IntVector selected, GroupingAggregatorEvaluationContext evaluationContext) { - BlockFactory blockFactory = driverContext.blockFactory(); - int positionCount = selected.getPositionCount(); - try (var resultBuilder = blockFactory.newDoubleBlockBuilder(positionCount)) { - for (int p = 0; p < positionCount; p++) { - int groupId = selected.getInt(p); - SimpleLinearRegressionWithTimeseries state = states.get(groupId); - if (state == null) { - resultBuilder.appendNull(); - } else { - double deriv = state.slope(); - resultBuilder.appendDouble(deriv); - } - } - blocks[offset] = resultBuilder.build(); - } - } - - @Override - public int intermediateBlockCount() { - return INTERMEDIATE_STATE_DESC.size(); - } - - @Override - public void close() { - Releasables.close(states); - } -} From ef79b97362b0277f7c3675f8d29e1ab4a333b974 Mon Sep 17 00:00:00 2001 From: Pablo Date: Sat, 25 Oct 2025 10:35:58 -0700 Subject: [PATCH 07/27] capab --- .../main/resources/k8s-timeseries.csv-spec | 45 ++++++++++++++----- .../xpack/esql/action/EsqlCapabilities.java | 1 + .../expression/function/aggregate/Deriv.java | 4 +- 3 files changed, 36 insertions(+), 14 deletions(-) diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec index f2a303fabfaaa..823633d7a4fb0 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec @@ -723,22 +723,43 @@ mx:integer | tbucket:datetime derivative_of_gauge_metric required_capability: ts_command_v0 +required_capability: ts_linreg TS k8s +| WHERE pod == "three" | STATS max_deriv = max(deriv(network.cost)) BY time_bucket = bucket(@timestamp,5minute), pod +| EVAL max_deriv = ROUND(max_deriv,6) +| KEEP max_deriv, time_bucket, pod | SORT pod, time_bucket -| LIMIT 10 +| LIMIT 5 +; + +max_deriv:double | time_bucket:datetime | pod:keyword +9.3E-5 | 2024-05-10T00:00:00.000Z | three +3.8E-5 | 2024-05-10T00:05:00.000Z | three +-2.0E-5 | 2024-05-10T00:10:00.000Z | three +-1.0E-6 | 2024-05-10T00:15:00.000Z | three +0.0 | 2024-05-10T00:20:00.000Z | three + ; -max_deriv:double | time_bucket:datetime | pod:keyword -2.7573529411764707E-5 | 2024-05-10T00:00:00.000Z | one --9.657960185083621E-6 | 2024-05-10T00:05:00.000Z | one -2.6771965095388827E-5 | 2024-05-10T00:10:00.000Z | one -4.405946549223768E-5 | 2024-05-10T00:15:00.000Z | one -8.564814814814814E-5 | 2024-05-10T00:20:00.000Z | one -9.27740599107712E-5 | 2024-05-10T00:00:00.000Z | three -3.80263223304832E-5 | 2024-05-10T00:05:00.000Z | three --1.9699890788329952E-5 | 2024-05-10T00:10:00.000Z | three --1.2087599544937429E-6 | 2024-05-10T00:15:00.000Z | three -NaN | 2024-05-10T00:20:00.000Z | three +derivative_compared_to_rate +required_capability: ts_command_v0 +required_capability: ts_linreg + +TS k8s +| STATS max_deriv = max(deriv(to_long(network.total_bytes_in))), max_rate = max(rate(network.total_bytes_in)) BY time_bucket = bucket(@timestamp,5minute), cluster +| EVAL max_deriv = ROUND(max_deriv,6), max_rate = ROUND(max_rate,6) +| KEEP max_deriv, max_rate, time_bucket, cluster +| SORT cluster, time_bucket +| LIMIT 5 +; + +max_deriv:double | max_rate:double | time_bucket:datetime | cluster:keyword +0.0855 | 8.120833 | 2024-05-10T00:00:00.000Z | prod +0.004933 | 6.451737 | 2024-05-10T00:05:00.000Z | prod +0.008922 | 11.562738 | 2024-05-10T00:10:00.000Z | prod +0.016623 | 11.860806 | 2024-05-10T00:15:00.000Z | prod +0.0 | 6.980661 | 2024-05-10T00:20:00.000Z | prod + ; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 36f111b370b36..e775210a15608 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -1484,6 +1484,7 @@ public enum Cap { */ PERCENTILE_OVER_TIME, VARIANCE_STDDEV_OVER_TIME, + TS_LINREG, /** * INLINE STATS fix incorrect prunning of null filtering * https://github.com/elastic/elasticsearch/pull/135011 diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java index 65ed3ef410069..aee2fde9c4881 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java @@ -9,7 +9,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; -import org.elasticsearch.compute.aggregation.DerivGroupingAggregatorFunction; +import org.elasticsearch.compute.aggregation.DerivDoubleGroupingAggregatorFunction; import org.elasticsearch.compute.aggregation.DerivIntGroupingAggregatorFunction; import org.elasticsearch.compute.aggregation.DerivLongGroupingAggregatorFunction; import org.elasticsearch.xpack.esql.core.expression.Expression; @@ -102,7 +102,7 @@ public AggregatorFunctionSupplier supplier() { return switch (type) { case INTEGER -> new DerivIntGroupingAggregatorFunction.Supplier(); case LONG -> new DerivLongGroupingAggregatorFunction.Supplier(); - case DOUBLE -> new DerivGroupingAggregatorFunction.Supplier(); + case DOUBLE -> new DerivDoubleGroupingAggregatorFunction.Supplier(); default -> throw new IllegalArgumentException("Unsupported data type for deriv aggregation: " + type); }; } From 7d4a1276af7466eccc2fdfb4ab8d91bcab0f1265 Mon Sep 17 00:00:00 2001 From: Pablo Date: Sat, 25 Oct 2025 10:59:28 -0700 Subject: [PATCH 08/27] fixup --- ...DerivDoubleGroupingAggregatorFunction.java | 6 ++ .../DerivIntGroupingAggregatorFunction.java | 6 ++ .../DerivLongGroupingAggregatorFunction.java | 6 ++ .../SimpleLinearRegressionWithTimeseries.java | 6 ++ ...pleLinearRegressionWithTimeseries.java.new | 73 +++++++++++++++++++ .../X-DerivGroupingAggregatorFunction.java.st | 6 ++ 6 files changed, 103 insertions(+) create mode 100644 x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SimpleLinearRegressionWithTimeseries.java.new diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java index 3550578673e32..0213cf9476bd8 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java @@ -1,3 +1,9 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ package org.elasticsearch.compute.aggregation; import org.elasticsearch.common.util.ObjectArray; diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java index 2abd6dfcb0463..78c7f7b89f581 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java @@ -1,3 +1,9 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ package org.elasticsearch.compute.aggregation; import org.elasticsearch.common.util.ObjectArray; diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java index 3fb34babb61e5..9d1ac640c362d 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java @@ -1,3 +1,9 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ package org.elasticsearch.compute.aggregation; import org.elasticsearch.common.util.ObjectArray; diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SimpleLinearRegressionWithTimeseries.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SimpleLinearRegressionWithTimeseries.java index f0cef837d5090..11b5a585016b2 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SimpleLinearRegressionWithTimeseries.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SimpleLinearRegressionWithTimeseries.java @@ -1,3 +1,9 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ package org.elasticsearch.compute.aggregation; class SimpleLinearRegressionWithTimeseries { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SimpleLinearRegressionWithTimeseries.java.new b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SimpleLinearRegressionWithTimeseries.java.new new file mode 100644 index 0000000000000..c00d6d65c1cb8 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SimpleLinearRegressionWithTimeseries.java.new @@ -0,0 +1,73 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ +package org.elasticsearch.compute.aggregation; + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +class SimpleLinearRegressionWithTimeseries { + long count; + double sumVal; + long sumTs; + double sumTsVal; + long sumTsSq; + + SimpleLinearRegressionWithTimeseries() { + this.count = 0; + this.sumVal = 0.0; + this.sumTs = 0; + this.sumTsVal = 0.0; + this.sumTsSq = 0; + } + + void add(long ts, double val) { + count++; + sumVal += val; + sumTs += ts; + sumTsVal += ts * val; + sumTsSq += ts * ts; + } + + double slope() { + if (count <= 1) { + return Double.NaN; + } + double numerator = count * sumTsVal - sumTs * sumVal; + double denominator = count * sumTsSq - sumTs * sumTs; + if (denominator == 0) { + return Double.NaN; + } + return numerator / denominator; + } + + double intercept() { + if (count == 0) { + return 0.0; // or handle as needed + } + return (sumVal - slope() * sumTs) / count; + } + +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-DerivGroupingAggregatorFunction.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-DerivGroupingAggregatorFunction.java.st index 9e975613ce802..dc2cb0a0fa004 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-DerivGroupingAggregatorFunction.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-DerivGroupingAggregatorFunction.java.st @@ -1,3 +1,9 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ package org.elasticsearch.compute.aggregation; import org.elasticsearch.common.util.ObjectArray; From 1bdf985056d5c5cae35d440d60933c1fe96e15e8 Mon Sep 17 00:00:00 2001 From: Pablo Date: Sat, 25 Oct 2025 19:52:27 -0700 Subject: [PATCH 09/27] fixup --- .../aggregation/DerivDoubleGroupingAggregatorFunction.java | 3 +++ .../aggregation/DerivIntGroupingAggregatorFunction.java | 5 +++++ .../aggregation/DerivLongGroupingAggregatorFunction.java | 4 ++++ .../aggregation/X-DerivGroupingAggregatorFunction.java.st | 2 ++ 4 files changed, 14 insertions(+) diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java index 0213cf9476bd8..043a4a59f0b5a 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java @@ -6,6 +6,7 @@ */ package org.elasticsearch.compute.aggregation; +// begin generated imports import org.elasticsearch.common.util.ObjectArray; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; @@ -17,11 +18,13 @@ import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasables; import java.util.List; +// end generated imports @SuppressWarnings("cast") public class DerivDoubleGroupingAggregatorFunction implements GroupingAggregatorFunction { diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java index 78c7f7b89f581..90af60dc2930b 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java @@ -6,20 +6,25 @@ */ package org.elasticsearch.compute.aggregation; +// begin generated imports import org.elasticsearch.common.util.ObjectArray; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.IntArrayBlock; import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasables; import java.util.List; +// end generated imports @SuppressWarnings("cast") public class DerivIntGroupingAggregatorFunction implements GroupingAggregatorFunction { diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java index 9d1ac640c362d..2745cf2fb323c 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java @@ -6,9 +6,12 @@ */ package org.elasticsearch.compute.aggregation; +// begin generated imports import org.elasticsearch.common.util.ObjectArray; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.IntArrayBlock; import org.elasticsearch.compute.data.IntBigArrayBlock; @@ -21,6 +24,7 @@ import org.elasticsearch.core.Releasables; import java.util.List; +// end generated imports @SuppressWarnings("cast") public class DerivLongGroupingAggregatorFunction implements GroupingAggregatorFunction { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-DerivGroupingAggregatorFunction.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-DerivGroupingAggregatorFunction.java.st index dc2cb0a0fa004..af28dd81f498a 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-DerivGroupingAggregatorFunction.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-DerivGroupingAggregatorFunction.java.st @@ -6,6 +6,7 @@ */ package org.elasticsearch.compute.aggregation; +// begin generated imports import org.elasticsearch.common.util.ObjectArray; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; @@ -23,6 +24,7 @@ import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasables; import java.util.List; +// end generated imports @SuppressWarnings("cast") public class Deriv$Type$GroupingAggregatorFunction implements GroupingAggregatorFunction { From df94605f2f3a6caf9771c1a246d1e9656db2e100 Mon Sep 17 00:00:00 2001 From: Pablo Date: Mon, 27 Oct 2025 10:27:28 -0700 Subject: [PATCH 10/27] try capability rename --- .../testFixtures/src/main/resources/k8s-timeseries.csv-spec | 4 ++-- .../org/elasticsearch/xpack/esql/action/EsqlCapabilities.java | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec index 823633d7a4fb0..fa455ef524f42 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec @@ -723,7 +723,7 @@ mx:integer | tbucket:datetime derivative_of_gauge_metric required_capability: ts_command_v0 -required_capability: ts_linreg +required_capability: TS_LINREG_DERIVATIVE TS k8s | WHERE pod == "three" @@ -745,7 +745,7 @@ max_deriv:double | time_bucket:datetime | pod:keyword derivative_compared_to_rate required_capability: ts_command_v0 -required_capability: ts_linreg +required_capability: TS_LINREG_DERIVATIVE TS k8s | STATS max_deriv = max(deriv(to_long(network.total_bytes_in))), max_rate = max(rate(network.total_bytes_in)) BY time_bucket = bucket(@timestamp,5minute), cluster diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index e775210a15608..5912db51ee046 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -1484,7 +1484,7 @@ public enum Cap { */ PERCENTILE_OVER_TIME, VARIANCE_STDDEV_OVER_TIME, - TS_LINREG, + TS_LINREG_DERIVATIVE, /** * INLINE STATS fix incorrect prunning of null filtering * https://github.com/elastic/elasticsearch/pull/135011 From 87eeb7b496b4c7e39f9841cfa1b327f492c39d74 Mon Sep 17 00:00:00 2001 From: Pablo Date: Mon, 3 Nov 2025 13:38:21 -0800 Subject: [PATCH 11/27] rebase --- .../xpack/esql/expression/function/aggregate/Deriv.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java index aee2fde9c4881..867050b1b0b97 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java @@ -43,12 +43,12 @@ public Deriv(Source source, @Param(name = "field", type = { "long", "integer", " } public Deriv(Source source, Expression field, Expression timestamp) { - super(source, field, Literal.TRUE, List.of(timestamp)); + super(source, field, Literal.TRUE, NO_WINDOW, List.of(timestamp)); this.timestamp = timestamp; } public Deriv(Source source, Expression field, Expression filter, Expression timestamp) { - super(source, field, filter, List.of(timestamp)); + super(source, field, filter, NO_WINDOW, List.of(timestamp)); this.timestamp = timestamp; } From 77855e7d28a757b75b1cc773c0b72ca79514cd8c Mon Sep 17 00:00:00 2001 From: Pablo Date: Mon, 3 Nov 2025 15:04:40 -0800 Subject: [PATCH 12/27] fixup --- .../expression/function/aggregate/Deriv.java | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java index 867050b1b0b97..ebdfd1ca511c8 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java @@ -43,12 +43,11 @@ public Deriv(Source source, @Param(name = "field", type = { "long", "integer", " } public Deriv(Source source, Expression field, Expression timestamp) { - super(source, field, Literal.TRUE, NO_WINDOW, List.of(timestamp)); - this.timestamp = timestamp; + this(source, field, Literal.TRUE, timestamp, NO_WINDOW); } - public Deriv(Source source, Expression field, Expression filter, Expression timestamp) { - super(source, field, filter, NO_WINDOW, List.of(timestamp)); + public Deriv(Source source, Expression field, Expression filter, Expression window, Expression timestamp) { + super(source, field, filter, window, List.of(timestamp)); this.timestamp = timestamp; } @@ -57,6 +56,7 @@ private Deriv(org.elasticsearch.common.io.stream.StreamInput in) throws java.io. Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class), in.readNamedWriteable(Expression.class), + in.readNamedWriteable(Expression.class), in.readNamedWriteable(Expression.class) ); } @@ -68,7 +68,7 @@ public AggregateFunction perTimeSeriesAggregation() { @Override public AggregateFunction withFilter(Expression filter) { - return new Deriv(source(), field(), filter, timestamp); + return new Deriv(source(), field(), filter, timestamp, window()); } @Override @@ -78,17 +78,19 @@ public DataType dataType() { @Override public Expression replaceChildren(List newChildren) { - if (newChildren.size() == 3) { - return new Deriv(source(), newChildren.get(0), newChildren.get(1), newChildren.get(2)); + if (newChildren.size() == 4) { + return new Deriv(source(), newChildren.get(0), newChildren.get(1), newChildren.get(2), newChildren.get(3)); + } else if (newChildren.size() == 3) { + return new Deriv(source(), newChildren.get(0), newChildren.get(1), newChildren.get(2), timestamp); } else { - assert newChildren.size() == 2; + assert newChildren.size() == 2 : "Expected 2, 3, 4 children but got " + newChildren.size(); return new Deriv(source(), newChildren.get(0), newChildren.get(1)); } } @Override protected NodeInfo info() { - return NodeInfo.create(this, Deriv::new, field(), filter(), timestamp); + return NodeInfo.create(this, Deriv::new, field(), filter(), window(), timestamp); } @Override From e8c6280dbafe17532052dbc9d531e61de5c4f0e9 Mon Sep 17 00:00:00 2001 From: Pablo Date: Mon, 10 Nov 2025 14:38:59 -0800 Subject: [PATCH 13/27] comments --- .../GroupingAggregatorFunction.java | 5 +- .../SimpleLinearRegressionWithTimeseries.java | 6 +- ...pleLinearRegressionWithTimeseries.java.new | 73 ------------------- .../expression/function/aggregate/Deriv.java | 11 +-- 4 files changed, 10 insertions(+), 85 deletions(-) delete mode 100644 x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SimpleLinearRegressionWithTimeseries.java.new diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunction.java index 2170b587172fb..3f8088be1800a 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunction.java @@ -143,8 +143,9 @@ default void add(int positionOffset, IntBlock groupIds) { * @param selected the groupIds that have been selected to be included in * the results. Always ascending. * - *

This function is called in the coordinator node after all intermediate - * blocks have been gathered from the data nodes.

+ *

This function may be called in the coordinator node after all intermediate + * blocks have been gathered from the data nodes, or on data nodes during + * node-level or cluster-level reduction with intermediate input to intermediate output.

*/ void evaluateIntermediate(Block[] blocks, int offset, IntVector selected); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SimpleLinearRegressionWithTimeseries.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SimpleLinearRegressionWithTimeseries.java index 11b5a585016b2..42ffdd52d6323 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SimpleLinearRegressionWithTimeseries.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SimpleLinearRegressionWithTimeseries.java @@ -45,7 +45,11 @@ void add(long ts, double val) { if (count == 0) { return 0.0; // or handle as needed } - return (sumVal - slope() * sumTs) / count; + var slp = slope(); + if (Double.isNaN(slp)) { + return Double.NaN; + } + return (sumVal - slp * sumTs) / count; } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SimpleLinearRegressionWithTimeseries.java.new b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SimpleLinearRegressionWithTimeseries.java.new deleted file mode 100644 index c00d6d65c1cb8..0000000000000 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SimpleLinearRegressionWithTimeseries.java.new +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ -package org.elasticsearch.compute.aggregation; - -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - - -class SimpleLinearRegressionWithTimeseries { - long count; - double sumVal; - long sumTs; - double sumTsVal; - long sumTsSq; - - SimpleLinearRegressionWithTimeseries() { - this.count = 0; - this.sumVal = 0.0; - this.sumTs = 0; - this.sumTsVal = 0.0; - this.sumTsSq = 0; - } - - void add(long ts, double val) { - count++; - sumVal += val; - sumTs += ts; - sumTsVal += ts * val; - sumTsSq += ts * ts; - } - - double slope() { - if (count <= 1) { - return Double.NaN; - } - double numerator = count * sumTsVal - sumTs * sumVal; - double denominator = count * sumTsSq - sumTs * sumTs; - if (denominator == 0) { - return Double.NaN; - } - return numerator / denominator; - } - - double intercept() { - if (count == 0) { - return 0.0; // or handle as needed - } - return (sumVal - slope() * sumTs) / count; - } - -} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java index ebdfd1ca511c8..54f4f9f0c2508 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java @@ -43,7 +43,7 @@ public Deriv(Source source, @Param(name = "field", type = { "long", "integer", " } public Deriv(Source source, Expression field, Expression timestamp) { - this(source, field, Literal.TRUE, timestamp, NO_WINDOW); + this(source, field, Literal.TRUE, NO_WINDOW, timestamp); } public Deriv(Source source, Expression field, Expression filter, Expression window, Expression timestamp) { @@ -78,14 +78,7 @@ public DataType dataType() { @Override public Expression replaceChildren(List newChildren) { - if (newChildren.size() == 4) { - return new Deriv(source(), newChildren.get(0), newChildren.get(1), newChildren.get(2), newChildren.get(3)); - } else if (newChildren.size() == 3) { - return new Deriv(source(), newChildren.get(0), newChildren.get(1), newChildren.get(2), timestamp); - } else { - assert newChildren.size() == 2 : "Expected 2, 3, 4 children but got " + newChildren.size(); - return new Deriv(source(), newChildren.get(0), newChildren.get(1)); - } + return new Deriv(source(), newChildren.get(0), newChildren.get(1), newChildren.get(2), newChildren.get(3)); } @Override From 1a8e3d73a468885a35a8866db1ed52c8008ba42f Mon Sep 17 00:00:00 2001 From: Pablo Date: Mon, 10 Nov 2025 14:42:49 -0800 Subject: [PATCH 14/27] resolvetype --- .../esql/expression/function/aggregate/Deriv.java | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java index 54f4f9f0c2508..4c42ba8974234 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java @@ -14,6 +14,7 @@ import org.elasticsearch.compute.aggregation.DerivLongGroupingAggregatorFunction; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -26,6 +27,9 @@ import java.util.List; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT; +import static org.elasticsearch.xpack.esql.core.type.DataType.AGGREGATE_METRIC_DOUBLE; + /** * Calculates the derivative over time of a numeric field using linear regression. */ @@ -76,6 +80,17 @@ public DataType dataType() { return DataType.DOUBLE; } + @Override + public TypeResolution resolveType() { + return TypeResolutions.isType( + field(), + dt -> dt.isNumeric() && dt != AGGREGATE_METRIC_DOUBLE, + sourceText(), + DEFAULT, + "numeric except counter types" + ); + } + @Override public Expression replaceChildren(List newChildren) { return new Deriv(source(), newChildren.get(0), newChildren.get(1), newChildren.get(2), newChildren.get(3)); From 3aa6e1d25596ac340b456e276b27f47c02186e39 Mon Sep 17 00:00:00 2001 From: Pablo Date: Mon, 10 Nov 2025 20:53:52 -0800 Subject: [PATCH 15/27] adding docs --- .../_snippets/functions/description/deriv.md | 6 + .../_snippets/functions/examples/deriv.md | 17 +++ .../esql/_snippets/functions/layout/deriv.md | 23 +++ .../_snippets/functions/parameters/deriv.md | 7 + .../esql/_snippets/functions/types/deriv.md | 10 ++ .../time-series-aggregation-functions.md | 1 + .../time-series-aggregation-functions.md | 3 + .../esql/images/functions/deriv.svg | 1 + .../kibana/definition/functions/deriv.json | 49 ++++++ .../esql/kibana/docs/functions/deriv.md | 10 ++ .../main/resources/k8s-timeseries.csv-spec | 5 + .../expression/function/aggregate/Deriv.java | 6 +- .../function/aggregate/DerivTests.java | 140 ++++++++++++++++++ 13 files changed, 276 insertions(+), 2 deletions(-) create mode 100644 docs/reference/query-languages/esql/_snippets/functions/description/deriv.md create mode 100644 docs/reference/query-languages/esql/_snippets/functions/examples/deriv.md create mode 100644 docs/reference/query-languages/esql/_snippets/functions/layout/deriv.md create mode 100644 docs/reference/query-languages/esql/_snippets/functions/parameters/deriv.md create mode 100644 docs/reference/query-languages/esql/_snippets/functions/types/deriv.md create mode 100644 docs/reference/query-languages/esql/images/functions/deriv.svg create mode 100644 docs/reference/query-languages/esql/kibana/definition/functions/deriv.json create mode 100644 docs/reference/query-languages/esql/kibana/docs/functions/deriv.md create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/DerivTests.java diff --git a/docs/reference/query-languages/esql/_snippets/functions/description/deriv.md b/docs/reference/query-languages/esql/_snippets/functions/description/deriv.md new file mode 100644 index 0000000000000..c845d3cbbf141 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/description/deriv.md @@ -0,0 +1,6 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Description** + +Calculates the derivative over time of a numeric field using linear regression. + diff --git a/docs/reference/query-languages/esql/_snippets/functions/examples/deriv.md b/docs/reference/query-languages/esql/_snippets/functions/examples/deriv.md new file mode 100644 index 0000000000000..7c4152da59b9d --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/examples/deriv.md @@ -0,0 +1,17 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Example** + +```esql +TS k8s +| WHERE pod == "three" +| STATS max_deriv = max(deriv(network.cost)) BY time_bucket = bucket(@timestamp,5minute), pod +``` + +| max_deriv:double | time_bucket:datetime | pod:keyword | +| --- | --- | --- | +| 9.3E-5 | 2024-05-10T00:00:00.000Z | three | +| 3.8E-5 | 2024-05-10T00:05:00.000Z | three | +| -2.0E-5 | 2024-05-10T00:10:00.000Z | three | + + diff --git a/docs/reference/query-languages/esql/_snippets/functions/layout/deriv.md b/docs/reference/query-languages/esql/_snippets/functions/layout/deriv.md new file mode 100644 index 0000000000000..09036fe8314d1 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/layout/deriv.md @@ -0,0 +1,23 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +## `DERIV` [esql-deriv] + +**Syntax** + +:::{image} ../../../images/functions/deriv.svg +:alt: Embedded +:class: text-center +::: + + +:::{include} ../parameters/deriv.md +::: + +:::{include} ../description/deriv.md +::: + +:::{include} ../types/deriv.md +::: + +:::{include} ../examples/deriv.md +::: diff --git a/docs/reference/query-languages/esql/_snippets/functions/parameters/deriv.md b/docs/reference/query-languages/esql/_snippets/functions/parameters/deriv.md new file mode 100644 index 0000000000000..24fedc1dde506 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/parameters/deriv.md @@ -0,0 +1,7 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Parameters** + +`field` +: + diff --git a/docs/reference/query-languages/esql/_snippets/functions/types/deriv.md b/docs/reference/query-languages/esql/_snippets/functions/types/deriv.md new file mode 100644 index 0000000000000..566f6b3d786f6 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/types/deriv.md @@ -0,0 +1,10 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Supported types** + +| field | result | +| --- | --- | +| double | double | +| integer | double | +| long | double | + diff --git a/docs/reference/query-languages/esql/_snippets/lists/time-series-aggregation-functions.md b/docs/reference/query-languages/esql/_snippets/lists/time-series-aggregation-functions.md index 67d40be53178c..fda0403b7e428 100644 --- a/docs/reference/query-languages/esql/_snippets/lists/time-series-aggregation-functions.md +++ b/docs/reference/query-languages/esql/_snippets/lists/time-series-aggregation-functions.md @@ -15,3 +15,4 @@ * [`STDDEV_OVER_TIME`](../../functions-operators/time-series-aggregation-functions.md#esql-stddev_over_time) {applies_to}`stack: preview 9.3` {applies_to}`serverless: preview` * [`VARIANCE_OVER_TIME`](../../functions-operators/time-series-aggregation-functions.md#esql-variance_over_time) {applies_to}`stack: preview 9.3` {applies_to}`serverless: preview` * [`SUM_OVER_TIME`](../../functions-operators/time-series-aggregation-functions.md#esql-sum_over_time) {applies_to}`stack: preview 9.2` {applies_to}`serverless: preview` +* [`DERIV`](../../functions-operators/time-series-aggregation-functions.md#esql-deriv) {applies_to}`stack: preview 9.3` {applies_to}`serverless: preview` diff --git a/docs/reference/query-languages/esql/functions-operators/time-series-aggregation-functions.md b/docs/reference/query-languages/esql/functions-operators/time-series-aggregation-functions.md index 141ec5706de3f..8fa531b633803 100644 --- a/docs/reference/query-languages/esql/functions-operators/time-series-aggregation-functions.md +++ b/docs/reference/query-languages/esql/functions-operators/time-series-aggregation-functions.md @@ -63,3 +63,6 @@ supports the following time series aggregation functions: :::{include} ../_snippets/functions/layout/sum_over_time.md ::: + +:::{include} ../_snippets/functions/layout/deriv.md +::: diff --git a/docs/reference/query-languages/esql/images/functions/deriv.svg b/docs/reference/query-languages/esql/images/functions/deriv.svg new file mode 100644 index 0000000000000..01a1a6ad2db4a --- /dev/null +++ b/docs/reference/query-languages/esql/images/functions/deriv.svg @@ -0,0 +1 @@ +DERIV(field) \ No newline at end of file diff --git a/docs/reference/query-languages/esql/kibana/definition/functions/deriv.json b/docs/reference/query-languages/esql/kibana/definition/functions/deriv.json new file mode 100644 index 0000000000000..7305b2707f639 --- /dev/null +++ b/docs/reference/query-languages/esql/kibana/definition/functions/deriv.json @@ -0,0 +1,49 @@ +{ + "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.", + "type" : "time_series_agg", + "name" : "deriv", + "description" : "Calculates the derivative over time of a numeric field using linear regression.", + "signatures" : [ + { + "params" : [ + { + "name" : "field", + "type" : "double", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "field", + "type" : "integer", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "field", + "type" : "long", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "double" + } + ], + "examples" : [ + "TS k8s\n| WHERE pod == \"three\"\n| STATS max_deriv = max(deriv(network.cost)) BY time_bucket = bucket(@timestamp,5minute), pod" + ], + "preview" : false, + "snapshot_only" : false +} diff --git a/docs/reference/query-languages/esql/kibana/docs/functions/deriv.md b/docs/reference/query-languages/esql/kibana/docs/functions/deriv.md new file mode 100644 index 0000000000000..2e7005674be89 --- /dev/null +++ b/docs/reference/query-languages/esql/kibana/docs/functions/deriv.md @@ -0,0 +1,10 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +### DERIV +Calculates the derivative over time of a numeric field using linear regression. + +```esql +TS k8s +| WHERE pod == "three" +| STATS max_deriv = max(deriv(network.cost)) BY time_bucket = bucket(@timestamp,5minute), pod +``` diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec index fa455ef524f42..152000c3731cf 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec @@ -725,19 +725,24 @@ derivative_of_gauge_metric required_capability: ts_command_v0 required_capability: TS_LINREG_DERIVATIVE + +// tag::deriv[] TS k8s | WHERE pod == "three" | STATS max_deriv = max(deriv(network.cost)) BY time_bucket = bucket(@timestamp,5minute), pod +// end::deriv[] | EVAL max_deriv = ROUND(max_deriv,6) | KEEP max_deriv, time_bucket, pod | SORT pod, time_bucket | LIMIT 5 ; +// tag::deriv-result[] max_deriv:double | time_bucket:datetime | pod:keyword 9.3E-5 | 2024-05-10T00:00:00.000Z | three 3.8E-5 | 2024-05-10T00:05:00.000Z | three -2.0E-5 | 2024-05-10T00:10:00.000Z | three +// end::deriv-result[] -1.0E-6 | 2024-05-10T00:15:00.000Z | three 0.0 | 2024-05-10T00:20:00.000Z | three diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java index 4c42ba8974234..2bc869940f8bf 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java @@ -19,6 +19,7 @@ import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.Example; import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.FunctionType; import org.elasticsearch.xpack.esql.expression.function.Param; @@ -40,7 +41,8 @@ public class Deriv extends TimeSeriesAggregateFunction implements ToAggregator { @FunctionInfo( type = FunctionType.TIME_SERIES_AGGREGATE, returnType = { "double" }, - description = "Calculates the derivative over time of a numeric field using linear regression." + description = "Calculates the derivative over time of a numeric field using linear regression.", + examples = { @Example(file = "k8s-timeseries", tag = "deriv") } ) public Deriv(Source source, @Param(name = "field", type = { "long", "integer", "double" }) Expression field) { this(source, field, new UnresolvedAttribute(source, "@timestamp")); @@ -61,7 +63,7 @@ private Deriv(org.elasticsearch.common.io.stream.StreamInput in) throws java.io. in.readNamedWriteable(Expression.class), in.readNamedWriteable(Expression.class), in.readNamedWriteable(Expression.class), - in.readNamedWriteable(Expression.class) + in.readNamedWriteableCollectionAsList(Expression.class).getFirst() ); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/DerivTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/DerivTests.java new file mode 100644 index 0000000000000..c791b8d4ab8eb --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/DerivTests.java @@ -0,0 +1,140 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase; +import org.elasticsearch.xpack.esql.expression.function.DocsV3Support; +import org.elasticsearch.xpack.esql.expression.function.MultiRowTestCaseSupplier; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import org.hamcrest.Matcher; +import org.hamcrest.Matchers; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.function.Supplier; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; + +public class DerivTests extends AbstractFunctionTestCase { + public DerivTests(@Name("TestCase") Supplier testCaseSupplier) { + this.testCase = testCaseSupplier.get(); + } + + @ParametersFactory + public static Iterable parameters() { + var suppliers = new ArrayList(); + + var valuesSuppliers = List.of( + MultiRowTestCaseSupplier.longCases(1, 1000, 0, 1000_000_000, true), + MultiRowTestCaseSupplier.intCases(1, 1000, 0, 1000_000_000, true), + MultiRowTestCaseSupplier.doubleCases(1, 1000, 0, 1000_000_000, true) + + ); + for (List valuesSupplier : valuesSuppliers) { + for (TestCaseSupplier.TypedDataSupplier fieldSupplier : valuesSupplier) { + TestCaseSupplier testCaseSupplier = makeSupplier(fieldSupplier); + suppliers.add(testCaseSupplier); + } + } + List parameters = new ArrayList<>(suppliers.size()); + for (TestCaseSupplier supplier : suppliers) { + parameters.add(new Object[] { supplier }); + } + return parameters; + } + + @Override + protected Expression build(Source source, List args) { + return new Deriv(source, args.get(0), args.get(1), args.get(2), args.get(3)); + } + + @SuppressWarnings("unchecked") + private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier fieldSupplier) { + DataType type = fieldSupplier.type(); + return new TestCaseSupplier(fieldSupplier.name(), List.of(type, DataType.DATETIME, DataType.INTEGER, DataType.LONG), () -> { + TestCaseSupplier.TypedData fieldTypedData = fieldSupplier.get(); + List dataRows = fieldTypedData.multiRowData(); + if (randomBoolean()) { + List withNulls = new ArrayList<>(dataRows.size()); + for (Object dataRow : dataRows) { + if (randomBoolean()) { + withNulls.add(null); + } else { + withNulls.add(dataRow); + } + } + dataRows = withNulls; + } + fieldTypedData = TestCaseSupplier.TypedData.multiRow(dataRows, type, fieldTypedData.name()); + List timestamps = new ArrayList<>(); + List slices = new ArrayList<>(); + List maxTimestamps = new ArrayList<>(); + long lastTimestamp = randomLongBetween(0, 1_000_000); + for (int row = 0; row < dataRows.size(); row++) { + lastTimestamp += randomLongBetween(1, 10_000); + timestamps.add(lastTimestamp); + slices.add(0); + maxTimestamps.add(Long.MAX_VALUE); + } + TestCaseSupplier.TypedData timestampsField = TestCaseSupplier.TypedData.multiRow( + timestamps.reversed(), + DataType.DATETIME, + "timestamps" + ); + TestCaseSupplier.TypedData sliceIndexType = TestCaseSupplier.TypedData.multiRow(slices, DataType.INTEGER, "_slice_index"); + TestCaseSupplier.TypedData nextTimestampType = TestCaseSupplier.TypedData.multiRow( + maxTimestamps, + DataType.LONG, + "_max_timestamp" + ); + + List nonNullDataRows = dataRows.stream().filter(Objects::nonNull).toList(); + Matcher matcher; + if (nonNullDataRows.size() < 2) { + matcher = Matchers.nullValue(); + } else { + var lastValue = ((Number) nonNullDataRows.getFirst()).doubleValue(); + var secondLastValue = ((Number) nonNullDataRows.get(1)).doubleValue(); + var increase = lastValue >= secondLastValue ? lastValue - secondLastValue : lastValue; + var largestTimestamp = timestamps.get(0); + var secondLargestTimestamp = timestamps.get(1); + var smallestTimestamp = timestamps.getLast(); + matcher = Matchers.allOf( + Matchers.greaterThanOrEqualTo(increase / (largestTimestamp - smallestTimestamp) * 1000 * 0.9), + Matchers.lessThanOrEqualTo( + increase / (largestTimestamp - secondLargestTimestamp) * (largestTimestamp - smallestTimestamp) * 1000 + ) + ); + } + + return new TestCaseSupplier.TestCase( + List.of(fieldTypedData, timestampsField, sliceIndexType, nextTimestampType), + Matchers.stringContainsInOrder("GroupingAggregator", "Deriv", "GroupingAggregatorFunction"), + // Matchers.any(String.class), + DataType.DOUBLE, + matcher + ); + }); + } + + public static List signatureTypes(List params) { + assertThat(params, hasSize(4)); + assertThat(params.get(1).dataType(), equalTo(DataType.DATETIME)); + assertThat(params.get(2).dataType(), equalTo(DataType.INTEGER)); + assertThat(params.get(3).dataType(), equalTo(DataType.LONG)); + return List.of(params.get(0)); + } +} From e34f14b6e15e936a8cfcff6f5fd8f0eb84499d90 Mon Sep 17 00:00:00 2001 From: Pablo Date: Tue, 11 Nov 2025 09:49:41 -0800 Subject: [PATCH 16/27] fixup. comments --- x-pack/plugin/esql/compute/build.gradle | 5 - .../DerivIntGroupingAggregatorFunction.java | 303 ------------------ .../main/resources/k8s-timeseries.csv-spec | 2 +- .../expression/function/aggregate/Deriv.java | 4 +- .../function/aggregate/DerivTests.java | 1 - 5 files changed, 2 insertions(+), 313 deletions(-) delete mode 100644 x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java diff --git a/x-pack/plugin/esql/compute/build.gradle b/x-pack/plugin/esql/compute/build.gradle index 77fe309569170..93e91131e1fa5 100644 --- a/x-pack/plugin/esql/compute/build.gradle +++ b/x-pack/plugin/esql/compute/build.gradle @@ -995,11 +995,6 @@ tasks.named('stringTemplates').configure { } File derivAggregatorInputFile = file("src/main/java/org/elasticsearch/compute/aggregation/X-DerivGroupingAggregatorFunction.java.st") - template { - it.properties = intProperties - it.inputFile = derivAggregatorInputFile - it.outputFile = "org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java" - } template { it.properties = doubleProperties it.inputFile = derivAggregatorInputFile diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java deleted file mode 100644 index 90af60dc2930b..0000000000000 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java +++ /dev/null @@ -1,303 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -package org.elasticsearch.compute.aggregation; - -// begin generated imports -import org.elasticsearch.common.util.ObjectArray; -import org.elasticsearch.compute.data.Block; -import org.elasticsearch.compute.data.BlockFactory; -import org.elasticsearch.compute.data.DoubleBlock; -import org.elasticsearch.compute.data.DoubleVector; -import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; -import org.elasticsearch.compute.data.IntBlock; -import org.elasticsearch.compute.data.IntVector; -import org.elasticsearch.compute.data.LongBlock; -import org.elasticsearch.compute.data.LongVector; -import org.elasticsearch.compute.data.Page; -import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.core.Releasables; - -import java.util.List; -// end generated imports - -@SuppressWarnings("cast") -public class DerivIntGroupingAggregatorFunction implements GroupingAggregatorFunction { - - private static final List INTERMEDIATE_STATE_DESC = List.of( - new IntermediateStateDesc("sumVal", ElementType.INT), - new IntermediateStateDesc("sumTs", ElementType.LONG), - new IntermediateStateDesc("sumTsVal", ElementType.INT), - new IntermediateStateDesc("sumTsSq", ElementType.LONG), - new IntermediateStateDesc("count", ElementType.LONG) - ); - - private final List channels; - private final DriverContext driverContext; - private ObjectArray states; - - public DerivIntGroupingAggregatorFunction(List channels, DriverContext driverContext) { - this.states = driverContext.bigArrays().newObjectArray(256); - this.channels = channels; - this.driverContext = driverContext; - } - - public static class Supplier implements AggregatorFunctionSupplier { - - @Override - public List nonGroupingIntermediateStateDesc() { - throw new UnsupportedOperationException("DerivGroupingAggregatorFunction does not support non-grouping aggregation"); - } - - @Override - public List groupingIntermediateStateDesc() { - return INTERMEDIATE_STATE_DESC; - } - - @Override - public AggregatorFunction aggregator(DriverContext driverContext, List channels) { - throw new UnsupportedOperationException("DerivGroupingAggregatorFunction does not support non-grouping aggregation"); - } - - @Override - public GroupingAggregatorFunction groupingAggregator(DriverContext driverContext, List channels) { - return new DerivIntGroupingAggregatorFunction(channels, driverContext); - } - - @Override - public String describe() { - return "derivative"; - } - } - - @Override - public AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { - final IntBlock valueBlock = page.getBlock(channels.get(0)); - final LongBlock timestampBlock = page.getBlock(channels.get(1)); - final IntVector valueVector = valueBlock.asVector(); - return new AddInput() { - @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addBlockInput(positionOffset, groupIds, timestampBlock, valueVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { - addBlockInput(positionOffset, groupIds, timestampBlock, valueVector); - } - - @Override - public void add(int positionOffset, IntVector groupIds) { - for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { - int valuesPosition = groupPosition + positionOffset; - int groupId = groupIds.getInt(groupPosition); - int vValue = valueVector.getInt(valuesPosition); - long ts = timestampBlock.getLong(valuesPosition); - SimpleLinearRegressionWithTimeseries state = states.get(groupId); - if (state == null) { - state = new SimpleLinearRegressionWithTimeseries(); - states.set(groupId, state); - } - state.add(ts, (double) vValue); // TODO - value needs to be converted to double - } - } - - @Override - public void close() { - - } - - private void addBlockInput(int positionOffset, IntBlock groupIds, LongBlock timestampBlock, IntVector valueVector) { - for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { - if (groupIds.isNull(groupPosition)) { - continue; - } - int valuePosition = groupPosition + positionOffset; - if (valueBlock.isNull(valuePosition)) { - continue; - } - int groupStart = groupIds.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groupIds.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groupIds.getInt(g); - int vStart = valueBlock.getFirstValueIndex(valuePosition); - int vEnd = vStart + valueBlock.getValueCount(valuePosition); - for (int v = vStart; v < vEnd; v++) { - long ts = timestampBlock.getLong(valuePosition); - int val = valueVector.getInt(valuePosition); - SimpleLinearRegressionWithTimeseries state = states.get(groupId); - if (state == null) { - state = new SimpleLinearRegressionWithTimeseries(); - states.set(groupId, state); - } - state.add(ts, val); // TODO - value needs to be converted to double - } - } - } - } - }; - } - - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - // No-op - } - - @Override - public void addIntermediateInput(int positionOffset, IntArrayBlock groupIdVector, Page page) { - addIntermediateBlockInput(positionOffset, groupIdVector, page); - } - - @Override - public void addIntermediateInput(int positionOffset, IntBigArrayBlock groupIdVector, Page page) { - addIntermediateBlockInput(positionOffset, groupIdVector, page); - } - - private void addIntermediateBlockInput(int positionOffset, IntBlock groupIdVector, Page page) { - IntBlock sumValBlock = page.getBlock(channels.get(0)); - LongBlock sumTsBlock = page.getBlock(channels.get(1)); - IntBlock sumTsValBlock = page.getBlock(channels.get(2)); - LongBlock sumTsSqBlock = page.getBlock(channels.get(3)); - LongBlock countBlock = page.getBlock(channels.get(4)); - - if (sumTsBlock.getTotalValueCount() != sumValBlock.getTotalValueCount() - || sumTsBlock.getTotalValueCount() != sumTsValBlock.getTotalValueCount() - || sumTsBlock.getTotalValueCount() != countBlock.getTotalValueCount()) { - throw new IllegalStateException("Mismatched intermediate state block value counts"); - } - - for (int groupPos = 0; groupPos < groupIdVector.getPositionCount(); groupPos++) { - int valuePos = groupPos + positionOffset; - - int firstGroup = groupIdVector.getFirstValueIndex(groupPos); - int lastGroup = firstGroup + groupIdVector.getValueCount(groupPos); - - for (int g = firstGroup; g < lastGroup; g++) { - int groupId = groupIdVector.getInt(g); - states = driverContext.bigArrays().grow(states, groupId + 1); - var state = states.get(groupId); - if (state == null) { - state = new SimpleLinearRegressionWithTimeseries(); // TODO - what happens for int / long - states.set(groupId, state); - } - long sumTs = sumTsBlock.getLong(valuePos); - int sumVal = sumValBlock.getInt(valuePos); - int sumTsVal = sumTsValBlock.getInt(valuePos); - long sumTsSq = sumTsSqBlock.getLong(valuePos); - long count = countBlock.getLong(valuePos); - state.sumTs += sumTs; - state.sumVal += sumVal; - state.sumTsVal += sumTsVal; - state.sumTsSq += sumTsSq; - state.count += count; - } - } - - } - - @Override - public void addIntermediateInput(int positionOffset, IntVector groupIdVector, Page page) { - IntBlock sumValBlock = page.getBlock(channels.get(0)); - LongBlock sumTsBlock = page.getBlock(channels.get(1)); - IntBlock sumTsValBlock = page.getBlock(channels.get(2)); - LongBlock sumTsSqBlock = page.getBlock(channels.get(3)); - LongBlock countBlock = page.getBlock(channels.get(4)); - - if (sumTsBlock.getTotalValueCount() != sumValBlock.getTotalValueCount() - || sumTsBlock.getTotalValueCount() != sumTsValBlock.getTotalValueCount() - || sumTsBlock.getTotalValueCount() != countBlock.getTotalValueCount()) { - throw new IllegalStateException("Mismatched intermediate state block value counts"); - } - - for (int groupPos = 0; groupPos < groupIdVector.getPositionCount(); groupPos++) { - int valuePos = groupPos + positionOffset; - int groupId = groupIdVector.getInt(groupPos); - states = driverContext.bigArrays().grow(states, groupId + 1); - var state = states.get(groupId); - if (state == null) { - state = new SimpleLinearRegressionWithTimeseries(); // TODO: what about double conversion - states.set(groupId, state); - } - long sumTs = sumTsBlock.getLong(valuePos); - int sumVal = sumValBlock.getInt(valuePos); - int sumTsVal = sumTsValBlock.getInt(valuePos); - long sumTsSq = sumTsSqBlock.getLong(valuePos); - long count = countBlock.getLong(valuePos); - state.sumTs += sumTs; - state.sumVal += sumVal; - state.sumTsVal += sumTsVal; - state.sumTsSq += sumTsSq; - state.count += count; - } - } - - @Override - public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { - BlockFactory blockFactory = driverContext.blockFactory(); - int positionCount = selected.getPositionCount(); - try ( - var sumValBuilder = blockFactory.newIntBlockBuilder(positionCount); - var sumTsBuilder = blockFactory.newLongBlockBuilder(positionCount); - var sumTsValBuilder = blockFactory.newIntBlockBuilder(positionCount); - var sumTsSqBuilder = blockFactory.newLongBlockBuilder(positionCount); - var countBuilder = blockFactory.newLongBlockBuilder(positionCount) - ) { - for (int p = 0; p < positionCount; p++) { - int groupId = selected.getInt(p); - SimpleLinearRegressionWithTimeseries state = states.get(groupId); - if (state == null) { - sumValBuilder.appendNull(); - sumTsBuilder.appendNull(); - sumTsValBuilder.appendNull(); - sumTsSqBuilder.appendNull(); - countBuilder.appendNull(); - } else { - sumValBuilder.appendInt((int) state.sumVal); - sumTsBuilder.appendLong(state.sumTs); - sumTsValBuilder.appendInt((int) state.sumTsVal); // TODO: fix this actually - sumTsSqBuilder.appendLong(state.sumTsSq); - countBuilder.appendLong(state.count); - } - } - blocks[offset] = sumValBuilder.build(); - blocks[offset + 1] = sumTsBuilder.build(); - blocks[offset + 2] = sumTsValBuilder.build(); - blocks[offset + 3] = sumTsSqBuilder.build(); - blocks[offset + 4] = countBuilder.build(); - } - } - - @Override - public void evaluateFinal(Block[] blocks, int offset, IntVector selected, GroupingAggregatorEvaluationContext evaluationContext) { - BlockFactory blockFactory = driverContext.blockFactory(); - int positionCount = selected.getPositionCount(); - try (var resultBuilder = blockFactory.newDoubleBlockBuilder(positionCount)) { - for (int p = 0; p < positionCount; p++) { - int groupId = selected.getInt(p); - SimpleLinearRegressionWithTimeseries state = states.get(groupId); - if (state == null) { - resultBuilder.appendNull(); - } else { - double deriv = state.slope(); - resultBuilder.appendDouble(deriv); - } - } - blocks[offset] = resultBuilder.build(); - } - } - - @Override - public int intermediateBlockCount() { - return INTERMEDIATE_STATE_DESC.size(); - } - - @Override - public void close() { - Releasables.close(states); - } -} diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec index 152000c3731cf..93da64ccc60f7 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec @@ -729,7 +729,7 @@ required_capability: TS_LINREG_DERIVATIVE // tag::deriv[] TS k8s | WHERE pod == "three" -| STATS max_deriv = max(deriv(network.cost)) BY time_bucket = bucket(@timestamp,5minute), pod +| STATS max_deriv = MAX(DERIV(network.cost)) BY time_bucket = BUCKET(@timestamp,5minute), pod // end::deriv[] | EVAL max_deriv = ROUND(max_deriv,6) | KEEP max_deriv, time_bucket, pod diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java index 2bc869940f8bf..9a51b7ba85d3b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java @@ -10,7 +10,6 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.DerivDoubleGroupingAggregatorFunction; -import org.elasticsearch.compute.aggregation.DerivIntGroupingAggregatorFunction; import org.elasticsearch.compute.aggregation.DerivLongGroupingAggregatorFunction; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; @@ -112,8 +111,7 @@ public String getWriteableName() { public AggregatorFunctionSupplier supplier() { final DataType type = field().dataType(); return switch (type) { - case INTEGER -> new DerivIntGroupingAggregatorFunction.Supplier(); - case LONG -> new DerivLongGroupingAggregatorFunction.Supplier(); + case INTEGER, LONG -> new DerivLongGroupingAggregatorFunction.Supplier(); case DOUBLE -> new DerivDoubleGroupingAggregatorFunction.Supplier(); default -> throw new IllegalArgumentException("Unsupported data type for deriv aggregation: " + type); }; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/DerivTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/DerivTests.java index c791b8d4ab8eb..2c20c20a20d8f 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/DerivTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/DerivTests.java @@ -123,7 +123,6 @@ private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier return new TestCaseSupplier.TestCase( List.of(fieldTypedData, timestampsField, sliceIndexType, nextTimestampType), Matchers.stringContainsInOrder("GroupingAggregator", "Deriv", "GroupingAggregatorFunction"), - // Matchers.any(String.class), DataType.DOUBLE, matcher ); From bf2f6a9dbf5789be651b8fb6af6432841c027b97 Mon Sep 17 00:00:00 2001 From: Pablo Date: Tue, 11 Nov 2025 13:46:09 -0800 Subject: [PATCH 17/27] fixup --- .../query-languages/esql/_snippets/functions/examples/deriv.md | 2 +- .../query-languages/esql/kibana/definition/functions/deriv.json | 2 +- .../query-languages/esql/kibana/docs/functions/deriv.md | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/reference/query-languages/esql/_snippets/functions/examples/deriv.md b/docs/reference/query-languages/esql/_snippets/functions/examples/deriv.md index 7c4152da59b9d..328d13f76598a 100644 --- a/docs/reference/query-languages/esql/_snippets/functions/examples/deriv.md +++ b/docs/reference/query-languages/esql/_snippets/functions/examples/deriv.md @@ -5,7 +5,7 @@ ```esql TS k8s | WHERE pod == "three" -| STATS max_deriv = max(deriv(network.cost)) BY time_bucket = bucket(@timestamp,5minute), pod +| STATS max_deriv = MAX(DERIV(network.cost)) BY time_bucket = BUCKET(@timestamp,5minute), pod ``` | max_deriv:double | time_bucket:datetime | pod:keyword | diff --git a/docs/reference/query-languages/esql/kibana/definition/functions/deriv.json b/docs/reference/query-languages/esql/kibana/definition/functions/deriv.json index 7305b2707f639..f84745f4de37a 100644 --- a/docs/reference/query-languages/esql/kibana/definition/functions/deriv.json +++ b/docs/reference/query-languages/esql/kibana/definition/functions/deriv.json @@ -42,7 +42,7 @@ } ], "examples" : [ - "TS k8s\n| WHERE pod == \"three\"\n| STATS max_deriv = max(deriv(network.cost)) BY time_bucket = bucket(@timestamp,5minute), pod" + "TS k8s\n| WHERE pod == \"three\"\n| STATS max_deriv = MAX(DERIV(network.cost)) BY time_bucket = BUCKET(@timestamp,5minute), pod" ], "preview" : false, "snapshot_only" : false diff --git a/docs/reference/query-languages/esql/kibana/docs/functions/deriv.md b/docs/reference/query-languages/esql/kibana/docs/functions/deriv.md index 2e7005674be89..fb17b716b03ea 100644 --- a/docs/reference/query-languages/esql/kibana/docs/functions/deriv.md +++ b/docs/reference/query-languages/esql/kibana/docs/functions/deriv.md @@ -6,5 +6,5 @@ Calculates the derivative over time of a numeric field using linear regression. ```esql TS k8s | WHERE pod == "three" -| STATS max_deriv = max(deriv(network.cost)) BY time_bucket = bucket(@timestamp,5minute), pod +| STATS max_deriv = MAX(DERIV(network.cost)) BY time_bucket = BUCKET(@timestamp,5minute), pod ``` From 8e2b5efae487a9c05895d5e8de13f01640493371 Mon Sep 17 00:00:00 2001 From: Pablo Date: Wed, 12 Nov 2025 21:01:48 -0800 Subject: [PATCH 18/27] prototype simpler code from nhat advice --- x-pack/plugin/esql/compute/build.gradle | 22 +- ...DerivDoubleGroupingAggregatorFunction.java | 303 -------------- .../DerivLongGroupingAggregatorFunction.java | 303 -------------- .../DerivDoubleAggregatorFunction.java | 235 +++++++++++ ...DerivDoubleAggregatorFunctionSupplier.java | 47 +++ ...DerivDoubleGroupingAggregatorFunction.java | 390 ++++++++++++++++++ .../aggregation/DerivDoubleAggregator.java | 170 ++++++++ .../SimpleLinearRegressionWithTimeseries.java | 19 +- .../expression/function/aggregate/Deriv.java | 6 +- 9 files changed, 873 insertions(+), 622 deletions(-) delete mode 100644 x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java delete mode 100644 x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java create mode 100644 x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivDoubleAggregatorFunction.java create mode 100644 x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivDoubleAggregatorFunctionSupplier.java create mode 100644 x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java create mode 100644 x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivDoubleAggregator.java diff --git a/x-pack/plugin/esql/compute/build.gradle b/x-pack/plugin/esql/compute/build.gradle index 93e91131e1fa5..cb49389b1f096 100644 --- a/x-pack/plugin/esql/compute/build.gradle +++ b/x-pack/plugin/esql/compute/build.gradle @@ -994,15 +994,15 @@ tasks.named('stringTemplates').configure { it.outputFile = "org/elasticsearch/compute/aggregation/RateLongGroupingAggregatorFunction.java" } - File derivAggregatorInputFile = file("src/main/java/org/elasticsearch/compute/aggregation/X-DerivGroupingAggregatorFunction.java.st") - template { - it.properties = doubleProperties - it.inputFile = derivAggregatorInputFile - it.outputFile = "org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java" - } - template { - it.properties = longProperties - it.inputFile = derivAggregatorInputFile - it.outputFile = "org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java" - } +// File derivAggregatorInputFile = file("src/main/java/org/elasticsearch/compute/aggregation/X-DerivGroupingAggregatorFunction.java.st") +// template { +// it.properties = doubleProperties +// it.inputFile = derivAggregatorInputFile +// it.outputFile = "org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java" +// } +// template { +// it.properties = longProperties +// it.inputFile = derivAggregatorInputFile +// it.outputFile = "org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java" +// } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java deleted file mode 100644 index 043a4a59f0b5a..0000000000000 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java +++ /dev/null @@ -1,303 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -package org.elasticsearch.compute.aggregation; - -// begin generated imports -import org.elasticsearch.common.util.ObjectArray; -import org.elasticsearch.compute.data.Block; -import org.elasticsearch.compute.data.BlockFactory; -import org.elasticsearch.compute.data.DoubleBlock; -import org.elasticsearch.compute.data.DoubleVector; -import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; -import org.elasticsearch.compute.data.IntBlock; -import org.elasticsearch.compute.data.IntVector; -import org.elasticsearch.compute.data.LongBlock; -import org.elasticsearch.compute.data.LongVector; -import org.elasticsearch.compute.data.Page; -import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.core.Releasables; - -import java.util.List; -// end generated imports - -@SuppressWarnings("cast") -public class DerivDoubleGroupingAggregatorFunction implements GroupingAggregatorFunction { - - private static final List INTERMEDIATE_STATE_DESC = List.of( - new IntermediateStateDesc("sumVal", ElementType.DOUBLE), - new IntermediateStateDesc("sumTs", ElementType.LONG), - new IntermediateStateDesc("sumTsVal", ElementType.DOUBLE), - new IntermediateStateDesc("sumTsSq", ElementType.LONG), - new IntermediateStateDesc("count", ElementType.LONG) - ); - - private final List channels; - private final DriverContext driverContext; - private ObjectArray states; - - public DerivDoubleGroupingAggregatorFunction(List channels, DriverContext driverContext) { - this.states = driverContext.bigArrays().newObjectArray(256); - this.channels = channels; - this.driverContext = driverContext; - } - - public static class Supplier implements AggregatorFunctionSupplier { - - @Override - public List nonGroupingIntermediateStateDesc() { - throw new UnsupportedOperationException("DerivGroupingAggregatorFunction does not support non-grouping aggregation"); - } - - @Override - public List groupingIntermediateStateDesc() { - return INTERMEDIATE_STATE_DESC; - } - - @Override - public AggregatorFunction aggregator(DriverContext driverContext, List channels) { - throw new UnsupportedOperationException("DerivGroupingAggregatorFunction does not support non-grouping aggregation"); - } - - @Override - public GroupingAggregatorFunction groupingAggregator(DriverContext driverContext, List channels) { - return new DerivDoubleGroupingAggregatorFunction(channels, driverContext); - } - - @Override - public String describe() { - return "derivative"; - } - } - - @Override - public AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { - final DoubleBlock valueBlock = page.getBlock(channels.get(0)); - final LongBlock timestampBlock = page.getBlock(channels.get(1)); - final DoubleVector valueVector = valueBlock.asVector(); - return new AddInput() { - @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addBlockInput(positionOffset, groupIds, timestampBlock, valueVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { - addBlockInput(positionOffset, groupIds, timestampBlock, valueVector); - } - - @Override - public void add(int positionOffset, IntVector groupIds) { - for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { - int valuesPosition = groupPosition + positionOffset; - int groupId = groupIds.getInt(groupPosition); - double vValue = valueVector.getDouble(valuesPosition); - long ts = timestampBlock.getLong(valuesPosition); - SimpleLinearRegressionWithTimeseries state = states.get(groupId); - if (state == null) { - state = new SimpleLinearRegressionWithTimeseries(); - states.set(groupId, state); - } - state.add(ts, (double) vValue); // TODO - value needs to be converted to double - } - } - - @Override - public void close() { - - } - - private void addBlockInput(int positionOffset, IntBlock groupIds, LongBlock timestampBlock, DoubleVector valueVector) { - for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { - if (groupIds.isNull(groupPosition)) { - continue; - } - int valuePosition = groupPosition + positionOffset; - if (valueBlock.isNull(valuePosition)) { - continue; - } - int groupStart = groupIds.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groupIds.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groupIds.getInt(g); - int vStart = valueBlock.getFirstValueIndex(valuePosition); - int vEnd = vStart + valueBlock.getValueCount(valuePosition); - for (int v = vStart; v < vEnd; v++) { - long ts = timestampBlock.getLong(valuePosition); - double val = valueVector.getDouble(valuePosition); - SimpleLinearRegressionWithTimeseries state = states.get(groupId); - if (state == null) { - state = new SimpleLinearRegressionWithTimeseries(); - states.set(groupId, state); - } - state.add(ts, val); // TODO - value needs to be converted to double - } - } - } - } - }; - } - - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - // No-op - } - - @Override - public void addIntermediateInput(int positionOffset, IntArrayBlock groupIdVector, Page page) { - addIntermediateBlockInput(positionOffset, groupIdVector, page); - } - - @Override - public void addIntermediateInput(int positionOffset, IntBigArrayBlock groupIdVector, Page page) { - addIntermediateBlockInput(positionOffset, groupIdVector, page); - } - - private void addIntermediateBlockInput(int positionOffset, IntBlock groupIdVector, Page page) { - DoubleBlock sumValBlock = page.getBlock(channels.get(0)); - LongBlock sumTsBlock = page.getBlock(channels.get(1)); - DoubleBlock sumTsValBlock = page.getBlock(channels.get(2)); - LongBlock sumTsSqBlock = page.getBlock(channels.get(3)); - LongBlock countBlock = page.getBlock(channels.get(4)); - - if (sumTsBlock.getTotalValueCount() != sumValBlock.getTotalValueCount() - || sumTsBlock.getTotalValueCount() != sumTsValBlock.getTotalValueCount() - || sumTsBlock.getTotalValueCount() != countBlock.getTotalValueCount()) { - throw new IllegalStateException("Mismatched intermediate state block value counts"); - } - - for (int groupPos = 0; groupPos < groupIdVector.getPositionCount(); groupPos++) { - int valuePos = groupPos + positionOffset; - - int firstGroup = groupIdVector.getFirstValueIndex(groupPos); - int lastGroup = firstGroup + groupIdVector.getValueCount(groupPos); - - for (int g = firstGroup; g < lastGroup; g++) { - int groupId = groupIdVector.getInt(g); - states = driverContext.bigArrays().grow(states, groupId + 1); - var state = states.get(groupId); - if (state == null) { - state = new SimpleLinearRegressionWithTimeseries(); // TODO - what happens for int / long - states.set(groupId, state); - } - long sumTs = sumTsBlock.getLong(valuePos); - double sumVal = sumValBlock.getDouble(valuePos); - double sumTsVal = sumTsValBlock.getDouble(valuePos); - long sumTsSq = sumTsSqBlock.getLong(valuePos); - long count = countBlock.getLong(valuePos); - state.sumTs += sumTs; - state.sumVal += sumVal; - state.sumTsVal += sumTsVal; - state.sumTsSq += sumTsSq; - state.count += count; - } - } - - } - - @Override - public void addIntermediateInput(int positionOffset, IntVector groupIdVector, Page page) { - DoubleBlock sumValBlock = page.getBlock(channels.get(0)); - LongBlock sumTsBlock = page.getBlock(channels.get(1)); - DoubleBlock sumTsValBlock = page.getBlock(channels.get(2)); - LongBlock sumTsSqBlock = page.getBlock(channels.get(3)); - LongBlock countBlock = page.getBlock(channels.get(4)); - - if (sumTsBlock.getTotalValueCount() != sumValBlock.getTotalValueCount() - || sumTsBlock.getTotalValueCount() != sumTsValBlock.getTotalValueCount() - || sumTsBlock.getTotalValueCount() != countBlock.getTotalValueCount()) { - throw new IllegalStateException("Mismatched intermediate state block value counts"); - } - - for (int groupPos = 0; groupPos < groupIdVector.getPositionCount(); groupPos++) { - int valuePos = groupPos + positionOffset; - int groupId = groupIdVector.getInt(groupPos); - states = driverContext.bigArrays().grow(states, groupId + 1); - var state = states.get(groupId); - if (state == null) { - state = new SimpleLinearRegressionWithTimeseries(); // TODO: what about double conversion - states.set(groupId, state); - } - long sumTs = sumTsBlock.getLong(valuePos); - double sumVal = sumValBlock.getDouble(valuePos); - double sumTsVal = sumTsValBlock.getDouble(valuePos); - long sumTsSq = sumTsSqBlock.getLong(valuePos); - long count = countBlock.getLong(valuePos); - state.sumTs += sumTs; - state.sumVal += sumVal; - state.sumTsVal += sumTsVal; - state.sumTsSq += sumTsSq; - state.count += count; - } - } - - @Override - public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { - BlockFactory blockFactory = driverContext.blockFactory(); - int positionCount = selected.getPositionCount(); - try ( - var sumValBuilder = blockFactory.newDoubleBlockBuilder(positionCount); - var sumTsBuilder = blockFactory.newLongBlockBuilder(positionCount); - var sumTsValBuilder = blockFactory.newDoubleBlockBuilder(positionCount); - var sumTsSqBuilder = blockFactory.newLongBlockBuilder(positionCount); - var countBuilder = blockFactory.newLongBlockBuilder(positionCount) - ) { - for (int p = 0; p < positionCount; p++) { - int groupId = selected.getInt(p); - SimpleLinearRegressionWithTimeseries state = states.get(groupId); - if (state == null) { - sumValBuilder.appendNull(); - sumTsBuilder.appendNull(); - sumTsValBuilder.appendNull(); - sumTsSqBuilder.appendNull(); - countBuilder.appendNull(); - } else { - sumValBuilder.appendDouble((double) state.sumVal); - sumTsBuilder.appendLong(state.sumTs); - sumTsValBuilder.appendDouble((double) state.sumTsVal); // TODO: fix this actually - sumTsSqBuilder.appendLong(state.sumTsSq); - countBuilder.appendLong(state.count); - } - } - blocks[offset] = sumValBuilder.build(); - blocks[offset + 1] = sumTsBuilder.build(); - blocks[offset + 2] = sumTsValBuilder.build(); - blocks[offset + 3] = sumTsSqBuilder.build(); - blocks[offset + 4] = countBuilder.build(); - } - } - - @Override - public void evaluateFinal(Block[] blocks, int offset, IntVector selected, GroupingAggregatorEvaluationContext evaluationContext) { - BlockFactory blockFactory = driverContext.blockFactory(); - int positionCount = selected.getPositionCount(); - try (var resultBuilder = blockFactory.newDoubleBlockBuilder(positionCount)) { - for (int p = 0; p < positionCount; p++) { - int groupId = selected.getInt(p); - SimpleLinearRegressionWithTimeseries state = states.get(groupId); - if (state == null) { - resultBuilder.appendNull(); - } else { - double deriv = state.slope(); - resultBuilder.appendDouble(deriv); - } - } - blocks[offset] = resultBuilder.build(); - } - } - - @Override - public int intermediateBlockCount() { - return INTERMEDIATE_STATE_DESC.size(); - } - - @Override - public void close() { - Releasables.close(states); - } -} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java deleted file mode 100644 index 2745cf2fb323c..0000000000000 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java +++ /dev/null @@ -1,303 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -package org.elasticsearch.compute.aggregation; - -// begin generated imports -import org.elasticsearch.common.util.ObjectArray; -import org.elasticsearch.compute.data.Block; -import org.elasticsearch.compute.data.BlockFactory; -import org.elasticsearch.compute.data.DoubleBlock; -import org.elasticsearch.compute.data.DoubleVector; -import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; -import org.elasticsearch.compute.data.IntBlock; -import org.elasticsearch.compute.data.IntVector; -import org.elasticsearch.compute.data.LongBlock; -import org.elasticsearch.compute.data.LongVector; -import org.elasticsearch.compute.data.Page; -import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.core.Releasables; - -import java.util.List; -// end generated imports - -@SuppressWarnings("cast") -public class DerivLongGroupingAggregatorFunction implements GroupingAggregatorFunction { - - private static final List INTERMEDIATE_STATE_DESC = List.of( - new IntermediateStateDesc("sumVal", ElementType.LONG), - new IntermediateStateDesc("sumTs", ElementType.LONG), - new IntermediateStateDesc("sumTsVal", ElementType.LONG), - new IntermediateStateDesc("sumTsSq", ElementType.LONG), - new IntermediateStateDesc("count", ElementType.LONG) - ); - - private final List channels; - private final DriverContext driverContext; - private ObjectArray states; - - public DerivLongGroupingAggregatorFunction(List channels, DriverContext driverContext) { - this.states = driverContext.bigArrays().newObjectArray(256); - this.channels = channels; - this.driverContext = driverContext; - } - - public static class Supplier implements AggregatorFunctionSupplier { - - @Override - public List nonGroupingIntermediateStateDesc() { - throw new UnsupportedOperationException("DerivGroupingAggregatorFunction does not support non-grouping aggregation"); - } - - @Override - public List groupingIntermediateStateDesc() { - return INTERMEDIATE_STATE_DESC; - } - - @Override - public AggregatorFunction aggregator(DriverContext driverContext, List channels) { - throw new UnsupportedOperationException("DerivGroupingAggregatorFunction does not support non-grouping aggregation"); - } - - @Override - public GroupingAggregatorFunction groupingAggregator(DriverContext driverContext, List channels) { - return new DerivLongGroupingAggregatorFunction(channels, driverContext); - } - - @Override - public String describe() { - return "derivative"; - } - } - - @Override - public AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { - final LongBlock valueBlock = page.getBlock(channels.get(0)); - final LongBlock timestampBlock = page.getBlock(channels.get(1)); - final LongVector valueVector = valueBlock.asVector(); - return new AddInput() { - @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addBlockInput(positionOffset, groupIds, timestampBlock, valueVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { - addBlockInput(positionOffset, groupIds, timestampBlock, valueVector); - } - - @Override - public void add(int positionOffset, IntVector groupIds) { - for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { - int valuesPosition = groupPosition + positionOffset; - int groupId = groupIds.getInt(groupPosition); - long vValue = valueVector.getLong(valuesPosition); - long ts = timestampBlock.getLong(valuesPosition); - SimpleLinearRegressionWithTimeseries state = states.get(groupId); - if (state == null) { - state = new SimpleLinearRegressionWithTimeseries(); - states.set(groupId, state); - } - state.add(ts, (double) vValue); // TODO - value needs to be converted to double - } - } - - @Override - public void close() { - - } - - private void addBlockInput(int positionOffset, IntBlock groupIds, LongBlock timestampBlock, LongVector valueVector) { - for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { - if (groupIds.isNull(groupPosition)) { - continue; - } - int valuePosition = groupPosition + positionOffset; - if (valueBlock.isNull(valuePosition)) { - continue; - } - int groupStart = groupIds.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groupIds.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groupIds.getInt(g); - int vStart = valueBlock.getFirstValueIndex(valuePosition); - int vEnd = vStart + valueBlock.getValueCount(valuePosition); - for (int v = vStart; v < vEnd; v++) { - long ts = timestampBlock.getLong(valuePosition); - long val = valueVector.getLong(valuePosition); - SimpleLinearRegressionWithTimeseries state = states.get(groupId); - if (state == null) { - state = new SimpleLinearRegressionWithTimeseries(); - states.set(groupId, state); - } - state.add(ts, val); // TODO - value needs to be converted to double - } - } - } - } - }; - } - - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - // No-op - } - - @Override - public void addIntermediateInput(int positionOffset, IntArrayBlock groupIdVector, Page page) { - addIntermediateBlockInput(positionOffset, groupIdVector, page); - } - - @Override - public void addIntermediateInput(int positionOffset, IntBigArrayBlock groupIdVector, Page page) { - addIntermediateBlockInput(positionOffset, groupIdVector, page); - } - - private void addIntermediateBlockInput(int positionOffset, IntBlock groupIdVector, Page page) { - LongBlock sumValBlock = page.getBlock(channels.get(0)); - LongBlock sumTsBlock = page.getBlock(channels.get(1)); - LongBlock sumTsValBlock = page.getBlock(channels.get(2)); - LongBlock sumTsSqBlock = page.getBlock(channels.get(3)); - LongBlock countBlock = page.getBlock(channels.get(4)); - - if (sumTsBlock.getTotalValueCount() != sumValBlock.getTotalValueCount() - || sumTsBlock.getTotalValueCount() != sumTsValBlock.getTotalValueCount() - || sumTsBlock.getTotalValueCount() != countBlock.getTotalValueCount()) { - throw new IllegalStateException("Mismatched intermediate state block value counts"); - } - - for (int groupPos = 0; groupPos < groupIdVector.getPositionCount(); groupPos++) { - int valuePos = groupPos + positionOffset; - - int firstGroup = groupIdVector.getFirstValueIndex(groupPos); - int lastGroup = firstGroup + groupIdVector.getValueCount(groupPos); - - for (int g = firstGroup; g < lastGroup; g++) { - int groupId = groupIdVector.getInt(g); - states = driverContext.bigArrays().grow(states, groupId + 1); - var state = states.get(groupId); - if (state == null) { - state = new SimpleLinearRegressionWithTimeseries(); // TODO - what happens for int / long - states.set(groupId, state); - } - long sumTs = sumTsBlock.getLong(valuePos); - long sumVal = sumValBlock.getLong(valuePos); - long sumTsVal = sumTsValBlock.getLong(valuePos); - long sumTsSq = sumTsSqBlock.getLong(valuePos); - long count = countBlock.getLong(valuePos); - state.sumTs += sumTs; - state.sumVal += sumVal; - state.sumTsVal += sumTsVal; - state.sumTsSq += sumTsSq; - state.count += count; - } - } - - } - - @Override - public void addIntermediateInput(int positionOffset, IntVector groupIdVector, Page page) { - LongBlock sumValBlock = page.getBlock(channels.get(0)); - LongBlock sumTsBlock = page.getBlock(channels.get(1)); - LongBlock sumTsValBlock = page.getBlock(channels.get(2)); - LongBlock sumTsSqBlock = page.getBlock(channels.get(3)); - LongBlock countBlock = page.getBlock(channels.get(4)); - - if (sumTsBlock.getTotalValueCount() != sumValBlock.getTotalValueCount() - || sumTsBlock.getTotalValueCount() != sumTsValBlock.getTotalValueCount() - || sumTsBlock.getTotalValueCount() != countBlock.getTotalValueCount()) { - throw new IllegalStateException("Mismatched intermediate state block value counts"); - } - - for (int groupPos = 0; groupPos < groupIdVector.getPositionCount(); groupPos++) { - int valuePos = groupPos + positionOffset; - int groupId = groupIdVector.getInt(groupPos); - states = driverContext.bigArrays().grow(states, groupId + 1); - var state = states.get(groupId); - if (state == null) { - state = new SimpleLinearRegressionWithTimeseries(); // TODO: what about double conversion - states.set(groupId, state); - } - long sumTs = sumTsBlock.getLong(valuePos); - long sumVal = sumValBlock.getLong(valuePos); - long sumTsVal = sumTsValBlock.getLong(valuePos); - long sumTsSq = sumTsSqBlock.getLong(valuePos); - long count = countBlock.getLong(valuePos); - state.sumTs += sumTs; - state.sumVal += sumVal; - state.sumTsVal += sumTsVal; - state.sumTsSq += sumTsSq; - state.count += count; - } - } - - @Override - public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { - BlockFactory blockFactory = driverContext.blockFactory(); - int positionCount = selected.getPositionCount(); - try ( - var sumValBuilder = blockFactory.newLongBlockBuilder(positionCount); - var sumTsBuilder = blockFactory.newLongBlockBuilder(positionCount); - var sumTsValBuilder = blockFactory.newLongBlockBuilder(positionCount); - var sumTsSqBuilder = blockFactory.newLongBlockBuilder(positionCount); - var countBuilder = blockFactory.newLongBlockBuilder(positionCount) - ) { - for (int p = 0; p < positionCount; p++) { - int groupId = selected.getInt(p); - SimpleLinearRegressionWithTimeseries state = states.get(groupId); - if (state == null) { - sumValBuilder.appendNull(); - sumTsBuilder.appendNull(); - sumTsValBuilder.appendNull(); - sumTsSqBuilder.appendNull(); - countBuilder.appendNull(); - } else { - sumValBuilder.appendLong((long) state.sumVal); - sumTsBuilder.appendLong(state.sumTs); - sumTsValBuilder.appendLong((long) state.sumTsVal); // TODO: fix this actually - sumTsSqBuilder.appendLong(state.sumTsSq); - countBuilder.appendLong(state.count); - } - } - blocks[offset] = sumValBuilder.build(); - blocks[offset + 1] = sumTsBuilder.build(); - blocks[offset + 2] = sumTsValBuilder.build(); - blocks[offset + 3] = sumTsSqBuilder.build(); - blocks[offset + 4] = countBuilder.build(); - } - } - - @Override - public void evaluateFinal(Block[] blocks, int offset, IntVector selected, GroupingAggregatorEvaluationContext evaluationContext) { - BlockFactory blockFactory = driverContext.blockFactory(); - int positionCount = selected.getPositionCount(); - try (var resultBuilder = blockFactory.newDoubleBlockBuilder(positionCount)) { - for (int p = 0; p < positionCount; p++) { - int groupId = selected.getInt(p); - SimpleLinearRegressionWithTimeseries state = states.get(groupId); - if (state == null) { - resultBuilder.appendNull(); - } else { - double deriv = state.slope(); - resultBuilder.appendDouble(deriv); - } - } - blocks[offset] = resultBuilder.build(); - } - } - - @Override - public int intermediateBlockCount() { - return INTERMEDIATE_STATE_DESC.size(); - } - - @Override - public void close() { - Releasables.close(states); - } -} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivDoubleAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivDoubleAggregatorFunction.java new file mode 100644 index 0000000000000..6698c7284d352 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivDoubleAggregatorFunction.java @@ -0,0 +1,235 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link DerivDoubleAggregator}. + * This class is generated. Edit {@code AggregatorImplementer} instead. + */ +public final class DerivDoubleAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("count", ElementType.LONG), + new IntermediateStateDesc("sumVal", ElementType.DOUBLE), + new IntermediateStateDesc("sumTs", ElementType.LONG), + new IntermediateStateDesc("sumTsVal", ElementType.DOUBLE), + new IntermediateStateDesc("sumTsSq", ElementType.LONG) ); + + private final DriverContext driverContext; + + private final SimpleLinearRegressionWithTimeseries state; + + private final List channels; + + public DerivDoubleAggregatorFunction(DriverContext driverContext, List channels, + SimpleLinearRegressionWithTimeseries state) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + } + + public static DerivDoubleAggregatorFunction create(DriverContext driverContext, + List channels) { + return new DerivDoubleAggregatorFunction(driverContext, channels, DerivDoubleAggregator.initSingle(driverContext)); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + } else if (mask.allTrue()) { + addRawInputNotMasked(page); + } else { + addRawInputMasked(page, mask); + } + } + + private void addRawInputMasked(Page page, BooleanVector mask) { + DoubleBlock valueBlock = page.getBlock(channels.get(0)); + LongBlock timestampBlock = page.getBlock(channels.get(1)); + DoubleVector valueVector = valueBlock.asVector(); + if (valueVector == null) { + addRawBlock(valueBlock, timestampBlock, mask); + return; + } + LongVector timestampVector = timestampBlock.asVector(); + if (timestampVector == null) { + addRawBlock(valueBlock, timestampBlock, mask); + return; + } + addRawVector(valueVector, timestampVector, mask); + } + + private void addRawInputNotMasked(Page page) { + DoubleBlock valueBlock = page.getBlock(channels.get(0)); + LongBlock timestampBlock = page.getBlock(channels.get(1)); + DoubleVector valueVector = valueBlock.asVector(); + if (valueVector == null) { + addRawBlock(valueBlock, timestampBlock); + return; + } + LongVector timestampVector = timestampBlock.asVector(); + if (timestampVector == null) { + addRawBlock(valueBlock, timestampBlock); + return; + } + addRawVector(valueVector, timestampVector); + } + + private void addRawVector(DoubleVector valueVector, LongVector timestampVector) { + for (int valuesPosition = 0; valuesPosition < valueVector.getPositionCount(); valuesPosition++) { + double valueValue = valueVector.getDouble(valuesPosition); + long timestampValue = timestampVector.getLong(valuesPosition); + DerivDoubleAggregator.combine(state, valueValue, timestampValue); + } + } + + private void addRawVector(DoubleVector valueVector, LongVector timestampVector, + BooleanVector mask) { + for (int valuesPosition = 0; valuesPosition < valueVector.getPositionCount(); valuesPosition++) { + if (mask.getBoolean(valuesPosition) == false) { + continue; + } + double valueValue = valueVector.getDouble(valuesPosition); + long timestampValue = timestampVector.getLong(valuesPosition); + DerivDoubleAggregator.combine(state, valueValue, timestampValue); + } + } + + private void addRawBlock(DoubleBlock valueBlock, LongBlock timestampBlock) { + for (int p = 0; p < valueBlock.getPositionCount(); p++) { + int valueValueCount = valueBlock.getValueCount(p); + if (valueValueCount == 0) { + continue; + } + int timestampValueCount = timestampBlock.getValueCount(p); + if (timestampValueCount == 0) { + continue; + } + int valueStart = valueBlock.getFirstValueIndex(p); + int valueEnd = valueStart + valueValueCount; + for (int valueOffset = valueStart; valueOffset < valueEnd; valueOffset++) { + double valueValue = valueBlock.getDouble(valueOffset); + int timestampStart = timestampBlock.getFirstValueIndex(p); + int timestampEnd = timestampStart + timestampValueCount; + for (int timestampOffset = timestampStart; timestampOffset < timestampEnd; timestampOffset++) { + long timestampValue = timestampBlock.getLong(timestampOffset); + DerivDoubleAggregator.combine(state, valueValue, timestampValue); + } + } + } + } + + private void addRawBlock(DoubleBlock valueBlock, LongBlock timestampBlock, BooleanVector mask) { + for (int p = 0; p < valueBlock.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + int valueValueCount = valueBlock.getValueCount(p); + if (valueValueCount == 0) { + continue; + } + int timestampValueCount = timestampBlock.getValueCount(p); + if (timestampValueCount == 0) { + continue; + } + int valueStart = valueBlock.getFirstValueIndex(p); + int valueEnd = valueStart + valueValueCount; + for (int valueOffset = valueStart; valueOffset < valueEnd; valueOffset++) { + double valueValue = valueBlock.getDouble(valueOffset); + int timestampStart = timestampBlock.getFirstValueIndex(p); + int timestampEnd = timestampStart + timestampValueCount; + for (int timestampOffset = timestampStart; timestampOffset < timestampEnd; timestampOffset++) { + long timestampValue = timestampBlock.getLong(timestampOffset); + DerivDoubleAggregator.combine(state, valueValue, timestampValue); + } + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block countUncast = page.getBlock(channels.get(0)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert count.getPositionCount() == 1; + Block sumValUncast = page.getBlock(channels.get(1)); + if (sumValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumVal = ((DoubleBlock) sumValUncast).asVector(); + assert sumVal.getPositionCount() == 1; + Block sumTsUncast = page.getBlock(channels.get(2)); + if (sumTsUncast.areAllValuesNull()) { + return; + } + LongVector sumTs = ((LongBlock) sumTsUncast).asVector(); + assert sumTs.getPositionCount() == 1; + Block sumTsValUncast = page.getBlock(channels.get(3)); + if (sumTsValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumTsVal = ((DoubleBlock) sumTsValUncast).asVector(); + assert sumTsVal.getPositionCount() == 1; + Block sumTsSqUncast = page.getBlock(channels.get(4)); + if (sumTsSqUncast.areAllValuesNull()) { + return; + } + LongVector sumTsSq = ((LongBlock) sumTsSqUncast).asVector(); + assert sumTsSq.getPositionCount() == 1; + DerivDoubleAggregator.combineIntermediate(state, count.getLong(0), sumVal.getDouble(0), sumTs.getLong(0), sumTsVal.getDouble(0), sumTsSq.getLong(0)); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = DerivDoubleAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivDoubleAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivDoubleAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..440c03cd403f8 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivDoubleAggregatorFunctionSupplier.java @@ -0,0 +1,47 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link DerivDoubleAggregator}. + * This class is generated. Edit {@code AggregatorFunctionSupplierImplementer} instead. + */ +public final class DerivDoubleAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + public DerivDoubleAggregatorFunctionSupplier() { + } + + @Override + public List nonGroupingIntermediateStateDesc() { + return DerivDoubleAggregatorFunction.intermediateStateDesc(); + } + + @Override + public List groupingIntermediateStateDesc() { + return DerivDoubleGroupingAggregatorFunction.intermediateStateDesc(); + } + + @Override + public DerivDoubleAggregatorFunction aggregator(DriverContext driverContext, + List channels) { + return DerivDoubleAggregatorFunction.create(driverContext, channels); + } + + @Override + public DerivDoubleGroupingAggregatorFunction groupingAggregator(DriverContext driverContext, + List channels) { + return DerivDoubleGroupingAggregatorFunction.create(channels, driverContext); + } + + @Override + public String describe() { + return "deriv of doubles"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..f7de242c2b950 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java @@ -0,0 +1,390 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link DerivDoubleAggregator}. + * This class is generated. Edit {@code GroupingAggregatorImplementer} instead. + */ +public final class DerivDoubleGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("timestamps", ElementType.LONG), + new IntermediateStateDesc("values", ElementType.DOUBLE) ); + + private final DerivDoubleAggregator.GroupingState state; + + private final List channels; + + private final DriverContext driverContext; + + public DerivDoubleGroupingAggregatorFunction(List channels, + DerivDoubleAggregator.GroupingState state, DriverContext driverContext) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + } + + public static DerivDoubleGroupingAggregatorFunction create(List channels, + DriverContext driverContext) { + return new DerivDoubleGroupingAggregatorFunction(channels, DerivDoubleAggregator.initGrouping(driverContext), driverContext); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, + Page page) { + LongBlock timestampBlock = page.getBlock(channels.get(0)); + DoubleBlock valueBlock = page.getBlock(channels.get(1)); + LongVector timestampVector = timestampBlock.asVector(); + if (timestampVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, timestampBlock, valueBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, timestampBlock, valueBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, timestampBlock, valueBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, timestampBlock, valueBlock); + } + + @Override + public void close() { + } + }; + } + DoubleVector valueVector = valueBlock.asVector(); + if (valueVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, timestampBlock, valueBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, timestampBlock, valueBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, timestampBlock, valueBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, timestampBlock, valueBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, timestampVector, valueVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, timestampVector, valueVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, timestampVector, valueVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock timestampBlock, + DoubleBlock valueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (timestampBlock.isNull(valuesPosition)) { + continue; + } + if (valueBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int timestampStart = timestampBlock.getFirstValueIndex(valuesPosition); + int timestampEnd = timestampStart + timestampBlock.getValueCount(valuesPosition); + for (int timestampOffset = timestampStart; timestampOffset < timestampEnd; timestampOffset++) { + long timestampValue = timestampBlock.getLong(timestampOffset); + int valueStart = valueBlock.getFirstValueIndex(valuesPosition); + int valueEnd = valueStart + valueBlock.getValueCount(valuesPosition); + for (int valueOffset = valueStart; valueOffset < valueEnd; valueOffset++) { + double valueValue = valueBlock.getDouble(valueOffset); + DerivDoubleAggregator.combine(state, groupId, timestampValue, valueValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector timestampVector, + DoubleVector valueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + long timestampValue = timestampVector.getLong(valuesPosition); + double valueValue = valueVector.getDouble(valuesPosition); + DerivDoubleAggregator.combine(state, groupId, timestampValue, valueValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + DoubleBlock values = (DoubleBlock) valuesUncast; + assert timestamps.getPositionCount() == values.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + DerivDoubleAggregator.combineIntermediate(state, groupId, timestamps, values, valuesPosition); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock timestampBlock, + DoubleBlock valueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (timestampBlock.isNull(valuesPosition)) { + continue; + } + if (valueBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int timestampStart = timestampBlock.getFirstValueIndex(valuesPosition); + int timestampEnd = timestampStart + timestampBlock.getValueCount(valuesPosition); + for (int timestampOffset = timestampStart; timestampOffset < timestampEnd; timestampOffset++) { + long timestampValue = timestampBlock.getLong(timestampOffset); + int valueStart = valueBlock.getFirstValueIndex(valuesPosition); + int valueEnd = valueStart + valueBlock.getValueCount(valuesPosition); + for (int valueOffset = valueStart; valueOffset < valueEnd; valueOffset++) { + double valueValue = valueBlock.getDouble(valueOffset); + DerivDoubleAggregator.combine(state, groupId, timestampValue, valueValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector timestampVector, + DoubleVector valueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + long timestampValue = timestampVector.getLong(valuesPosition); + double valueValue = valueVector.getDouble(valuesPosition); + DerivDoubleAggregator.combine(state, groupId, timestampValue, valueValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + DoubleBlock values = (DoubleBlock) valuesUncast; + assert timestamps.getPositionCount() == values.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + DerivDoubleAggregator.combineIntermediate(state, groupId, timestamps, values, valuesPosition); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, LongBlock timestampBlock, + DoubleBlock valueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + if (timestampBlock.isNull(valuesPosition)) { + continue; + } + if (valueBlock.isNull(valuesPosition)) { + continue; + } + int groupId = groups.getInt(groupPosition); + int timestampStart = timestampBlock.getFirstValueIndex(valuesPosition); + int timestampEnd = timestampStart + timestampBlock.getValueCount(valuesPosition); + for (int timestampOffset = timestampStart; timestampOffset < timestampEnd; timestampOffset++) { + long timestampValue = timestampBlock.getLong(timestampOffset); + int valueStart = valueBlock.getFirstValueIndex(valuesPosition); + int valueEnd = valueStart + valueBlock.getValueCount(valuesPosition); + for (int valueOffset = valueStart; valueOffset < valueEnd; valueOffset++) { + double valueValue = valueBlock.getDouble(valueOffset); + DerivDoubleAggregator.combine(state, groupId, timestampValue, valueValue); + } + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, LongVector timestampVector, + DoubleVector valueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + int groupId = groups.getInt(groupPosition); + long timestampValue = timestampVector.getLong(valuesPosition); + double valueValue = valueVector.getDouble(valuesPosition); + DerivDoubleAggregator.combine(state, groupId, timestampValue, valueValue); + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + DoubleBlock values = (DoubleBlock) valuesUncast; + assert timestamps.getPositionCount() == values.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + int valuesPosition = groupPosition + positionOffset; + DerivDoubleAggregator.combineIntermediate(state, groupId, timestamps, values, valuesPosition); + } + } + + private void maybeEnableGroupIdTracking(SeenGroupIds seenGroupIds, LongBlock timestampBlock, + DoubleBlock valueBlock) { + if (timestampBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + if (valueBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + GroupingAggregatorEvaluationContext ctx) { + blocks[offset] = DerivDoubleAggregator.evaluateFinal(state, selected, ctx); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivDoubleAggregator.java new file mode 100644 index 0000000000000..be04eef30b19b --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivDoubleAggregator.java @@ -0,0 +1,170 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.DoubleArray; +import org.elasticsearch.common.util.LongArray; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; + +@Aggregator( + { + @IntermediateState(name = "count", type = "LONG"), + @IntermediateState(name = "sumVal", type = "DOUBLE"), + @IntermediateState(name = "sumTs", type = "LONG"), + @IntermediateState(name = "sumTsVal", type = "DOUBLE"), + @IntermediateState(name = "sumTsSq", type = "LONG") } +) +@GroupingAggregator( + { @IntermediateState(name = "timestamps", type = "LONG_BLOCK"), @IntermediateState(name = "values", type = "DOUBLE_BLOCK") } +) +class DerivDoubleAggregator { + + public static SimpleLinearRegressionWithTimeseries initSingle(DriverContext driverContext) { + return new SimpleLinearRegressionWithTimeseries(); + } + + public static void combine(SimpleLinearRegressionWithTimeseries current, double value, long timestamp) { + current.add(timestamp, value); + } + + public static void combineIntermediate( + SimpleLinearRegressionWithTimeseries state, + long count, + double sumVal, + long sumTs, + double sumTsVal, + long sumTsSq + ) { + state.count += count; + state.sumVal += sumVal; + state.sumTs += sumTs; + state.sumTsVal += sumTsVal; + state.sumTsSq += sumTsSq; + } + + public static Block evaluateFinal(SimpleLinearRegressionWithTimeseries state, DriverContext driverContext) { + BlockFactory blockFactory = driverContext.blockFactory(); + var slope = state.slope(); + if (Double.isNaN(slope)) { + return blockFactory.newConstantNullBlock(1); + } + return blockFactory.newConstantDoubleBlockWith(slope, 1); + } + + public static GroupingState initGrouping(DriverContext driverContext) { + return new GroupingState(driverContext.bigArrays()); + } + + public static void combine(GroupingState state, int groupId, long timestamp, double value) { + // TODO + } + + public static void combineIntermediate( + GroupingState state, + int groupId, + LongBlock timestamps, // stylecheck + DoubleBlock values, + int otherPosition + ) { + // TODO use groupId + state.collectValue(groupId, timestamps.getLong(otherPosition), values.getDouble(otherPosition)); + } + + public static Block evaluateFinal(GroupingState state, IntVector selectedGroups, GroupingAggregatorEvaluationContext ctx) { + // Block evaluatePercentile(IntVector selected, DriverContext driverContext) { + // try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount())) { + // for (int i = 0; i < selected.getPositionCount(); i++) { + // int si = selected.getInt(i); + // if (si >= digests.size()) { + // builder.appendNull(); + // continue; + // } + // final TDigestState digest = digests.get(si); + // if (percentile != null && digest != null && digest.size() > 0) { + // builder.appendDouble(digest.quantile(percentile / 100)); + // } else { + // builder.appendNull(); + // } + // } + // return builder.build(); + // } + // } + try (DoubleBlock.Builder builder = ctx.driverContext().blockFactory().newDoubleBlockBuilder(selectedGroups.getPositionCount())) { + for (int i = 0; i < selectedGroups.getPositionCount(); i++) { + int groupId = selectedGroups.getInt(i); + // TODO must use groupId + double result = 1.0; + if (Double.isNaN(result)) { + builder.appendNull(); + continue; + } + builder.appendDouble(result); + } + return builder.build(); + } + } + + public static final class GroupingState extends AbstractArrayState { + private final BigArrays bigArrays; + private LongArray timestamps; + private DoubleArray values; + + GroupingState(BigArrays bigArrays) { + super(bigArrays); + this.bigArrays = bigArrays; + this.timestamps = bigArrays.newLongArray(1L); + this.values = bigArrays.newDoubleArray(1L); + } + + void collectValue(int groupId, long timestamp, double value) { + if (groupId < timestamps.size()) { + timestamps.set(groupId, timestamp); + } else { + timestamps = bigArrays.grow(timestamps, groupId + 1); + } + timestamps.set(groupId, timestamp); + if (groupId < values.size()) { + values.set(groupId, value); + } else { + values = bigArrays.grow(values, groupId + 1); + } + values.set(groupId, value); + } + + @Override + public void close() { + Releasables.close(timestamps, values, super::close); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + try ( + LongBlock.Builder timestampBuilder = driverContext.blockFactory().newLongBlockBuilder(selected.getPositionCount()); + DoubleBlock.Builder valueBuilder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount()); + ) { + for (int i = 0; i < selected.getPositionCount(); i++) { + int groupId = selected.getInt(i); + timestampBuilder.appendLong(timestamps.get(groupId)); + valueBuilder.appendDouble(values.get(groupId)); + } + blocks[offset] = timestampBuilder.build(); + blocks[offset + 1] = valueBuilder.build(); + } + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SimpleLinearRegressionWithTimeseries.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SimpleLinearRegressionWithTimeseries.java index 42ffdd52d6323..c98340f2e4750 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SimpleLinearRegressionWithTimeseries.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SimpleLinearRegressionWithTimeseries.java @@ -6,7 +6,24 @@ */ package org.elasticsearch.compute.aggregation; -class SimpleLinearRegressionWithTimeseries { +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.operator.DriverContext; + +class SimpleLinearRegressionWithTimeseries implements AggregatorState { + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset + 0] = driverContext.blockFactory().newConstantLongBlockWith(count, 1); + blocks[offset + 1] = driverContext.blockFactory().newConstantDoubleBlockWith(sumVal, 1); + blocks[offset + 2] = driverContext.blockFactory().newConstantLongBlockWith(sumTs, 1); + blocks[offset + 3] = driverContext.blockFactory().newConstantDoubleBlockWith(sumTsVal, 1); + blocks[offset + 4] = driverContext.blockFactory().newConstantLongBlockWith(sumTsSq, 1); + } + + @Override + public void close() { + + } + long count; double sumVal; long sumTs; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java index 9a51b7ba85d3b..e7f32e3a2a154 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java @@ -9,8 +9,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; -import org.elasticsearch.compute.aggregation.DerivDoubleGroupingAggregatorFunction; -import org.elasticsearch.compute.aggregation.DerivLongGroupingAggregatorFunction; +import org.elasticsearch.compute.aggregation.DerivDoubleAggregatorFunctionSupplier; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; @@ -111,8 +110,7 @@ public String getWriteableName() { public AggregatorFunctionSupplier supplier() { final DataType type = field().dataType(); return switch (type) { - case INTEGER, LONG -> new DerivLongGroupingAggregatorFunction.Supplier(); - case DOUBLE -> new DerivDoubleGroupingAggregatorFunction.Supplier(); + case DOUBLE -> new DerivDoubleAggregatorFunctionSupplier(); default -> throw new IllegalArgumentException("Unsupported data type for deriv aggregation: " + type); }; } From 6368797ed03a358b988ab445f8bc5f1c024c46ad Mon Sep 17 00:00:00 2001 From: Pablo Date: Thu, 13 Nov 2025 14:54:09 -0800 Subject: [PATCH 19/27] remove unneeded stuff --- x-pack/plugin/esql/compute/build.gradle | 12 - .../X-DerivGroupingAggregatorFunction.java.st | 303 ------------------ 2 files changed, 315 deletions(-) delete mode 100644 x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-DerivGroupingAggregatorFunction.java.st diff --git a/x-pack/plugin/esql/compute/build.gradle b/x-pack/plugin/esql/compute/build.gradle index cb49389b1f096..e792bad34f67a 100644 --- a/x-pack/plugin/esql/compute/build.gradle +++ b/x-pack/plugin/esql/compute/build.gradle @@ -993,16 +993,4 @@ tasks.named('stringTemplates').configure { it.inputFile = rateAggregatorInputFile it.outputFile = "org/elasticsearch/compute/aggregation/RateLongGroupingAggregatorFunction.java" } - -// File derivAggregatorInputFile = file("src/main/java/org/elasticsearch/compute/aggregation/X-DerivGroupingAggregatorFunction.java.st") -// template { -// it.properties = doubleProperties -// it.inputFile = derivAggregatorInputFile -// it.outputFile = "org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java" -// } -// template { -// it.properties = longProperties -// it.inputFile = derivAggregatorInputFile -// it.outputFile = "org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java" -// } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-DerivGroupingAggregatorFunction.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-DerivGroupingAggregatorFunction.java.st deleted file mode 100644 index af28dd81f498a..0000000000000 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-DerivGroupingAggregatorFunction.java.st +++ /dev/null @@ -1,303 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -package org.elasticsearch.compute.aggregation; - -// begin generated imports -import org.elasticsearch.common.util.ObjectArray; -import org.elasticsearch.compute.data.Block; -import org.elasticsearch.compute.data.BlockFactory; -import org.elasticsearch.compute.data.DoubleBlock; -import org.elasticsearch.compute.data.DoubleVector; -import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; -import org.elasticsearch.compute.data.IntBlock; -import org.elasticsearch.compute.data.IntVector; -import org.elasticsearch.compute.data.LongBlock; -import org.elasticsearch.compute.data.LongVector; -import org.elasticsearch.compute.data.Page; -import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.core.Releasables; - -import java.util.List; -// end generated imports - -@SuppressWarnings("cast") -public class Deriv$Type$GroupingAggregatorFunction implements GroupingAggregatorFunction { - - private static final List INTERMEDIATE_STATE_DESC = List.of( - new IntermediateStateDesc("sumVal", ElementType.$TYPE$), - new IntermediateStateDesc("sumTs", ElementType.LONG), - new IntermediateStateDesc("sumTsVal", ElementType.$TYPE$), - new IntermediateStateDesc("sumTsSq", ElementType.LONG), - new IntermediateStateDesc("count", ElementType.LONG) - ); - - private final List channels; - private final DriverContext driverContext; - private ObjectArray states; - - public Deriv$Type$GroupingAggregatorFunction(List channels, DriverContext driverContext) { - this.states = driverContext.bigArrays().newObjectArray(256); - this.channels = channels; - this.driverContext = driverContext; - } - - public static class Supplier implements AggregatorFunctionSupplier { - - @Override - public List nonGroupingIntermediateStateDesc() { - throw new UnsupportedOperationException("DerivGroupingAggregatorFunction does not support non-grouping aggregation"); - } - - @Override - public List groupingIntermediateStateDesc() { - return INTERMEDIATE_STATE_DESC; - } - - @Override - public AggregatorFunction aggregator(DriverContext driverContext, List channels) { - throw new UnsupportedOperationException("DerivGroupingAggregatorFunction does not support non-grouping aggregation"); - } - - @Override - public GroupingAggregatorFunction groupingAggregator(DriverContext driverContext, List channels) { - return new Deriv$Type$GroupingAggregatorFunction(channels, driverContext); - } - - @Override - public String describe() { - return "derivative"; - } - } - - @Override - public AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { - final $Type$Block valueBlock = page.getBlock(channels.get(0)); - final LongBlock timestampBlock = page.getBlock(channels.get(1)); - final $Type$Vector valueVector = valueBlock.asVector(); - return new AddInput() { - @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addBlockInput(positionOffset, groupIds, timestampBlock, valueVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { - addBlockInput(positionOffset, groupIds, timestampBlock, valueVector); - } - - @Override - public void add(int positionOffset, IntVector groupIds) { - for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { - int valuesPosition = groupPosition + positionOffset; - int groupId = groupIds.getInt(groupPosition); - $type$ vValue = valueVector.get$Type$(valuesPosition); - long ts = timestampBlock.getLong(valuesPosition); - SimpleLinearRegressionWithTimeseries state = states.get(groupId); - if (state == null) { - state = new SimpleLinearRegressionWithTimeseries(); - states.set(groupId, state); - } - state.add(ts, (double) vValue); // TODO - value needs to be converted to double - } - } - - @Override - public void close() { - - } - - private void addBlockInput(int positionOffset, IntBlock groupIds, LongBlock timestampBlock, $Type$Vector valueVector) { - for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { - if (groupIds.isNull(groupPosition)) { - continue; - } - int valuePosition = groupPosition + positionOffset; - if (valueBlock.isNull(valuePosition)) { - continue; - } - int groupStart = groupIds.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groupIds.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groupIds.getInt(g); - int vStart = valueBlock.getFirstValueIndex(valuePosition); - int vEnd = vStart + valueBlock.getValueCount(valuePosition); - for (int v = vStart; v < vEnd; v++) { - long ts = timestampBlock.getLong(valuePosition); - $type$ val = valueVector.get$Type$(valuePosition); - SimpleLinearRegressionWithTimeseries state = states.get(groupId); - if (state == null) { - state = new SimpleLinearRegressionWithTimeseries(); - states.set(groupId, state); - } - state.add(ts, val); // TODO - value needs to be converted to double - } - } - } - } - }; - } - - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - // No-op - } - - @Override - public void addIntermediateInput(int positionOffset, IntArrayBlock groupIdVector, Page page) { - addIntermediateBlockInput(positionOffset, groupIdVector, page); - } - - @Override - public void addIntermediateInput(int positionOffset, IntBigArrayBlock groupIdVector, Page page) { - addIntermediateBlockInput(positionOffset, groupIdVector, page); - } - - private void addIntermediateBlockInput(int positionOffset, IntBlock groupIdVector, Page page) { - $Type$Block sumValBlock = page.getBlock(channels.get(0)); - LongBlock sumTsBlock = page.getBlock(channels.get(1)); - $Type$Block sumTsValBlock = page.getBlock(channels.get(2)); - LongBlock sumTsSqBlock = page.getBlock(channels.get(3)); - LongBlock countBlock = page.getBlock(channels.get(4)); - - if (sumTsBlock.getTotalValueCount() != sumValBlock.getTotalValueCount() - || sumTsBlock.getTotalValueCount() != sumTsValBlock.getTotalValueCount() - || sumTsBlock.getTotalValueCount() != countBlock.getTotalValueCount()) { - throw new IllegalStateException("Mismatched intermediate state block value counts"); - } - - for (int groupPos = 0; groupPos < groupIdVector.getPositionCount(); groupPos++) { - int valuePos = groupPos + positionOffset; - - int firstGroup = groupIdVector.getFirstValueIndex(groupPos); - int lastGroup = firstGroup + groupIdVector.getValueCount(groupPos); - - for (int g = firstGroup; g < lastGroup; g++) { - int groupId = groupIdVector.getInt(g); - states = driverContext.bigArrays().grow(states, groupId + 1); - var state = states.get(groupId); - if (state == null) { - state = new SimpleLinearRegressionWithTimeseries(); // TODO - what happens for int / long - states.set(groupId, state); - } - long sumTs = sumTsBlock.getLong(valuePos); - $type$ sumVal = sumValBlock.get$Type$(valuePos); - $type$ sumTsVal = sumTsValBlock.get$Type$(valuePos); - long sumTsSq = sumTsSqBlock.getLong(valuePos); - long count = countBlock.getLong(valuePos); - state.sumTs += sumTs; - state.sumVal += sumVal; - state.sumTsVal += sumTsVal; - state.sumTsSq += sumTsSq; - state.count += count; - } - } - - } - - @Override - public void addIntermediateInput(int positionOffset, IntVector groupIdVector, Page page) { - $Type$Block sumValBlock = page.getBlock(channels.get(0)); - LongBlock sumTsBlock = page.getBlock(channels.get(1)); - $Type$Block sumTsValBlock = page.getBlock(channels.get(2)); - LongBlock sumTsSqBlock = page.getBlock(channels.get(3)); - LongBlock countBlock = page.getBlock(channels.get(4)); - - if (sumTsBlock.getTotalValueCount() != sumValBlock.getTotalValueCount() - || sumTsBlock.getTotalValueCount() != sumTsValBlock.getTotalValueCount() - || sumTsBlock.getTotalValueCount() != countBlock.getTotalValueCount()) { - throw new IllegalStateException("Mismatched intermediate state block value counts"); - } - - for (int groupPos = 0; groupPos < groupIdVector.getPositionCount(); groupPos++) { - int valuePos = groupPos + positionOffset; - int groupId = groupIdVector.getInt(groupPos); - states = driverContext.bigArrays().grow(states, groupId + 1); - var state = states.get(groupId); - if (state == null) { - state = new SimpleLinearRegressionWithTimeseries(); // TODO: what about double conversion - states.set(groupId, state); - } - long sumTs = sumTsBlock.getLong(valuePos); - $type$ sumVal = sumValBlock.get$Type$(valuePos); - $type$ sumTsVal = sumTsValBlock.get$Type$(valuePos); - long sumTsSq = sumTsSqBlock.getLong(valuePos); - long count = countBlock.getLong(valuePos); - state.sumTs += sumTs; - state.sumVal += sumVal; - state.sumTsVal += sumTsVal; - state.sumTsSq += sumTsSq; - state.count += count; - } - } - - @Override - public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { - BlockFactory blockFactory = driverContext.blockFactory(); - int positionCount = selected.getPositionCount(); - try ( - var sumValBuilder = blockFactory.new$Type$BlockBuilder(positionCount); - var sumTsBuilder = blockFactory.newLongBlockBuilder(positionCount); - var sumTsValBuilder = blockFactory.new$Type$BlockBuilder(positionCount); - var sumTsSqBuilder = blockFactory.newLongBlockBuilder(positionCount); - var countBuilder = blockFactory.newLongBlockBuilder(positionCount) - ) { - for (int p = 0; p < positionCount; p++) { - int groupId = selected.getInt(p); - SimpleLinearRegressionWithTimeseries state = states.get(groupId); - if (state == null) { - sumValBuilder.appendNull(); - sumTsBuilder.appendNull(); - sumTsValBuilder.appendNull(); - sumTsSqBuilder.appendNull(); - countBuilder.appendNull(); - } else { - sumValBuilder.append$Type$(($type$) state.sumVal); - sumTsBuilder.appendLong(state.sumTs); - sumTsValBuilder.append$Type$(($type$) state.sumTsVal); // TODO: fix this actually - sumTsSqBuilder.appendLong(state.sumTsSq); - countBuilder.appendLong(state.count); - } - } - blocks[offset] = sumValBuilder.build(); - blocks[offset + 1] = sumTsBuilder.build(); - blocks[offset + 2] = sumTsValBuilder.build(); - blocks[offset + 3] = sumTsSqBuilder.build(); - blocks[offset + 4] = countBuilder.build(); - } - } - - @Override - public void evaluateFinal(Block[] blocks, int offset, IntVector selected, GroupingAggregatorEvaluationContext evaluationContext) { - BlockFactory blockFactory = driverContext.blockFactory(); - int positionCount = selected.getPositionCount(); - try (var resultBuilder = blockFactory.newDoubleBlockBuilder(positionCount)) { - for (int p = 0; p < positionCount; p++) { - int groupId = selected.getInt(p); - SimpleLinearRegressionWithTimeseries state = states.get(groupId); - if (state == null) { - resultBuilder.appendNull(); - } else { - double deriv = state.slope(); - resultBuilder.appendDouble(deriv); - } - } - blocks[offset] = resultBuilder.build(); - } - } - - @Override - public int intermediateBlockCount() { - return INTERMEDIATE_STATE_DESC.size(); - } - - @Override - public void close() { - Releasables.close(states); - } -} From 7420ad7fa166aee341a2f607068296778ebff449 Mon Sep 17 00:00:00 2001 From: Pablo Date: Mon, 17 Nov 2025 10:57:14 -0800 Subject: [PATCH 20/27] rewrite of deriv --- ...DerivDoubleGroupingAggregatorFunction.java | 244 ++++++---- .../DerivLongAggregatorFunction.java | 235 ++++++++++ .../DerivLongAggregatorFunctionSupplier.java | 47 ++ .../DerivLongGroupingAggregatorFunction.java | 438 ++++++++++++++++++ .../aggregation/DerivDoubleAggregator.java | 114 ++--- .../aggregation/DerivLongAggregator.java | 78 ++++ .../main/resources/k8s-timeseries.csv-spec | 10 +- .../expression/function/aggregate/Deriv.java | 2 + 8 files changed, 1008 insertions(+), 160 deletions(-) create mode 100644 x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivLongAggregatorFunction.java create mode 100644 x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivLongAggregatorFunctionSupplier.java create mode 100644 x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java create mode 100644 x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivLongAggregator.java diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java index f7de242c2b950..adf77eff37fa8 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java @@ -27,8 +27,11 @@ */ public final class DerivDoubleGroupingAggregatorFunction implements GroupingAggregatorFunction { private static final List INTERMEDIATE_STATE_DESC = List.of( - new IntermediateStateDesc("timestamps", ElementType.LONG), - new IntermediateStateDesc("values", ElementType.DOUBLE) ); + new IntermediateStateDesc("count", ElementType.LONG), + new IntermediateStateDesc("sumVal", ElementType.DOUBLE), + new IntermediateStateDesc("sumTs", ElementType.LONG), + new IntermediateStateDesc("sumTsVal", ElementType.DOUBLE), + new IntermediateStateDesc("sumTsSq", ElementType.LONG) ); private final DerivDoubleAggregator.GroupingState state; @@ -60,25 +63,25 @@ public int intermediateBlockCount() { @Override public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { - LongBlock timestampBlock = page.getBlock(channels.get(0)); - DoubleBlock valueBlock = page.getBlock(channels.get(1)); - LongVector timestampVector = timestampBlock.asVector(); - if (timestampVector == null) { - maybeEnableGroupIdTracking(seenGroupIds, timestampBlock, valueBlock); + DoubleBlock valueBlock = page.getBlock(channels.get(0)); + LongBlock timestampBlock = page.getBlock(channels.get(1)); + DoubleVector valueVector = valueBlock.asVector(); + if (valueVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, valueBlock, timestampBlock); return new GroupingAggregatorFunction.AddInput() { @Override public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, timestampBlock, valueBlock); + addRawInput(positionOffset, groupIds, valueBlock, timestampBlock); } @Override public void add(int positionOffset, IntBigArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, timestampBlock, valueBlock); + addRawInput(positionOffset, groupIds, valueBlock, timestampBlock); } @Override public void add(int positionOffset, IntVector groupIds) { - addRawInput(positionOffset, groupIds, timestampBlock, valueBlock); + addRawInput(positionOffset, groupIds, valueBlock, timestampBlock); } @Override @@ -86,23 +89,23 @@ public void close() { } }; } - DoubleVector valueVector = valueBlock.asVector(); - if (valueVector == null) { - maybeEnableGroupIdTracking(seenGroupIds, timestampBlock, valueBlock); + LongVector timestampVector = timestampBlock.asVector(); + if (timestampVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, valueBlock, timestampBlock); return new GroupingAggregatorFunction.AddInput() { @Override public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, timestampBlock, valueBlock); + addRawInput(positionOffset, groupIds, valueBlock, timestampBlock); } @Override public void add(int positionOffset, IntBigArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, timestampBlock, valueBlock); + addRawInput(positionOffset, groupIds, valueBlock, timestampBlock); } @Override public void add(int positionOffset, IntVector groupIds) { - addRawInput(positionOffset, groupIds, timestampBlock, valueBlock); + addRawInput(positionOffset, groupIds, valueBlock, timestampBlock); } @Override @@ -113,17 +116,17 @@ public void close() { return new GroupingAggregatorFunction.AddInput() { @Override public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, timestampVector, valueVector); + addRawInput(positionOffset, groupIds, valueVector, timestampVector); } @Override public void add(int positionOffset, IntBigArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, timestampVector, valueVector); + addRawInput(positionOffset, groupIds, valueVector, timestampVector); } @Override public void add(int positionOffset, IntVector groupIds) { - addRawInput(positionOffset, groupIds, timestampVector, valueVector); + addRawInput(positionOffset, groupIds, valueVector, timestampVector); } @Override @@ -132,40 +135,40 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock timestampBlock, - DoubleBlock valueBlock) { + private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleBlock valueBlock, + LongBlock timestampBlock) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; } int valuesPosition = groupPosition + positionOffset; - if (timestampBlock.isNull(valuesPosition)) { + if (valueBlock.isNull(valuesPosition)) { continue; } - if (valueBlock.isNull(valuesPosition)) { + if (timestampBlock.isNull(valuesPosition)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - int timestampStart = timestampBlock.getFirstValueIndex(valuesPosition); - int timestampEnd = timestampStart + timestampBlock.getValueCount(valuesPosition); - for (int timestampOffset = timestampStart; timestampOffset < timestampEnd; timestampOffset++) { - long timestampValue = timestampBlock.getLong(timestampOffset); - int valueStart = valueBlock.getFirstValueIndex(valuesPosition); - int valueEnd = valueStart + valueBlock.getValueCount(valuesPosition); - for (int valueOffset = valueStart; valueOffset < valueEnd; valueOffset++) { - double valueValue = valueBlock.getDouble(valueOffset); - DerivDoubleAggregator.combine(state, groupId, timestampValue, valueValue); + int valueStart = valueBlock.getFirstValueIndex(valuesPosition); + int valueEnd = valueStart + valueBlock.getValueCount(valuesPosition); + for (int valueOffset = valueStart; valueOffset < valueEnd; valueOffset++) { + double valueValue = valueBlock.getDouble(valueOffset); + int timestampStart = timestampBlock.getFirstValueIndex(valuesPosition); + int timestampEnd = timestampStart + timestampBlock.getValueCount(valuesPosition); + for (int timestampOffset = timestampStart; timestampOffset < timestampEnd; timestampOffset++) { + long timestampValue = timestampBlock.getLong(timestampOffset); + DerivDoubleAggregator.combine(state, groupId, valueValue, timestampValue); } } } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector timestampVector, - DoubleVector valueVector) { + private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector valueVector, + LongVector timestampVector) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -175,9 +178,9 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector ti int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - long timestampValue = timestampVector.getLong(valuesPosition); double valueValue = valueVector.getDouble(valuesPosition); - DerivDoubleAggregator.combine(state, groupId, timestampValue, valueValue); + long timestampValue = timestampVector.getLong(valuesPosition); + DerivDoubleAggregator.combine(state, groupId, valueValue, timestampValue); } } } @@ -186,17 +189,32 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector ti public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); assert channels.size() == intermediateBlockCount(); - Block timestampsUncast = page.getBlock(channels.get(0)); - if (timestampsUncast.areAllValuesNull()) { + Block countUncast = page.getBlock(channels.get(0)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + Block sumValUncast = page.getBlock(channels.get(1)); + if (sumValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumVal = ((DoubleBlock) sumValUncast).asVector(); + Block sumTsUncast = page.getBlock(channels.get(2)); + if (sumTsUncast.areAllValuesNull()) { + return; + } + LongVector sumTs = ((LongBlock) sumTsUncast).asVector(); + Block sumTsValUncast = page.getBlock(channels.get(3)); + if (sumTsValUncast.areAllValuesNull()) { return; } - LongBlock timestamps = (LongBlock) timestampsUncast; - Block valuesUncast = page.getBlock(channels.get(1)); - if (valuesUncast.areAllValuesNull()) { + DoubleVector sumTsVal = ((DoubleBlock) sumTsValUncast).asVector(); + Block sumTsSqUncast = page.getBlock(channels.get(4)); + if (sumTsSqUncast.areAllValuesNull()) { return; } - DoubleBlock values = (DoubleBlock) valuesUncast; - assert timestamps.getPositionCount() == values.getPositionCount(); + LongVector sumTsSq = ((LongBlock) sumTsSqUncast).asVector(); + assert count.getPositionCount() == sumVal.getPositionCount() && count.getPositionCount() == sumTs.getPositionCount() && count.getPositionCount() == sumTsVal.getPositionCount() && count.getPositionCount() == sumTsSq.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -206,45 +224,45 @@ public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); int valuesPosition = groupPosition + positionOffset; - DerivDoubleAggregator.combineIntermediate(state, groupId, timestamps, values, valuesPosition); + DerivDoubleAggregator.combineIntermediate(state, groupId, count.getLong(valuesPosition), sumVal.getDouble(valuesPosition), sumTs.getLong(valuesPosition), sumTsVal.getDouble(valuesPosition), sumTsSq.getLong(valuesPosition)); } } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock timestampBlock, - DoubleBlock valueBlock) { + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock valueBlock, + LongBlock timestampBlock) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; } int valuesPosition = groupPosition + positionOffset; - if (timestampBlock.isNull(valuesPosition)) { + if (valueBlock.isNull(valuesPosition)) { continue; } - if (valueBlock.isNull(valuesPosition)) { + if (timestampBlock.isNull(valuesPosition)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - int timestampStart = timestampBlock.getFirstValueIndex(valuesPosition); - int timestampEnd = timestampStart + timestampBlock.getValueCount(valuesPosition); - for (int timestampOffset = timestampStart; timestampOffset < timestampEnd; timestampOffset++) { - long timestampValue = timestampBlock.getLong(timestampOffset); - int valueStart = valueBlock.getFirstValueIndex(valuesPosition); - int valueEnd = valueStart + valueBlock.getValueCount(valuesPosition); - for (int valueOffset = valueStart; valueOffset < valueEnd; valueOffset++) { - double valueValue = valueBlock.getDouble(valueOffset); - DerivDoubleAggregator.combine(state, groupId, timestampValue, valueValue); + int valueStart = valueBlock.getFirstValueIndex(valuesPosition); + int valueEnd = valueStart + valueBlock.getValueCount(valuesPosition); + for (int valueOffset = valueStart; valueOffset < valueEnd; valueOffset++) { + double valueValue = valueBlock.getDouble(valueOffset); + int timestampStart = timestampBlock.getFirstValueIndex(valuesPosition); + int timestampEnd = timestampStart + timestampBlock.getValueCount(valuesPosition); + for (int timestampOffset = timestampStart; timestampOffset < timestampEnd; timestampOffset++) { + long timestampValue = timestampBlock.getLong(timestampOffset); + DerivDoubleAggregator.combine(state, groupId, valueValue, timestampValue); } } } } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector timestampVector, - DoubleVector valueVector) { + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVector valueVector, + LongVector timestampVector) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -254,9 +272,9 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - long timestampValue = timestampVector.getLong(valuesPosition); double valueValue = valueVector.getDouble(valuesPosition); - DerivDoubleAggregator.combine(state, groupId, timestampValue, valueValue); + long timestampValue = timestampVector.getLong(valuesPosition); + DerivDoubleAggregator.combine(state, groupId, valueValue, timestampValue); } } } @@ -265,17 +283,32 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); assert channels.size() == intermediateBlockCount(); - Block timestampsUncast = page.getBlock(channels.get(0)); - if (timestampsUncast.areAllValuesNull()) { + Block countUncast = page.getBlock(channels.get(0)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + Block sumValUncast = page.getBlock(channels.get(1)); + if (sumValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumVal = ((DoubleBlock) sumValUncast).asVector(); + Block sumTsUncast = page.getBlock(channels.get(2)); + if (sumTsUncast.areAllValuesNull()) { + return; + } + LongVector sumTs = ((LongBlock) sumTsUncast).asVector(); + Block sumTsValUncast = page.getBlock(channels.get(3)); + if (sumTsValUncast.areAllValuesNull()) { return; } - LongBlock timestamps = (LongBlock) timestampsUncast; - Block valuesUncast = page.getBlock(channels.get(1)); - if (valuesUncast.areAllValuesNull()) { + DoubleVector sumTsVal = ((DoubleBlock) sumTsValUncast).asVector(); + Block sumTsSqUncast = page.getBlock(channels.get(4)); + if (sumTsSqUncast.areAllValuesNull()) { return; } - DoubleBlock values = (DoubleBlock) valuesUncast; - assert timestamps.getPositionCount() == values.getPositionCount(); + LongVector sumTsSq = ((LongBlock) sumTsSqUncast).asVector(); + assert count.getPositionCount() == sumVal.getPositionCount() && count.getPositionCount() == sumTs.getPositionCount() && count.getPositionCount() == sumTsVal.getPositionCount() && count.getPositionCount() == sumTsSq.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -285,44 +318,44 @@ public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Pa for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); int valuesPosition = groupPosition + positionOffset; - DerivDoubleAggregator.combineIntermediate(state, groupId, timestamps, values, valuesPosition); + DerivDoubleAggregator.combineIntermediate(state, groupId, count.getLong(valuesPosition), sumVal.getDouble(valuesPosition), sumTs.getLong(valuesPosition), sumTsVal.getDouble(valuesPosition), sumTsSq.getLong(valuesPosition)); } } } - private void addRawInput(int positionOffset, IntVector groups, LongBlock timestampBlock, - DoubleBlock valueBlock) { + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock valueBlock, + LongBlock timestampBlock) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { int valuesPosition = groupPosition + positionOffset; - if (timestampBlock.isNull(valuesPosition)) { + if (valueBlock.isNull(valuesPosition)) { continue; } - if (valueBlock.isNull(valuesPosition)) { + if (timestampBlock.isNull(valuesPosition)) { continue; } int groupId = groups.getInt(groupPosition); - int timestampStart = timestampBlock.getFirstValueIndex(valuesPosition); - int timestampEnd = timestampStart + timestampBlock.getValueCount(valuesPosition); - for (int timestampOffset = timestampStart; timestampOffset < timestampEnd; timestampOffset++) { - long timestampValue = timestampBlock.getLong(timestampOffset); - int valueStart = valueBlock.getFirstValueIndex(valuesPosition); - int valueEnd = valueStart + valueBlock.getValueCount(valuesPosition); - for (int valueOffset = valueStart; valueOffset < valueEnd; valueOffset++) { - double valueValue = valueBlock.getDouble(valueOffset); - DerivDoubleAggregator.combine(state, groupId, timestampValue, valueValue); + int valueStart = valueBlock.getFirstValueIndex(valuesPosition); + int valueEnd = valueStart + valueBlock.getValueCount(valuesPosition); + for (int valueOffset = valueStart; valueOffset < valueEnd; valueOffset++) { + double valueValue = valueBlock.getDouble(valueOffset); + int timestampStart = timestampBlock.getFirstValueIndex(valuesPosition); + int timestampEnd = timestampStart + timestampBlock.getValueCount(valuesPosition); + for (int timestampOffset = timestampStart; timestampOffset < timestampEnd; timestampOffset++) { + long timestampValue = timestampBlock.getLong(timestampOffset); + DerivDoubleAggregator.combine(state, groupId, valueValue, timestampValue); } } } } - private void addRawInput(int positionOffset, IntVector groups, LongVector timestampVector, - DoubleVector valueVector) { + private void addRawInput(int positionOffset, IntVector groups, DoubleVector valueVector, + LongVector timestampVector) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { int valuesPosition = groupPosition + positionOffset; int groupId = groups.getInt(groupPosition); - long timestampValue = timestampVector.getLong(valuesPosition); double valueValue = valueVector.getDouble(valuesPosition); - DerivDoubleAggregator.combine(state, groupId, timestampValue, valueValue); + long timestampValue = timestampVector.getLong(valuesPosition); + DerivDoubleAggregator.combine(state, groupId, valueValue, timestampValue); } } @@ -330,30 +363,45 @@ private void addRawInput(int positionOffset, IntVector groups, LongVector timest public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); assert channels.size() == intermediateBlockCount(); - Block timestampsUncast = page.getBlock(channels.get(0)); - if (timestampsUncast.areAllValuesNull()) { + Block countUncast = page.getBlock(channels.get(0)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + Block sumValUncast = page.getBlock(channels.get(1)); + if (sumValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumVal = ((DoubleBlock) sumValUncast).asVector(); + Block sumTsUncast = page.getBlock(channels.get(2)); + if (sumTsUncast.areAllValuesNull()) { + return; + } + LongVector sumTs = ((LongBlock) sumTsUncast).asVector(); + Block sumTsValUncast = page.getBlock(channels.get(3)); + if (sumTsValUncast.areAllValuesNull()) { return; } - LongBlock timestamps = (LongBlock) timestampsUncast; - Block valuesUncast = page.getBlock(channels.get(1)); - if (valuesUncast.areAllValuesNull()) { + DoubleVector sumTsVal = ((DoubleBlock) sumTsValUncast).asVector(); + Block sumTsSqUncast = page.getBlock(channels.get(4)); + if (sumTsSqUncast.areAllValuesNull()) { return; } - DoubleBlock values = (DoubleBlock) valuesUncast; - assert timestamps.getPositionCount() == values.getPositionCount(); + LongVector sumTsSq = ((LongBlock) sumTsSqUncast).asVector(); + assert count.getPositionCount() == sumVal.getPositionCount() && count.getPositionCount() == sumTs.getPositionCount() && count.getPositionCount() == sumTsVal.getPositionCount() && count.getPositionCount() == sumTsSq.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { int groupId = groups.getInt(groupPosition); int valuesPosition = groupPosition + positionOffset; - DerivDoubleAggregator.combineIntermediate(state, groupId, timestamps, values, valuesPosition); + DerivDoubleAggregator.combineIntermediate(state, groupId, count.getLong(valuesPosition), sumVal.getDouble(valuesPosition), sumTs.getLong(valuesPosition), sumTsVal.getDouble(valuesPosition), sumTsSq.getLong(valuesPosition)); } } - private void maybeEnableGroupIdTracking(SeenGroupIds seenGroupIds, LongBlock timestampBlock, - DoubleBlock valueBlock) { - if (timestampBlock.mayHaveNulls()) { + private void maybeEnableGroupIdTracking(SeenGroupIds seenGroupIds, DoubleBlock valueBlock, + LongBlock timestampBlock) { + if (valueBlock.mayHaveNulls()) { state.enableGroupIdTracking(seenGroupIds); } - if (valueBlock.mayHaveNulls()) { + if (timestampBlock.mayHaveNulls()) { state.enableGroupIdTracking(seenGroupIds); } } diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivLongAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivLongAggregatorFunction.java new file mode 100644 index 0000000000000..b18118d21c08a --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivLongAggregatorFunction.java @@ -0,0 +1,235 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link DerivLongAggregator}. + * This class is generated. Edit {@code AggregatorImplementer} instead. + */ +public final class DerivLongAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("count", ElementType.LONG), + new IntermediateStateDesc("sumVal", ElementType.DOUBLE), + new IntermediateStateDesc("sumTs", ElementType.LONG), + new IntermediateStateDesc("sumTsVal", ElementType.DOUBLE), + new IntermediateStateDesc("sumTsSq", ElementType.LONG) ); + + private final DriverContext driverContext; + + private final SimpleLinearRegressionWithTimeseries state; + + private final List channels; + + public DerivLongAggregatorFunction(DriverContext driverContext, List channels, + SimpleLinearRegressionWithTimeseries state) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + } + + public static DerivLongAggregatorFunction create(DriverContext driverContext, + List channels) { + return new DerivLongAggregatorFunction(driverContext, channels, DerivLongAggregator.initSingle(driverContext)); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + } else if (mask.allTrue()) { + addRawInputNotMasked(page); + } else { + addRawInputMasked(page, mask); + } + } + + private void addRawInputMasked(Page page, BooleanVector mask) { + LongBlock valueBlock = page.getBlock(channels.get(0)); + LongBlock timestampBlock = page.getBlock(channels.get(1)); + LongVector valueVector = valueBlock.asVector(); + if (valueVector == null) { + addRawBlock(valueBlock, timestampBlock, mask); + return; + } + LongVector timestampVector = timestampBlock.asVector(); + if (timestampVector == null) { + addRawBlock(valueBlock, timestampBlock, mask); + return; + } + addRawVector(valueVector, timestampVector, mask); + } + + private void addRawInputNotMasked(Page page) { + LongBlock valueBlock = page.getBlock(channels.get(0)); + LongBlock timestampBlock = page.getBlock(channels.get(1)); + LongVector valueVector = valueBlock.asVector(); + if (valueVector == null) { + addRawBlock(valueBlock, timestampBlock); + return; + } + LongVector timestampVector = timestampBlock.asVector(); + if (timestampVector == null) { + addRawBlock(valueBlock, timestampBlock); + return; + } + addRawVector(valueVector, timestampVector); + } + + private void addRawVector(LongVector valueVector, LongVector timestampVector) { + for (int valuesPosition = 0; valuesPosition < valueVector.getPositionCount(); valuesPosition++) { + long valueValue = valueVector.getLong(valuesPosition); + long timestampValue = timestampVector.getLong(valuesPosition); + DerivLongAggregator.combine(state, valueValue, timestampValue); + } + } + + private void addRawVector(LongVector valueVector, LongVector timestampVector, + BooleanVector mask) { + for (int valuesPosition = 0; valuesPosition < valueVector.getPositionCount(); valuesPosition++) { + if (mask.getBoolean(valuesPosition) == false) { + continue; + } + long valueValue = valueVector.getLong(valuesPosition); + long timestampValue = timestampVector.getLong(valuesPosition); + DerivLongAggregator.combine(state, valueValue, timestampValue); + } + } + + private void addRawBlock(LongBlock valueBlock, LongBlock timestampBlock) { + for (int p = 0; p < valueBlock.getPositionCount(); p++) { + int valueValueCount = valueBlock.getValueCount(p); + if (valueValueCount == 0) { + continue; + } + int timestampValueCount = timestampBlock.getValueCount(p); + if (timestampValueCount == 0) { + continue; + } + int valueStart = valueBlock.getFirstValueIndex(p); + int valueEnd = valueStart + valueValueCount; + for (int valueOffset = valueStart; valueOffset < valueEnd; valueOffset++) { + long valueValue = valueBlock.getLong(valueOffset); + int timestampStart = timestampBlock.getFirstValueIndex(p); + int timestampEnd = timestampStart + timestampValueCount; + for (int timestampOffset = timestampStart; timestampOffset < timestampEnd; timestampOffset++) { + long timestampValue = timestampBlock.getLong(timestampOffset); + DerivLongAggregator.combine(state, valueValue, timestampValue); + } + } + } + } + + private void addRawBlock(LongBlock valueBlock, LongBlock timestampBlock, BooleanVector mask) { + for (int p = 0; p < valueBlock.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + int valueValueCount = valueBlock.getValueCount(p); + if (valueValueCount == 0) { + continue; + } + int timestampValueCount = timestampBlock.getValueCount(p); + if (timestampValueCount == 0) { + continue; + } + int valueStart = valueBlock.getFirstValueIndex(p); + int valueEnd = valueStart + valueValueCount; + for (int valueOffset = valueStart; valueOffset < valueEnd; valueOffset++) { + long valueValue = valueBlock.getLong(valueOffset); + int timestampStart = timestampBlock.getFirstValueIndex(p); + int timestampEnd = timestampStart + timestampValueCount; + for (int timestampOffset = timestampStart; timestampOffset < timestampEnd; timestampOffset++) { + long timestampValue = timestampBlock.getLong(timestampOffset); + DerivLongAggregator.combine(state, valueValue, timestampValue); + } + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block countUncast = page.getBlock(channels.get(0)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert count.getPositionCount() == 1; + Block sumValUncast = page.getBlock(channels.get(1)); + if (sumValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumVal = ((DoubleBlock) sumValUncast).asVector(); + assert sumVal.getPositionCount() == 1; + Block sumTsUncast = page.getBlock(channels.get(2)); + if (sumTsUncast.areAllValuesNull()) { + return; + } + LongVector sumTs = ((LongBlock) sumTsUncast).asVector(); + assert sumTs.getPositionCount() == 1; + Block sumTsValUncast = page.getBlock(channels.get(3)); + if (sumTsValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumTsVal = ((DoubleBlock) sumTsValUncast).asVector(); + assert sumTsVal.getPositionCount() == 1; + Block sumTsSqUncast = page.getBlock(channels.get(4)); + if (sumTsSqUncast.areAllValuesNull()) { + return; + } + LongVector sumTsSq = ((LongBlock) sumTsSqUncast).asVector(); + assert sumTsSq.getPositionCount() == 1; + DerivLongAggregator.combineIntermediate(state, count.getLong(0), sumVal.getDouble(0), sumTs.getLong(0), sumTsVal.getDouble(0), sumTsSq.getLong(0)); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = DerivLongAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivLongAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivLongAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..259eb756cb1b2 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivLongAggregatorFunctionSupplier.java @@ -0,0 +1,47 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link DerivLongAggregator}. + * This class is generated. Edit {@code AggregatorFunctionSupplierImplementer} instead. + */ +public final class DerivLongAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + public DerivLongAggregatorFunctionSupplier() { + } + + @Override + public List nonGroupingIntermediateStateDesc() { + return DerivLongAggregatorFunction.intermediateStateDesc(); + } + + @Override + public List groupingIntermediateStateDesc() { + return DerivLongGroupingAggregatorFunction.intermediateStateDesc(); + } + + @Override + public DerivLongAggregatorFunction aggregator(DriverContext driverContext, + List channels) { + return DerivLongAggregatorFunction.create(driverContext, channels); + } + + @Override + public DerivLongGroupingAggregatorFunction groupingAggregator(DriverContext driverContext, + List channels) { + return DerivLongGroupingAggregatorFunction.create(channels, driverContext); + } + + @Override + public String describe() { + return "deriv of longs"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..cb97638df4ea4 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java @@ -0,0 +1,438 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link DerivLongAggregator}. + * This class is generated. Edit {@code GroupingAggregatorImplementer} instead. + */ +public final class DerivLongGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("count", ElementType.LONG), + new IntermediateStateDesc("sumVal", ElementType.DOUBLE), + new IntermediateStateDesc("sumTs", ElementType.LONG), + new IntermediateStateDesc("sumTsVal", ElementType.DOUBLE), + new IntermediateStateDesc("sumTsSq", ElementType.LONG) ); + + private final DerivDoubleAggregator.GroupingState state; + + private final List channels; + + private final DriverContext driverContext; + + public DerivLongGroupingAggregatorFunction(List channels, + DerivDoubleAggregator.GroupingState state, DriverContext driverContext) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + } + + public static DerivLongGroupingAggregatorFunction create(List channels, + DriverContext driverContext) { + return new DerivLongGroupingAggregatorFunction(channels, DerivLongAggregator.initGrouping(driverContext), driverContext); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, + Page page) { + LongBlock valueBlock = page.getBlock(channels.get(0)); + LongBlock timestampBlock = page.getBlock(channels.get(1)); + LongVector valueVector = valueBlock.asVector(); + if (valueVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, valueBlock, timestampBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valueBlock, timestampBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valueBlock, timestampBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valueBlock, timestampBlock); + } + + @Override + public void close() { + } + }; + } + LongVector timestampVector = timestampBlock.asVector(); + if (timestampVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, valueBlock, timestampBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valueBlock, timestampBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valueBlock, timestampBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valueBlock, timestampBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valueVector, timestampVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valueVector, timestampVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valueVector, timestampVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock valueBlock, + LongBlock timestampBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (valueBlock.isNull(valuesPosition)) { + continue; + } + if (timestampBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valueStart = valueBlock.getFirstValueIndex(valuesPosition); + int valueEnd = valueStart + valueBlock.getValueCount(valuesPosition); + for (int valueOffset = valueStart; valueOffset < valueEnd; valueOffset++) { + long valueValue = valueBlock.getLong(valueOffset); + int timestampStart = timestampBlock.getFirstValueIndex(valuesPosition); + int timestampEnd = timestampStart + timestampBlock.getValueCount(valuesPosition); + for (int timestampOffset = timestampStart; timestampOffset < timestampEnd; timestampOffset++) { + long timestampValue = timestampBlock.getLong(timestampOffset); + DerivLongAggregator.combine(state, groupId, valueValue, timestampValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector valueVector, + LongVector timestampVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + long valueValue = valueVector.getLong(valuesPosition); + long timestampValue = timestampVector.getLong(valuesPosition); + DerivLongAggregator.combine(state, groupId, valueValue, timestampValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block countUncast = page.getBlock(channels.get(0)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + Block sumValUncast = page.getBlock(channels.get(1)); + if (sumValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumVal = ((DoubleBlock) sumValUncast).asVector(); + Block sumTsUncast = page.getBlock(channels.get(2)); + if (sumTsUncast.areAllValuesNull()) { + return; + } + LongVector sumTs = ((LongBlock) sumTsUncast).asVector(); + Block sumTsValUncast = page.getBlock(channels.get(3)); + if (sumTsValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumTsVal = ((DoubleBlock) sumTsValUncast).asVector(); + Block sumTsSqUncast = page.getBlock(channels.get(4)); + if (sumTsSqUncast.areAllValuesNull()) { + return; + } + LongVector sumTsSq = ((LongBlock) sumTsSqUncast).asVector(); + assert count.getPositionCount() == sumVal.getPositionCount() && count.getPositionCount() == sumTs.getPositionCount() && count.getPositionCount() == sumTsVal.getPositionCount() && count.getPositionCount() == sumTsSq.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + DerivLongAggregator.combineIntermediate(state, groupId, count.getLong(valuesPosition), sumVal.getDouble(valuesPosition), sumTs.getLong(valuesPosition), sumTsVal.getDouble(valuesPosition), sumTsSq.getLong(valuesPosition)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock valueBlock, + LongBlock timestampBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (valueBlock.isNull(valuesPosition)) { + continue; + } + if (timestampBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valueStart = valueBlock.getFirstValueIndex(valuesPosition); + int valueEnd = valueStart + valueBlock.getValueCount(valuesPosition); + for (int valueOffset = valueStart; valueOffset < valueEnd; valueOffset++) { + long valueValue = valueBlock.getLong(valueOffset); + int timestampStart = timestampBlock.getFirstValueIndex(valuesPosition); + int timestampEnd = timestampStart + timestampBlock.getValueCount(valuesPosition); + for (int timestampOffset = timestampStart; timestampOffset < timestampEnd; timestampOffset++) { + long timestampValue = timestampBlock.getLong(timestampOffset); + DerivLongAggregator.combine(state, groupId, valueValue, timestampValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector valueVector, + LongVector timestampVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + long valueValue = valueVector.getLong(valuesPosition); + long timestampValue = timestampVector.getLong(valuesPosition); + DerivLongAggregator.combine(state, groupId, valueValue, timestampValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block countUncast = page.getBlock(channels.get(0)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + Block sumValUncast = page.getBlock(channels.get(1)); + if (sumValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumVal = ((DoubleBlock) sumValUncast).asVector(); + Block sumTsUncast = page.getBlock(channels.get(2)); + if (sumTsUncast.areAllValuesNull()) { + return; + } + LongVector sumTs = ((LongBlock) sumTsUncast).asVector(); + Block sumTsValUncast = page.getBlock(channels.get(3)); + if (sumTsValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumTsVal = ((DoubleBlock) sumTsValUncast).asVector(); + Block sumTsSqUncast = page.getBlock(channels.get(4)); + if (sumTsSqUncast.areAllValuesNull()) { + return; + } + LongVector sumTsSq = ((LongBlock) sumTsSqUncast).asVector(); + assert count.getPositionCount() == sumVal.getPositionCount() && count.getPositionCount() == sumTs.getPositionCount() && count.getPositionCount() == sumTsVal.getPositionCount() && count.getPositionCount() == sumTsSq.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + DerivLongAggregator.combineIntermediate(state, groupId, count.getLong(valuesPosition), sumVal.getDouble(valuesPosition), sumTs.getLong(valuesPosition), sumTsVal.getDouble(valuesPosition), sumTsSq.getLong(valuesPosition)); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, LongBlock valueBlock, + LongBlock timestampBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + if (valueBlock.isNull(valuesPosition)) { + continue; + } + if (timestampBlock.isNull(valuesPosition)) { + continue; + } + int groupId = groups.getInt(groupPosition); + int valueStart = valueBlock.getFirstValueIndex(valuesPosition); + int valueEnd = valueStart + valueBlock.getValueCount(valuesPosition); + for (int valueOffset = valueStart; valueOffset < valueEnd; valueOffset++) { + long valueValue = valueBlock.getLong(valueOffset); + int timestampStart = timestampBlock.getFirstValueIndex(valuesPosition); + int timestampEnd = timestampStart + timestampBlock.getValueCount(valuesPosition); + for (int timestampOffset = timestampStart; timestampOffset < timestampEnd; timestampOffset++) { + long timestampValue = timestampBlock.getLong(timestampOffset); + DerivLongAggregator.combine(state, groupId, valueValue, timestampValue); + } + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, LongVector valueVector, + LongVector timestampVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + int groupId = groups.getInt(groupPosition); + long valueValue = valueVector.getLong(valuesPosition); + long timestampValue = timestampVector.getLong(valuesPosition); + DerivLongAggregator.combine(state, groupId, valueValue, timestampValue); + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block countUncast = page.getBlock(channels.get(0)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + Block sumValUncast = page.getBlock(channels.get(1)); + if (sumValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumVal = ((DoubleBlock) sumValUncast).asVector(); + Block sumTsUncast = page.getBlock(channels.get(2)); + if (sumTsUncast.areAllValuesNull()) { + return; + } + LongVector sumTs = ((LongBlock) sumTsUncast).asVector(); + Block sumTsValUncast = page.getBlock(channels.get(3)); + if (sumTsValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumTsVal = ((DoubleBlock) sumTsValUncast).asVector(); + Block sumTsSqUncast = page.getBlock(channels.get(4)); + if (sumTsSqUncast.areAllValuesNull()) { + return; + } + LongVector sumTsSq = ((LongBlock) sumTsSqUncast).asVector(); + assert count.getPositionCount() == sumVal.getPositionCount() && count.getPositionCount() == sumTs.getPositionCount() && count.getPositionCount() == sumTsVal.getPositionCount() && count.getPositionCount() == sumTsSq.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + int valuesPosition = groupPosition + positionOffset; + DerivLongAggregator.combineIntermediate(state, groupId, count.getLong(valuesPosition), sumVal.getDouble(valuesPosition), sumTs.getLong(valuesPosition), sumTsVal.getDouble(valuesPosition), sumTsSq.getLong(valuesPosition)); + } + } + + private void maybeEnableGroupIdTracking(SeenGroupIds seenGroupIds, LongBlock valueBlock, + LongBlock timestampBlock) { + if (valueBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + if (timestampBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + GroupingAggregatorEvaluationContext ctx) { + blocks[offset] = DerivLongAggregator.evaluateFinal(state, selected, ctx); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivDoubleAggregator.java index be04eef30b19b..e44df4d56236b 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivDoubleAggregator.java @@ -8,8 +8,7 @@ package org.elasticsearch.compute.aggregation; import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.common.util.DoubleArray; -import org.elasticsearch.common.util.LongArray; +import org.elasticsearch.common.util.ObjectArray; import org.elasticsearch.compute.ann.Aggregator; import org.elasticsearch.compute.ann.GroupingAggregator; import org.elasticsearch.compute.ann.IntermediateState; @@ -29,9 +28,7 @@ @IntermediateState(name = "sumTsVal", type = "DOUBLE"), @IntermediateState(name = "sumTsSq", type = "LONG") } ) -@GroupingAggregator( - { @IntermediateState(name = "timestamps", type = "LONG_BLOCK"), @IntermediateState(name = "values", type = "DOUBLE_BLOCK") } -) +@GroupingAggregator class DerivDoubleAggregator { public static SimpleLinearRegressionWithTimeseries initSingle(DriverContext driverContext) { @@ -70,45 +67,32 @@ public static GroupingState initGrouping(DriverContext driverContext) { return new GroupingState(driverContext.bigArrays()); } - public static void combine(GroupingState state, int groupId, long timestamp, double value) { - // TODO + public static void combine(GroupingState state, int groupId, double value, long timestamp) { + state.getAndGrow(groupId).add(timestamp, value); } public static void combineIntermediate( GroupingState state, int groupId, - LongBlock timestamps, // stylecheck - DoubleBlock values, - int otherPosition + long count, + double sumVal, + long sumTs, + double sumTsVal, + long sumTsSq ) { - // TODO use groupId - state.collectValue(groupId, timestamps.getLong(otherPosition), values.getDouble(otherPosition)); + combineIntermediate(state.getAndGrow(groupId), count, sumVal, sumTs, sumTsVal, sumTsSq); } public static Block evaluateFinal(GroupingState state, IntVector selectedGroups, GroupingAggregatorEvaluationContext ctx) { - // Block evaluatePercentile(IntVector selected, DriverContext driverContext) { - // try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount())) { - // for (int i = 0; i < selected.getPositionCount(); i++) { - // int si = selected.getInt(i); - // if (si >= digests.size()) { - // builder.appendNull(); - // continue; - // } - // final TDigestState digest = digests.get(si); - // if (percentile != null && digest != null && digest.size() > 0) { - // builder.appendDouble(digest.quantile(percentile / 100)); - // } else { - // builder.appendNull(); - // } - // } - // return builder.build(); - // } - // } try (DoubleBlock.Builder builder = ctx.driverContext().blockFactory().newDoubleBlockBuilder(selectedGroups.getPositionCount())) { for (int i = 0; i < selectedGroups.getPositionCount(); i++) { int groupId = selectedGroups.getInt(i); - // TODO must use groupId - double result = 1.0; + SimpleLinearRegressionWithTimeseries slr = state.get(groupId); + if (slr == null) { + builder.appendNull(); + continue; + } + double result = slr.slope(); if (Double.isNaN(result)) { builder.appendNull(); continue; @@ -120,50 +104,68 @@ public static Block evaluateFinal(GroupingState state, IntVector selectedGroups, } public static final class GroupingState extends AbstractArrayState { - private final BigArrays bigArrays; - private LongArray timestamps; - private DoubleArray values; + private ObjectArray states; GroupingState(BigArrays bigArrays) { super(bigArrays); - this.bigArrays = bigArrays; - this.timestamps = bigArrays.newLongArray(1L); - this.values = bigArrays.newDoubleArray(1L); + states = bigArrays.newObjectArray(1); + } + + SimpleLinearRegressionWithTimeseries get(int groupId) { + if (groupId >= states.size()) { + return null; + } + return states.get(groupId); } - void collectValue(int groupId, long timestamp, double value) { - if (groupId < timestamps.size()) { - timestamps.set(groupId, timestamp); - } else { - timestamps = bigArrays.grow(timestamps, groupId + 1); + SimpleLinearRegressionWithTimeseries getAndGrow(int groupId) { + if (groupId >= states.size()) { + states = bigArrays.grow(states, groupId + 1); } - timestamps.set(groupId, timestamp); - if (groupId < values.size()) { - values.set(groupId, value); - } else { - values = bigArrays.grow(values, groupId + 1); + SimpleLinearRegressionWithTimeseries slr = states.get(groupId); + if (slr == null) { + slr = new SimpleLinearRegressionWithTimeseries(); + states.set(groupId, slr); } - values.set(groupId, value); + return slr; } @Override public void close() { - Releasables.close(timestamps, values, super::close); + Releasables.close(states, super::close); } @Override public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { try ( - LongBlock.Builder timestampBuilder = driverContext.blockFactory().newLongBlockBuilder(selected.getPositionCount()); - DoubleBlock.Builder valueBuilder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount()); + LongBlock.Builder countBuilder = driverContext.blockFactory().newLongBlockBuilder(selected.getPositionCount()); + DoubleBlock.Builder sumValBuilder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount()); + LongBlock.Builder sumTsBuilder = driverContext.blockFactory().newLongBlockBuilder(selected.getPositionCount()); + DoubleBlock.Builder sumTsValBuilder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount()); + LongBlock.Builder sumTsSqBuilder = driverContext.blockFactory().newLongBlockBuilder(selected.getPositionCount()) ) { for (int i = 0; i < selected.getPositionCount(); i++) { int groupId = selected.getInt(i); - timestampBuilder.appendLong(timestamps.get(groupId)); - valueBuilder.appendDouble(values.get(groupId)); + SimpleLinearRegressionWithTimeseries slr = get(groupId); + if (slr == null) { + countBuilder.appendNull(); + sumValBuilder.appendNull(); + sumTsBuilder.appendNull(); + sumTsValBuilder.appendNull(); + sumTsSqBuilder.appendNull(); + } else { + countBuilder.appendLong(slr.count); + sumValBuilder.appendDouble(slr.sumVal); + sumTsBuilder.appendLong(slr.sumTs); + sumTsValBuilder.appendDouble(slr.sumTsVal); + sumTsSqBuilder.appendLong(slr.sumTsSq); + } } - blocks[offset] = timestampBuilder.build(); - blocks[offset + 1] = valueBuilder.build(); + blocks[offset] = countBuilder.build(); + blocks[offset + 1] = sumValBuilder.build(); + blocks[offset + 2] = sumTsBuilder.build(); + blocks[offset + 3] = sumTsValBuilder.build(); + blocks[offset + 4] = sumTsSqBuilder.build(); } } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivLongAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivLongAggregator.java new file mode 100644 index 0000000000000..b1572be80618c --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivLongAggregator.java @@ -0,0 +1,78 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.operator.DriverContext; + +@Aggregator( + { + @IntermediateState(name = "count", type = "LONG"), + @IntermediateState(name = "sumVal", type = "DOUBLE"), + @IntermediateState(name = "sumTs", type = "LONG"), + @IntermediateState(name = "sumTsVal", type = "DOUBLE"), + @IntermediateState(name = "sumTsSq", type = "LONG") } +) +@GroupingAggregator +class DerivLongAggregator { + + public static SimpleLinearRegressionWithTimeseries initSingle(DriverContext driverContext) { + return new SimpleLinearRegressionWithTimeseries(); + } + + public static void combine(SimpleLinearRegressionWithTimeseries current, long value, long timestamp) { + DerivDoubleAggregator.combine(current, (double) value, timestamp); + } + + public static void combineIntermediate( + SimpleLinearRegressionWithTimeseries state, + long count, + double sumVal, + long sumTs, + double sumTsVal, + long sumTsSq + ) { + DerivDoubleAggregator.combineIntermediate(state, count, sumVal, sumTs, sumTsVal, sumTsSq); + } + + public static Block evaluateFinal(SimpleLinearRegressionWithTimeseries state, DriverContext driverContext) { + return DerivDoubleAggregator.evaluateFinal(state, driverContext); + } + + public static DerivDoubleAggregator.GroupingState initGrouping(DriverContext driverContext) { + return new DerivDoubleAggregator.GroupingState(driverContext.bigArrays()); + } + + public static void combine(DerivDoubleAggregator.GroupingState state, int groupId, long value, long timestamp) { + DerivDoubleAggregator.combine(state.getAndGrow(groupId), (double) value, timestamp); + } + + public static void combineIntermediate( + DerivDoubleAggregator.GroupingState state, + int groupId, + long count, + double sumVal, + long sumTs, + double sumTsVal, + long sumTsSq + ) { + combineIntermediate(state.getAndGrow(groupId), count, sumVal, sumTs, sumTsVal, sumTsSq); + } + + public static Block evaluateFinal( + DerivDoubleAggregator.GroupingState state, + IntVector selectedGroups, + GroupingAggregatorEvaluationContext ctx + ) { + return DerivDoubleAggregator.evaluateFinal(state, selectedGroups, ctx); + } +} diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec index 93da64ccc60f7..afc0ad6681028 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec @@ -723,8 +723,6 @@ mx:integer | tbucket:datetime derivative_of_gauge_metric required_capability: ts_command_v0 -required_capability: TS_LINREG_DERIVATIVE - // tag::deriv[] TS k8s @@ -744,7 +742,7 @@ max_deriv:double | time_bucket:datetime | pod:keyword -2.0E-5 | 2024-05-10T00:10:00.000Z | three // end::deriv-result[] -1.0E-6 | 2024-05-10T00:15:00.000Z | three -0.0 | 2024-05-10T00:20:00.000Z | three +1.9E-5 | 2024-05-10T00:20:00.000Z | three ; @@ -763,8 +761,8 @@ TS k8s max_deriv:double | max_rate:double | time_bucket:datetime | cluster:keyword 0.0855 | 8.120833 | 2024-05-10T00:00:00.000Z | prod 0.004933 | 6.451737 | 2024-05-10T00:05:00.000Z | prod -0.008922 | 11.562738 | 2024-05-10T00:10:00.000Z | prod -0.016623 | 11.860806 | 2024-05-10T00:15:00.000Z | prod -0.0 | 6.980661 | 2024-05-10T00:20:00.000Z | prod +0.008922 | 11.56274 | 2024-05-10T00:10:00.000Z | prod +0.016623 | 11.86081 | 2024-05-10T00:15:00.000Z | prod +0.009026 | 6.980661 | 2024-05-10T00:20:00.000Z | prod ; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java index e7f32e3a2a154..20723465fc3a3 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java @@ -10,6 +10,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.DerivDoubleAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.DerivLongAggregatorFunctionSupplier; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; @@ -111,6 +112,7 @@ public AggregatorFunctionSupplier supplier() { final DataType type = field().dataType(); return switch (type) { case DOUBLE -> new DerivDoubleAggregatorFunctionSupplier(); + case LONG, INTEGER -> new DerivLongAggregatorFunctionSupplier(); default -> throw new IllegalArgumentException("Unsupported data type for deriv aggregation: " + type); }; } From f43ddd200bf3abe50332c60ae08ce30be8cc4e2c Mon Sep 17 00:00:00 2001 From: Pablo Date: Mon, 17 Nov 2025 13:22:32 -0800 Subject: [PATCH 21/27] fixup test --- .../qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec | 1 + 1 file changed, 1 insertion(+) diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec index afc0ad6681028..42759729f3845 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec @@ -723,6 +723,7 @@ mx:integer | tbucket:datetime derivative_of_gauge_metric required_capability: ts_command_v0 +required_capability: TS_LINREG_DERIVATIVE // tag::deriv[] TS k8s From 8facb2078e9d22af43d331d6e8c9c33376060107 Mon Sep 17 00:00:00 2001 From: Pablo Date: Mon, 17 Nov 2025 20:09:49 -0800 Subject: [PATCH 22/27] fixup --- .../SimpleLinearRegressionWithTimeseries.java | 2 +- .../main/resources/k8s-timeseries.csv-spec | 22 +++++++++---------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SimpleLinearRegressionWithTimeseries.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SimpleLinearRegressionWithTimeseries.java index c98340f2e4750..9ed3a5bf2b081 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SimpleLinearRegressionWithTimeseries.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SimpleLinearRegressionWithTimeseries.java @@ -55,7 +55,7 @@ void add(long ts, double val) { if (denominator == 0) { return Double.NaN; } - return numerator / denominator; + return numerator / denominator * 1000.0; // per second } double intercept() { diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec index 42759729f3845..81e0779730da0 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec @@ -723,7 +723,6 @@ mx:integer | tbucket:datetime derivative_of_gauge_metric required_capability: ts_command_v0 -required_capability: TS_LINREG_DERIVATIVE // tag::deriv[] TS k8s @@ -738,18 +737,17 @@ TS k8s // tag::deriv-result[] max_deriv:double | time_bucket:datetime | pod:keyword -9.3E-5 | 2024-05-10T00:00:00.000Z | three -3.8E-5 | 2024-05-10T00:05:00.000Z | three --2.0E-5 | 2024-05-10T00:10:00.000Z | three +0.092774 | 2024-05-10T00:00:00.000Z | three +0.038026 | 2024-05-10T00:05:00.000Z | three +-0.0197 | 2024-05-10T00:10:00.000Z | three // end::deriv-result[] --1.0E-6 | 2024-05-10T00:15:00.000Z | three -1.9E-5 | 2024-05-10T00:20:00.000Z | three +-0.001209 | 2024-05-10T00:15:00.000Z | three +0.019397 | 2024-05-10T00:20:00.000Z | three ; derivative_compared_to_rate required_capability: ts_command_v0 -required_capability: TS_LINREG_DERIVATIVE TS k8s | STATS max_deriv = max(deriv(to_long(network.total_bytes_in))), max_rate = max(rate(network.total_bytes_in)) BY time_bucket = bucket(@timestamp,5minute), cluster @@ -760,10 +758,10 @@ TS k8s ; max_deriv:double | max_rate:double | time_bucket:datetime | cluster:keyword -0.0855 | 8.120833 | 2024-05-10T00:00:00.000Z | prod -0.004933 | 6.451737 | 2024-05-10T00:05:00.000Z | prod -0.008922 | 11.56274 | 2024-05-10T00:10:00.000Z | prod -0.016623 | 11.86081 | 2024-05-10T00:15:00.000Z | prod -0.009026 | 6.980661 | 2024-05-10T00:20:00.000Z | prod +85.5 | 8.120833 | 2024-05-10T00:00:00.000Z | prod +4.933168 | 6.451737 | 2024-05-10T00:05:00.000Z | prod +8.922491 | 11.56274 | 2024-05-10T00:10:00.000Z | prod +16.62316 | 11.86081 | 2024-05-10T00:15:00.000Z | prod +9.026268 | 6.980661 | 2024-05-10T00:20:00.000Z | prod ; From b8f5f361540924a14a29afe4db96d14fffaecf6f Mon Sep 17 00:00:00 2001 From: Pablo Date: Tue, 18 Nov 2025 10:20:58 -0800 Subject: [PATCH 23/27] fixcup --- .../qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec | 2 ++ 1 file changed, 2 insertions(+) diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec index 81e0779730da0..a09c1602aa101 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec @@ -723,6 +723,7 @@ mx:integer | tbucket:datetime derivative_of_gauge_metric required_capability: ts_command_v0 +required_capability: TS_LINREG_DERIVATIVE // tag::deriv[] TS k8s @@ -748,6 +749,7 @@ max_deriv:double | time_bucket:datetime | pod:keyword derivative_compared_to_rate required_capability: ts_command_v0 +required_capability: TS_LINREG_DERIVATIVE TS k8s | STATS max_deriv = max(deriv(to_long(network.total_bytes_in))), max_rate = max(rate(network.total_bytes_in)) BY time_bucket = bucket(@timestamp,5minute), cluster From 7a5bf026e97d569c892099d90b92422b6ae84668 Mon Sep 17 00:00:00 2001 From: Pablo Date: Tue, 18 Nov 2025 11:23:15 -0800 Subject: [PATCH 24/27] fix docs --- .../esql/_snippets/functions/examples/deriv.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/reference/query-languages/esql/_snippets/functions/examples/deriv.md b/docs/reference/query-languages/esql/_snippets/functions/examples/deriv.md index 328d13f76598a..dc90ed8d0ce2b 100644 --- a/docs/reference/query-languages/esql/_snippets/functions/examples/deriv.md +++ b/docs/reference/query-languages/esql/_snippets/functions/examples/deriv.md @@ -10,8 +10,8 @@ TS k8s | max_deriv:double | time_bucket:datetime | pod:keyword | | --- | --- | --- | -| 9.3E-5 | 2024-05-10T00:00:00.000Z | three | -| 3.8E-5 | 2024-05-10T00:05:00.000Z | three | -| -2.0E-5 | 2024-05-10T00:10:00.000Z | three | +| 0.092774 | 2024-05-10T00:00:00.000Z | three | +| 0.038026 | 2024-05-10T00:05:00.000Z | three | +| -0.0197 | 2024-05-10T00:10:00.000Z | three | From 2d2ed8986f4d9a396931afca1facd08c5385377b Mon Sep 17 00:00:00 2001 From: Pablo Date: Wed, 19 Nov 2025 19:54:00 -0800 Subject: [PATCH 25/27] comments --- .../DerivIntAggregatorFunction.java | 236 ++++++++++ .../DerivIntAggregatorFunctionSupplier.java | 47 ++ .../DerivIntGroupingAggregatorFunction.java | 439 ++++++++++++++++++ .../aggregation/DerivIntAggregator.java | 78 ++++ .../GroupingAggregatorFunction.java | 4 - .../function/EsqlFunctionRegistry.java | 2 +- .../expression/function/aggregate/Deriv.java | 12 +- 7 files changed, 805 insertions(+), 13 deletions(-) create mode 100644 x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivIntAggregatorFunction.java create mode 100644 x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivIntAggregatorFunctionSupplier.java create mode 100644 x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java create mode 100644 x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivIntAggregator.java diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivIntAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivIntAggregatorFunction.java new file mode 100644 index 0000000000000..1c936d5de9135 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivIntAggregatorFunction.java @@ -0,0 +1,236 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link DerivIntAggregator}. + * This class is generated. Edit {@code AggregatorImplementer} instead. + */ +public final class DerivIntAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("count", ElementType.LONG), + new IntermediateStateDesc("sumVal", ElementType.DOUBLE), + new IntermediateStateDesc("sumTs", ElementType.LONG), + new IntermediateStateDesc("sumTsVal", ElementType.DOUBLE), + new IntermediateStateDesc("sumTsSq", ElementType.LONG) ); + + private final DriverContext driverContext; + + private final SimpleLinearRegressionWithTimeseries state; + + private final List channels; + + public DerivIntAggregatorFunction(DriverContext driverContext, List channels, + SimpleLinearRegressionWithTimeseries state) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + } + + public static DerivIntAggregatorFunction create(DriverContext driverContext, + List channels) { + return new DerivIntAggregatorFunction(driverContext, channels, DerivIntAggregator.initSingle(driverContext)); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + } else if (mask.allTrue()) { + addRawInputNotMasked(page); + } else { + addRawInputMasked(page, mask); + } + } + + private void addRawInputMasked(Page page, BooleanVector mask) { + IntBlock valueBlock = page.getBlock(channels.get(0)); + LongBlock timestampBlock = page.getBlock(channels.get(1)); + IntVector valueVector = valueBlock.asVector(); + if (valueVector == null) { + addRawBlock(valueBlock, timestampBlock, mask); + return; + } + LongVector timestampVector = timestampBlock.asVector(); + if (timestampVector == null) { + addRawBlock(valueBlock, timestampBlock, mask); + return; + } + addRawVector(valueVector, timestampVector, mask); + } + + private void addRawInputNotMasked(Page page) { + IntBlock valueBlock = page.getBlock(channels.get(0)); + LongBlock timestampBlock = page.getBlock(channels.get(1)); + IntVector valueVector = valueBlock.asVector(); + if (valueVector == null) { + addRawBlock(valueBlock, timestampBlock); + return; + } + LongVector timestampVector = timestampBlock.asVector(); + if (timestampVector == null) { + addRawBlock(valueBlock, timestampBlock); + return; + } + addRawVector(valueVector, timestampVector); + } + + private void addRawVector(IntVector valueVector, LongVector timestampVector) { + for (int valuesPosition = 0; valuesPosition < valueVector.getPositionCount(); valuesPosition++) { + int valueValue = valueVector.getInt(valuesPosition); + long timestampValue = timestampVector.getLong(valuesPosition); + DerivIntAggregator.combine(state, valueValue, timestampValue); + } + } + + private void addRawVector(IntVector valueVector, LongVector timestampVector, BooleanVector mask) { + for (int valuesPosition = 0; valuesPosition < valueVector.getPositionCount(); valuesPosition++) { + if (mask.getBoolean(valuesPosition) == false) { + continue; + } + int valueValue = valueVector.getInt(valuesPosition); + long timestampValue = timestampVector.getLong(valuesPosition); + DerivIntAggregator.combine(state, valueValue, timestampValue); + } + } + + private void addRawBlock(IntBlock valueBlock, LongBlock timestampBlock) { + for (int p = 0; p < valueBlock.getPositionCount(); p++) { + int valueValueCount = valueBlock.getValueCount(p); + if (valueValueCount == 0) { + continue; + } + int timestampValueCount = timestampBlock.getValueCount(p); + if (timestampValueCount == 0) { + continue; + } + int valueStart = valueBlock.getFirstValueIndex(p); + int valueEnd = valueStart + valueValueCount; + for (int valueOffset = valueStart; valueOffset < valueEnd; valueOffset++) { + int valueValue = valueBlock.getInt(valueOffset); + int timestampStart = timestampBlock.getFirstValueIndex(p); + int timestampEnd = timestampStart + timestampValueCount; + for (int timestampOffset = timestampStart; timestampOffset < timestampEnd; timestampOffset++) { + long timestampValue = timestampBlock.getLong(timestampOffset); + DerivIntAggregator.combine(state, valueValue, timestampValue); + } + } + } + } + + private void addRawBlock(IntBlock valueBlock, LongBlock timestampBlock, BooleanVector mask) { + for (int p = 0; p < valueBlock.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + int valueValueCount = valueBlock.getValueCount(p); + if (valueValueCount == 0) { + continue; + } + int timestampValueCount = timestampBlock.getValueCount(p); + if (timestampValueCount == 0) { + continue; + } + int valueStart = valueBlock.getFirstValueIndex(p); + int valueEnd = valueStart + valueValueCount; + for (int valueOffset = valueStart; valueOffset < valueEnd; valueOffset++) { + int valueValue = valueBlock.getInt(valueOffset); + int timestampStart = timestampBlock.getFirstValueIndex(p); + int timestampEnd = timestampStart + timestampValueCount; + for (int timestampOffset = timestampStart; timestampOffset < timestampEnd; timestampOffset++) { + long timestampValue = timestampBlock.getLong(timestampOffset); + DerivIntAggregator.combine(state, valueValue, timestampValue); + } + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block countUncast = page.getBlock(channels.get(0)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert count.getPositionCount() == 1; + Block sumValUncast = page.getBlock(channels.get(1)); + if (sumValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumVal = ((DoubleBlock) sumValUncast).asVector(); + assert sumVal.getPositionCount() == 1; + Block sumTsUncast = page.getBlock(channels.get(2)); + if (sumTsUncast.areAllValuesNull()) { + return; + } + LongVector sumTs = ((LongBlock) sumTsUncast).asVector(); + assert sumTs.getPositionCount() == 1; + Block sumTsValUncast = page.getBlock(channels.get(3)); + if (sumTsValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumTsVal = ((DoubleBlock) sumTsValUncast).asVector(); + assert sumTsVal.getPositionCount() == 1; + Block sumTsSqUncast = page.getBlock(channels.get(4)); + if (sumTsSqUncast.areAllValuesNull()) { + return; + } + LongVector sumTsSq = ((LongBlock) sumTsSqUncast).asVector(); + assert sumTsSq.getPositionCount() == 1; + DerivIntAggregator.combineIntermediate(state, count.getLong(0), sumVal.getDouble(0), sumTs.getLong(0), sumTsVal.getDouble(0), sumTsSq.getLong(0)); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = DerivIntAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivIntAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivIntAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..ecd4e4bf8dbd8 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivIntAggregatorFunctionSupplier.java @@ -0,0 +1,47 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link DerivIntAggregator}. + * This class is generated. Edit {@code AggregatorFunctionSupplierImplementer} instead. + */ +public final class DerivIntAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + public DerivIntAggregatorFunctionSupplier() { + } + + @Override + public List nonGroupingIntermediateStateDesc() { + return DerivIntAggregatorFunction.intermediateStateDesc(); + } + + @Override + public List groupingIntermediateStateDesc() { + return DerivIntGroupingAggregatorFunction.intermediateStateDesc(); + } + + @Override + public DerivIntAggregatorFunction aggregator(DriverContext driverContext, + List channels) { + return DerivIntAggregatorFunction.create(driverContext, channels); + } + + @Override + public DerivIntGroupingAggregatorFunction groupingAggregator(DriverContext driverContext, + List channels) { + return DerivIntGroupingAggregatorFunction.create(channels, driverContext); + } + + @Override + public String describe() { + return "deriv of ints"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..ab1e3d602be25 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java @@ -0,0 +1,439 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link DerivIntAggregator}. + * This class is generated. Edit {@code GroupingAggregatorImplementer} instead. + */ +public final class DerivIntGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("count", ElementType.LONG), + new IntermediateStateDesc("sumVal", ElementType.DOUBLE), + new IntermediateStateDesc("sumTs", ElementType.LONG), + new IntermediateStateDesc("sumTsVal", ElementType.DOUBLE), + new IntermediateStateDesc("sumTsSq", ElementType.LONG) ); + + private final DerivDoubleAggregator.GroupingState state; + + private final List channels; + + private final DriverContext driverContext; + + public DerivIntGroupingAggregatorFunction(List channels, + DerivDoubleAggregator.GroupingState state, DriverContext driverContext) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + } + + public static DerivIntGroupingAggregatorFunction create(List channels, + DriverContext driverContext) { + return new DerivIntGroupingAggregatorFunction(channels, DerivIntAggregator.initGrouping(driverContext), driverContext); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, + Page page) { + IntBlock valueBlock = page.getBlock(channels.get(0)); + LongBlock timestampBlock = page.getBlock(channels.get(1)); + IntVector valueVector = valueBlock.asVector(); + if (valueVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, valueBlock, timestampBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valueBlock, timestampBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valueBlock, timestampBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valueBlock, timestampBlock); + } + + @Override + public void close() { + } + }; + } + LongVector timestampVector = timestampBlock.asVector(); + if (timestampVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, valueBlock, timestampBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valueBlock, timestampBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valueBlock, timestampBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valueBlock, timestampBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valueVector, timestampVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valueVector, timestampVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valueVector, timestampVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, IntBlock valueBlock, + LongBlock timestampBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (valueBlock.isNull(valuesPosition)) { + continue; + } + if (timestampBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valueStart = valueBlock.getFirstValueIndex(valuesPosition); + int valueEnd = valueStart + valueBlock.getValueCount(valuesPosition); + for (int valueOffset = valueStart; valueOffset < valueEnd; valueOffset++) { + int valueValue = valueBlock.getInt(valueOffset); + int timestampStart = timestampBlock.getFirstValueIndex(valuesPosition); + int timestampEnd = timestampStart + timestampBlock.getValueCount(valuesPosition); + for (int timestampOffset = timestampStart; timestampOffset < timestampEnd; timestampOffset++) { + long timestampValue = timestampBlock.getLong(timestampOffset); + DerivIntAggregator.combine(state, groupId, valueValue, timestampValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector valueVector, + LongVector timestampVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valueValue = valueVector.getInt(valuesPosition); + long timestampValue = timestampVector.getLong(valuesPosition); + DerivIntAggregator.combine(state, groupId, valueValue, timestampValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block countUncast = page.getBlock(channels.get(0)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + Block sumValUncast = page.getBlock(channels.get(1)); + if (sumValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumVal = ((DoubleBlock) sumValUncast).asVector(); + Block sumTsUncast = page.getBlock(channels.get(2)); + if (sumTsUncast.areAllValuesNull()) { + return; + } + LongVector sumTs = ((LongBlock) sumTsUncast).asVector(); + Block sumTsValUncast = page.getBlock(channels.get(3)); + if (sumTsValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumTsVal = ((DoubleBlock) sumTsValUncast).asVector(); + Block sumTsSqUncast = page.getBlock(channels.get(4)); + if (sumTsSqUncast.areAllValuesNull()) { + return; + } + LongVector sumTsSq = ((LongBlock) sumTsSqUncast).asVector(); + assert count.getPositionCount() == sumVal.getPositionCount() && count.getPositionCount() == sumTs.getPositionCount() && count.getPositionCount() == sumTsVal.getPositionCount() && count.getPositionCount() == sumTsSq.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + DerivIntAggregator.combineIntermediate(state, groupId, count.getLong(valuesPosition), sumVal.getDouble(valuesPosition), sumTs.getLong(valuesPosition), sumTsVal.getDouble(valuesPosition), sumTsSq.getLong(valuesPosition)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock valueBlock, + LongBlock timestampBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (valueBlock.isNull(valuesPosition)) { + continue; + } + if (timestampBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valueStart = valueBlock.getFirstValueIndex(valuesPosition); + int valueEnd = valueStart + valueBlock.getValueCount(valuesPosition); + for (int valueOffset = valueStart; valueOffset < valueEnd; valueOffset++) { + int valueValue = valueBlock.getInt(valueOffset); + int timestampStart = timestampBlock.getFirstValueIndex(valuesPosition); + int timestampEnd = timestampStart + timestampBlock.getValueCount(valuesPosition); + for (int timestampOffset = timestampStart; timestampOffset < timestampEnd; timestampOffset++) { + long timestampValue = timestampBlock.getLong(timestampOffset); + DerivIntAggregator.combine(state, groupId, valueValue, timestampValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector valueVector, + LongVector timestampVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valueValue = valueVector.getInt(valuesPosition); + long timestampValue = timestampVector.getLong(valuesPosition); + DerivIntAggregator.combine(state, groupId, valueValue, timestampValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block countUncast = page.getBlock(channels.get(0)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + Block sumValUncast = page.getBlock(channels.get(1)); + if (sumValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumVal = ((DoubleBlock) sumValUncast).asVector(); + Block sumTsUncast = page.getBlock(channels.get(2)); + if (sumTsUncast.areAllValuesNull()) { + return; + } + LongVector sumTs = ((LongBlock) sumTsUncast).asVector(); + Block sumTsValUncast = page.getBlock(channels.get(3)); + if (sumTsValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumTsVal = ((DoubleBlock) sumTsValUncast).asVector(); + Block sumTsSqUncast = page.getBlock(channels.get(4)); + if (sumTsSqUncast.areAllValuesNull()) { + return; + } + LongVector sumTsSq = ((LongBlock) sumTsSqUncast).asVector(); + assert count.getPositionCount() == sumVal.getPositionCount() && count.getPositionCount() == sumTs.getPositionCount() && count.getPositionCount() == sumTsVal.getPositionCount() && count.getPositionCount() == sumTsSq.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + DerivIntAggregator.combineIntermediate(state, groupId, count.getLong(valuesPosition), sumVal.getDouble(valuesPosition), sumTs.getLong(valuesPosition), sumTsVal.getDouble(valuesPosition), sumTsSq.getLong(valuesPosition)); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, IntBlock valueBlock, + LongBlock timestampBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + if (valueBlock.isNull(valuesPosition)) { + continue; + } + if (timestampBlock.isNull(valuesPosition)) { + continue; + } + int groupId = groups.getInt(groupPosition); + int valueStart = valueBlock.getFirstValueIndex(valuesPosition); + int valueEnd = valueStart + valueBlock.getValueCount(valuesPosition); + for (int valueOffset = valueStart; valueOffset < valueEnd; valueOffset++) { + int valueValue = valueBlock.getInt(valueOffset); + int timestampStart = timestampBlock.getFirstValueIndex(valuesPosition); + int timestampEnd = timestampStart + timestampBlock.getValueCount(valuesPosition); + for (int timestampOffset = timestampStart; timestampOffset < timestampEnd; timestampOffset++) { + long timestampValue = timestampBlock.getLong(timestampOffset); + DerivIntAggregator.combine(state, groupId, valueValue, timestampValue); + } + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, IntVector valueVector, + LongVector timestampVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + int groupId = groups.getInt(groupPosition); + int valueValue = valueVector.getInt(valuesPosition); + long timestampValue = timestampVector.getLong(valuesPosition); + DerivIntAggregator.combine(state, groupId, valueValue, timestampValue); + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block countUncast = page.getBlock(channels.get(0)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + Block sumValUncast = page.getBlock(channels.get(1)); + if (sumValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumVal = ((DoubleBlock) sumValUncast).asVector(); + Block sumTsUncast = page.getBlock(channels.get(2)); + if (sumTsUncast.areAllValuesNull()) { + return; + } + LongVector sumTs = ((LongBlock) sumTsUncast).asVector(); + Block sumTsValUncast = page.getBlock(channels.get(3)); + if (sumTsValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumTsVal = ((DoubleBlock) sumTsValUncast).asVector(); + Block sumTsSqUncast = page.getBlock(channels.get(4)); + if (sumTsSqUncast.areAllValuesNull()) { + return; + } + LongVector sumTsSq = ((LongBlock) sumTsSqUncast).asVector(); + assert count.getPositionCount() == sumVal.getPositionCount() && count.getPositionCount() == sumTs.getPositionCount() && count.getPositionCount() == sumTsVal.getPositionCount() && count.getPositionCount() == sumTsSq.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + int valuesPosition = groupPosition + positionOffset; + DerivIntAggregator.combineIntermediate(state, groupId, count.getLong(valuesPosition), sumVal.getDouble(valuesPosition), sumTs.getLong(valuesPosition), sumTsVal.getDouble(valuesPosition), sumTsSq.getLong(valuesPosition)); + } + } + + private void maybeEnableGroupIdTracking(SeenGroupIds seenGroupIds, IntBlock valueBlock, + LongBlock timestampBlock) { + if (valueBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + if (timestampBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + GroupingAggregatorEvaluationContext ctx) { + blocks[offset] = DerivIntAggregator.evaluateFinal(state, selected, ctx); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivIntAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivIntAggregator.java new file mode 100644 index 0000000000000..0b1dcc912ce45 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivIntAggregator.java @@ -0,0 +1,78 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.operator.DriverContext; + +@Aggregator( + { + @IntermediateState(name = "count", type = "LONG"), + @IntermediateState(name = "sumVal", type = "DOUBLE"), + @IntermediateState(name = "sumTs", type = "LONG"), + @IntermediateState(name = "sumTsVal", type = "DOUBLE"), + @IntermediateState(name = "sumTsSq", type = "LONG") } +) +@GroupingAggregator +class DerivIntAggregator { + + public static SimpleLinearRegressionWithTimeseries initSingle(DriverContext driverContext) { + return new SimpleLinearRegressionWithTimeseries(); + } + + public static void combine(SimpleLinearRegressionWithTimeseries current, int value, long timestamp) { + DerivDoubleAggregator.combine(current, value, timestamp); + } + + public static void combineIntermediate( + SimpleLinearRegressionWithTimeseries state, + long count, + double sumVal, + long sumTs, + double sumTsVal, + long sumTsSq + ) { + DerivDoubleAggregator.combineIntermediate(state, count, sumVal, sumTs, sumTsVal, sumTsSq); + } + + public static Block evaluateFinal(SimpleLinearRegressionWithTimeseries state, DriverContext driverContext) { + return DerivDoubleAggregator.evaluateFinal(state, driverContext); + } + + public static DerivDoubleAggregator.GroupingState initGrouping(DriverContext driverContext) { + return new DerivDoubleAggregator.GroupingState(driverContext.bigArrays()); + } + + public static void combine(DerivDoubleAggregator.GroupingState state, int groupId, int value, long timestamp) { + DerivDoubleAggregator.combine(state.getAndGrow(groupId), value, timestamp); + } + + public static void combineIntermediate( + DerivDoubleAggregator.GroupingState state, + int groupId, + long count, + double sumVal, + long sumTs, + double sumTsVal, + long sumTsSq + ) { + combineIntermediate(state.getAndGrow(groupId), count, sumVal, sumTs, sumTsVal, sumTsSq); + } + + public static Block evaluateFinal( + DerivDoubleAggregator.GroupingState state, + IntVector selectedGroups, + GroupingAggregatorEvaluationContext ctx + ) { + return DerivDoubleAggregator.evaluateFinal(state, selectedGroups, ctx); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunction.java index 3f8088be1800a..319b32c6a1fc8 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunction.java @@ -142,10 +142,6 @@ default void add(int positionOffset, IntBlock groupIds) { * Build the intermediate results for this aggregation. * @param selected the groupIds that have been selected to be included in * the results. Always ascending. - * - *

This function may be called in the coordinator node after all intermediate - * blocks have been gathered from the data nodes, or on data nodes during - * node-level or cluster-level reduction with intermediate input to intermediate output.

*/ void evaluateIntermediate(Block[] blocks, int offset, IntVector selected); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index 7f8b02cf5972b..6ab85f7fe4b37 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -537,7 +537,7 @@ private static FunctionDefinition[][] functions() { defTS(Idelta.class, bi(Idelta::new), "idelta"), defTS(Delta.class, bi(Delta::new), "delta"), defTS(Increase.class, bi(Increase::new), "increase"), - def(Deriv.class, uni(Deriv::new), "deriv"), + defTS(Deriv.class, bi(Deriv::new), "deriv"), def(MaxOverTime.class, uni(MaxOverTime::new), "max_over_time"), def(MinOverTime.class, uni(MinOverTime::new), "min_over_time"), def(SumOverTime.class, uni(SumOverTime::new), "sum_over_time"), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java index 20723465fc3a3..dca951ccd5446 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java @@ -14,7 +14,6 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; -import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -43,11 +42,7 @@ public class Deriv extends TimeSeriesAggregateFunction implements ToAggregator { description = "Calculates the derivative over time of a numeric field using linear regression.", examples = { @Example(file = "k8s-timeseries", tag = "deriv") } ) - public Deriv(Source source, @Param(name = "field", type = { "long", "integer", "double" }) Expression field) { - this(source, field, new UnresolvedAttribute(source, "@timestamp")); - } - - public Deriv(Source source, Expression field, Expression timestamp) { + public Deriv(Source source, @Param(name = "field", type = { "long", "integer", "double" }) Expression field, Expression timestamp) { this(source, field, Literal.TRUE, NO_WINDOW, timestamp); } @@ -73,7 +68,7 @@ public AggregateFunction perTimeSeriesAggregation() { @Override public AggregateFunction withFilter(Expression filter) { - return new Deriv(source(), field(), filter, timestamp, window()); + return new Deriv(source(), field(), filter, window(), timestamp); } @Override @@ -112,7 +107,8 @@ public AggregatorFunctionSupplier supplier() { final DataType type = field().dataType(); return switch (type) { case DOUBLE -> new DerivDoubleAggregatorFunctionSupplier(); - case LONG, INTEGER -> new DerivLongAggregatorFunctionSupplier(); + case LONG -> new DerivLongAggregatorFunctionSupplier(); + case INTEGER -> new DerivLongAggregatorFunctionSupplier(); default -> throw new IllegalArgumentException("Unsupported data type for deriv aggregation: " + type); }; } From 837a1b63214f06884ad0ccbcc051b93c8b3be81d Mon Sep 17 00:00:00 2001 From: Pablo Date: Wed, 19 Nov 2025 20:02:33 -0800 Subject: [PATCH 26/27] fix --- .../xpack/esql/expression/function/aggregate/Deriv.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java index dca951ccd5446..2974a92781d96 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java @@ -10,6 +10,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.DerivDoubleAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.DerivIntAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.DerivLongAggregatorFunctionSupplier; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; @@ -108,7 +109,7 @@ public AggregatorFunctionSupplier supplier() { return switch (type) { case DOUBLE -> new DerivDoubleAggregatorFunctionSupplier(); case LONG -> new DerivLongAggregatorFunctionSupplier(); - case INTEGER -> new DerivLongAggregatorFunctionSupplier(); + case INTEGER -> new DerivIntAggregatorFunctionSupplier(); default -> throw new IllegalArgumentException("Unsupported data type for deriv aggregation: " + type); }; } From 8b5328df6411c89b4bc3eb092bd3ff40af14148a Mon Sep 17 00:00:00 2001 From: Pablo Date: Wed, 19 Nov 2025 20:16:58 -0800 Subject: [PATCH 27/27] fixup --- .../xpack/esql/expression/function/aggregate/Deriv.java | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java index 2974a92781d96..431f9e1c4a056 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java @@ -22,6 +22,7 @@ import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.FunctionType; import org.elasticsearch.xpack.esql.expression.function.Param; +import org.elasticsearch.xpack.esql.expression.function.TimestampAware; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.planner.ToAggregator; @@ -33,7 +34,7 @@ /** * Calculates the derivative over time of a numeric field using linear regression. */ -public class Deriv extends TimeSeriesAggregateFunction implements ToAggregator { +public class Deriv extends TimeSeriesAggregateFunction implements ToAggregator, TimestampAware { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Deriv", Deriv::new); private final Expression timestamp; @@ -62,6 +63,11 @@ private Deriv(org.elasticsearch.common.io.stream.StreamInput in) throws java.io. ); } + @Override + public Expression timestamp() { + return timestamp; + } + @Override public AggregateFunction perTimeSeriesAggregation() { return this;