diff --git a/docs/reference/query-languages/esql/_snippets/functions/description/deriv.md b/docs/reference/query-languages/esql/_snippets/functions/description/deriv.md new file mode 100644 index 0000000000000..c845d3cbbf141 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/description/deriv.md @@ -0,0 +1,6 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Description** + +Calculates the derivative over time of a numeric field using linear regression. + diff --git a/docs/reference/query-languages/esql/_snippets/functions/examples/deriv.md b/docs/reference/query-languages/esql/_snippets/functions/examples/deriv.md new file mode 100644 index 0000000000000..dc90ed8d0ce2b --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/examples/deriv.md @@ -0,0 +1,17 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Example** + +```esql +TS k8s +| WHERE pod == "three" +| STATS max_deriv = MAX(DERIV(network.cost)) BY time_bucket = BUCKET(@timestamp,5minute), pod +``` + +| max_deriv:double | time_bucket:datetime | pod:keyword | +| --- | --- | --- | +| 0.092774 | 2024-05-10T00:00:00.000Z | three | +| 0.038026 | 2024-05-10T00:05:00.000Z | three | +| -0.0197 | 2024-05-10T00:10:00.000Z | three | + + diff --git a/docs/reference/query-languages/esql/_snippets/functions/layout/deriv.md b/docs/reference/query-languages/esql/_snippets/functions/layout/deriv.md new file mode 100644 index 0000000000000..09036fe8314d1 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/layout/deriv.md @@ -0,0 +1,23 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +## `DERIV` [esql-deriv] + +**Syntax** + +:::{image} ../../../images/functions/deriv.svg +:alt: Embedded +:class: text-center +::: + + +:::{include} ../parameters/deriv.md +::: + +:::{include} ../description/deriv.md +::: + +:::{include} ../types/deriv.md +::: + +:::{include} ../examples/deriv.md +::: diff --git a/docs/reference/query-languages/esql/_snippets/functions/parameters/deriv.md b/docs/reference/query-languages/esql/_snippets/functions/parameters/deriv.md new file mode 100644 index 0000000000000..24fedc1dde506 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/parameters/deriv.md @@ -0,0 +1,7 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Parameters** + +`field` +: + diff --git a/docs/reference/query-languages/esql/_snippets/functions/types/deriv.md b/docs/reference/query-languages/esql/_snippets/functions/types/deriv.md new file mode 100644 index 0000000000000..566f6b3d786f6 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/types/deriv.md @@ -0,0 +1,10 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Supported types** + +| field | result | +| --- | --- | +| double | double | +| integer | double | +| long | double | + diff --git a/docs/reference/query-languages/esql/_snippets/lists/time-series-aggregation-functions.md b/docs/reference/query-languages/esql/_snippets/lists/time-series-aggregation-functions.md index 67d40be53178c..fda0403b7e428 100644 --- a/docs/reference/query-languages/esql/_snippets/lists/time-series-aggregation-functions.md +++ b/docs/reference/query-languages/esql/_snippets/lists/time-series-aggregation-functions.md @@ -15,3 +15,4 @@ * [`STDDEV_OVER_TIME`](../../functions-operators/time-series-aggregation-functions.md#esql-stddev_over_time) {applies_to}`stack: preview 9.3` {applies_to}`serverless: preview` * [`VARIANCE_OVER_TIME`](../../functions-operators/time-series-aggregation-functions.md#esql-variance_over_time) {applies_to}`stack: preview 9.3` {applies_to}`serverless: preview` * [`SUM_OVER_TIME`](../../functions-operators/time-series-aggregation-functions.md#esql-sum_over_time) {applies_to}`stack: preview 9.2` {applies_to}`serverless: preview` +* [`DERIV`](../../functions-operators/time-series-aggregation-functions.md#esql-deriv) {applies_to}`stack: preview 9.3` {applies_to}`serverless: preview` diff --git a/docs/reference/query-languages/esql/functions-operators/time-series-aggregation-functions.md b/docs/reference/query-languages/esql/functions-operators/time-series-aggregation-functions.md index 141ec5706de3f..8fa531b633803 100644 --- a/docs/reference/query-languages/esql/functions-operators/time-series-aggregation-functions.md +++ b/docs/reference/query-languages/esql/functions-operators/time-series-aggregation-functions.md @@ -63,3 +63,6 @@ supports the following time series aggregation functions: :::{include} ../_snippets/functions/layout/sum_over_time.md ::: + +:::{include} ../_snippets/functions/layout/deriv.md +::: diff --git a/docs/reference/query-languages/esql/images/functions/deriv.svg b/docs/reference/query-languages/esql/images/functions/deriv.svg new file mode 100644 index 0000000000000..01a1a6ad2db4a --- /dev/null +++ b/docs/reference/query-languages/esql/images/functions/deriv.svg @@ -0,0 +1 @@ +DERIV(field) \ No newline at end of file diff --git a/docs/reference/query-languages/esql/kibana/definition/functions/deriv.json b/docs/reference/query-languages/esql/kibana/definition/functions/deriv.json new file mode 100644 index 0000000000000..f84745f4de37a --- /dev/null +++ b/docs/reference/query-languages/esql/kibana/definition/functions/deriv.json @@ -0,0 +1,49 @@ +{ + "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.", + "type" : "time_series_agg", + "name" : "deriv", + "description" : "Calculates the derivative over time of a numeric field using linear regression.", + "signatures" : [ + { + "params" : [ + { + "name" : "field", + "type" : "double", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "field", + "type" : "integer", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "field", + "type" : "long", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "double" + } + ], + "examples" : [ + "TS k8s\n| WHERE pod == \"three\"\n| STATS max_deriv = MAX(DERIV(network.cost)) BY time_bucket = BUCKET(@timestamp,5minute), pod" + ], + "preview" : false, + "snapshot_only" : false +} diff --git a/docs/reference/query-languages/esql/kibana/docs/functions/deriv.md b/docs/reference/query-languages/esql/kibana/docs/functions/deriv.md new file mode 100644 index 0000000000000..fb17b716b03ea --- /dev/null +++ b/docs/reference/query-languages/esql/kibana/docs/functions/deriv.md @@ -0,0 +1,10 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +### DERIV +Calculates the derivative over time of a numeric field using linear regression. + +```esql +TS k8s +| WHERE pod == "three" +| STATS max_deriv = MAX(DERIV(network.cost)) BY time_bucket = BUCKET(@timestamp,5minute), pod +``` diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivDoubleAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivDoubleAggregatorFunction.java new file mode 100644 index 0000000000000..6698c7284d352 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivDoubleAggregatorFunction.java @@ -0,0 +1,235 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link DerivDoubleAggregator}. + * This class is generated. Edit {@code AggregatorImplementer} instead. + */ +public final class DerivDoubleAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("count", ElementType.LONG), + new IntermediateStateDesc("sumVal", ElementType.DOUBLE), + new IntermediateStateDesc("sumTs", ElementType.LONG), + new IntermediateStateDesc("sumTsVal", ElementType.DOUBLE), + new IntermediateStateDesc("sumTsSq", ElementType.LONG) ); + + private final DriverContext driverContext; + + private final SimpleLinearRegressionWithTimeseries state; + + private final List channels; + + public DerivDoubleAggregatorFunction(DriverContext driverContext, List channels, + SimpleLinearRegressionWithTimeseries state) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + } + + public static DerivDoubleAggregatorFunction create(DriverContext driverContext, + List channels) { + return new DerivDoubleAggregatorFunction(driverContext, channels, DerivDoubleAggregator.initSingle(driverContext)); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + } else if (mask.allTrue()) { + addRawInputNotMasked(page); + } else { + addRawInputMasked(page, mask); + } + } + + private void addRawInputMasked(Page page, BooleanVector mask) { + DoubleBlock valueBlock = page.getBlock(channels.get(0)); + LongBlock timestampBlock = page.getBlock(channels.get(1)); + DoubleVector valueVector = valueBlock.asVector(); + if (valueVector == null) { + addRawBlock(valueBlock, timestampBlock, mask); + return; + } + LongVector timestampVector = timestampBlock.asVector(); + if (timestampVector == null) { + addRawBlock(valueBlock, timestampBlock, mask); + return; + } + addRawVector(valueVector, timestampVector, mask); + } + + private void addRawInputNotMasked(Page page) { + DoubleBlock valueBlock = page.getBlock(channels.get(0)); + LongBlock timestampBlock = page.getBlock(channels.get(1)); + DoubleVector valueVector = valueBlock.asVector(); + if (valueVector == null) { + addRawBlock(valueBlock, timestampBlock); + return; + } + LongVector timestampVector = timestampBlock.asVector(); + if (timestampVector == null) { + addRawBlock(valueBlock, timestampBlock); + return; + } + addRawVector(valueVector, timestampVector); + } + + private void addRawVector(DoubleVector valueVector, LongVector timestampVector) { + for (int valuesPosition = 0; valuesPosition < valueVector.getPositionCount(); valuesPosition++) { + double valueValue = valueVector.getDouble(valuesPosition); + long timestampValue = timestampVector.getLong(valuesPosition); + DerivDoubleAggregator.combine(state, valueValue, timestampValue); + } + } + + private void addRawVector(DoubleVector valueVector, LongVector timestampVector, + BooleanVector mask) { + for (int valuesPosition = 0; valuesPosition < valueVector.getPositionCount(); valuesPosition++) { + if (mask.getBoolean(valuesPosition) == false) { + continue; + } + double valueValue = valueVector.getDouble(valuesPosition); + long timestampValue = timestampVector.getLong(valuesPosition); + DerivDoubleAggregator.combine(state, valueValue, timestampValue); + } + } + + private void addRawBlock(DoubleBlock valueBlock, LongBlock timestampBlock) { + for (int p = 0; p < valueBlock.getPositionCount(); p++) { + int valueValueCount = valueBlock.getValueCount(p); + if (valueValueCount == 0) { + continue; + } + int timestampValueCount = timestampBlock.getValueCount(p); + if (timestampValueCount == 0) { + continue; + } + int valueStart = valueBlock.getFirstValueIndex(p); + int valueEnd = valueStart + valueValueCount; + for (int valueOffset = valueStart; valueOffset < valueEnd; valueOffset++) { + double valueValue = valueBlock.getDouble(valueOffset); + int timestampStart = timestampBlock.getFirstValueIndex(p); + int timestampEnd = timestampStart + timestampValueCount; + for (int timestampOffset = timestampStart; timestampOffset < timestampEnd; timestampOffset++) { + long timestampValue = timestampBlock.getLong(timestampOffset); + DerivDoubleAggregator.combine(state, valueValue, timestampValue); + } + } + } + } + + private void addRawBlock(DoubleBlock valueBlock, LongBlock timestampBlock, BooleanVector mask) { + for (int p = 0; p < valueBlock.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + int valueValueCount = valueBlock.getValueCount(p); + if (valueValueCount == 0) { + continue; + } + int timestampValueCount = timestampBlock.getValueCount(p); + if (timestampValueCount == 0) { + continue; + } + int valueStart = valueBlock.getFirstValueIndex(p); + int valueEnd = valueStart + valueValueCount; + for (int valueOffset = valueStart; valueOffset < valueEnd; valueOffset++) { + double valueValue = valueBlock.getDouble(valueOffset); + int timestampStart = timestampBlock.getFirstValueIndex(p); + int timestampEnd = timestampStart + timestampValueCount; + for (int timestampOffset = timestampStart; timestampOffset < timestampEnd; timestampOffset++) { + long timestampValue = timestampBlock.getLong(timestampOffset); + DerivDoubleAggregator.combine(state, valueValue, timestampValue); + } + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block countUncast = page.getBlock(channels.get(0)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert count.getPositionCount() == 1; + Block sumValUncast = page.getBlock(channels.get(1)); + if (sumValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumVal = ((DoubleBlock) sumValUncast).asVector(); + assert sumVal.getPositionCount() == 1; + Block sumTsUncast = page.getBlock(channels.get(2)); + if (sumTsUncast.areAllValuesNull()) { + return; + } + LongVector sumTs = ((LongBlock) sumTsUncast).asVector(); + assert sumTs.getPositionCount() == 1; + Block sumTsValUncast = page.getBlock(channels.get(3)); + if (sumTsValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumTsVal = ((DoubleBlock) sumTsValUncast).asVector(); + assert sumTsVal.getPositionCount() == 1; + Block sumTsSqUncast = page.getBlock(channels.get(4)); + if (sumTsSqUncast.areAllValuesNull()) { + return; + } + LongVector sumTsSq = ((LongBlock) sumTsSqUncast).asVector(); + assert sumTsSq.getPositionCount() == 1; + DerivDoubleAggregator.combineIntermediate(state, count.getLong(0), sumVal.getDouble(0), sumTs.getLong(0), sumTsVal.getDouble(0), sumTsSq.getLong(0)); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = DerivDoubleAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivDoubleAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivDoubleAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..440c03cd403f8 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivDoubleAggregatorFunctionSupplier.java @@ -0,0 +1,47 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link DerivDoubleAggregator}. + * This class is generated. Edit {@code AggregatorFunctionSupplierImplementer} instead. + */ +public final class DerivDoubleAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + public DerivDoubleAggregatorFunctionSupplier() { + } + + @Override + public List nonGroupingIntermediateStateDesc() { + return DerivDoubleAggregatorFunction.intermediateStateDesc(); + } + + @Override + public List groupingIntermediateStateDesc() { + return DerivDoubleGroupingAggregatorFunction.intermediateStateDesc(); + } + + @Override + public DerivDoubleAggregatorFunction aggregator(DriverContext driverContext, + List channels) { + return DerivDoubleAggregatorFunction.create(driverContext, channels); + } + + @Override + public DerivDoubleGroupingAggregatorFunction groupingAggregator(DriverContext driverContext, + List channels) { + return DerivDoubleGroupingAggregatorFunction.create(channels, driverContext); + } + + @Override + public String describe() { + return "deriv of doubles"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..adf77eff37fa8 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java @@ -0,0 +1,438 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link DerivDoubleAggregator}. + * This class is generated. Edit {@code GroupingAggregatorImplementer} instead. + */ +public final class DerivDoubleGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("count", ElementType.LONG), + new IntermediateStateDesc("sumVal", ElementType.DOUBLE), + new IntermediateStateDesc("sumTs", ElementType.LONG), + new IntermediateStateDesc("sumTsVal", ElementType.DOUBLE), + new IntermediateStateDesc("sumTsSq", ElementType.LONG) ); + + private final DerivDoubleAggregator.GroupingState state; + + private final List channels; + + private final DriverContext driverContext; + + public DerivDoubleGroupingAggregatorFunction(List channels, + DerivDoubleAggregator.GroupingState state, DriverContext driverContext) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + } + + public static DerivDoubleGroupingAggregatorFunction create(List channels, + DriverContext driverContext) { + return new DerivDoubleGroupingAggregatorFunction(channels, DerivDoubleAggregator.initGrouping(driverContext), driverContext); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, + Page page) { + DoubleBlock valueBlock = page.getBlock(channels.get(0)); + LongBlock timestampBlock = page.getBlock(channels.get(1)); + DoubleVector valueVector = valueBlock.asVector(); + if (valueVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, valueBlock, timestampBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valueBlock, timestampBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valueBlock, timestampBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valueBlock, timestampBlock); + } + + @Override + public void close() { + } + }; + } + LongVector timestampVector = timestampBlock.asVector(); + if (timestampVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, valueBlock, timestampBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valueBlock, timestampBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valueBlock, timestampBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valueBlock, timestampBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valueVector, timestampVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valueVector, timestampVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valueVector, timestampVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleBlock valueBlock, + LongBlock timestampBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (valueBlock.isNull(valuesPosition)) { + continue; + } + if (timestampBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valueStart = valueBlock.getFirstValueIndex(valuesPosition); + int valueEnd = valueStart + valueBlock.getValueCount(valuesPosition); + for (int valueOffset = valueStart; valueOffset < valueEnd; valueOffset++) { + double valueValue = valueBlock.getDouble(valueOffset); + int timestampStart = timestampBlock.getFirstValueIndex(valuesPosition); + int timestampEnd = timestampStart + timestampBlock.getValueCount(valuesPosition); + for (int timestampOffset = timestampStart; timestampOffset < timestampEnd; timestampOffset++) { + long timestampValue = timestampBlock.getLong(timestampOffset); + DerivDoubleAggregator.combine(state, groupId, valueValue, timestampValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector valueVector, + LongVector timestampVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + double valueValue = valueVector.getDouble(valuesPosition); + long timestampValue = timestampVector.getLong(valuesPosition); + DerivDoubleAggregator.combine(state, groupId, valueValue, timestampValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block countUncast = page.getBlock(channels.get(0)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + Block sumValUncast = page.getBlock(channels.get(1)); + if (sumValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumVal = ((DoubleBlock) sumValUncast).asVector(); + Block sumTsUncast = page.getBlock(channels.get(2)); + if (sumTsUncast.areAllValuesNull()) { + return; + } + LongVector sumTs = ((LongBlock) sumTsUncast).asVector(); + Block sumTsValUncast = page.getBlock(channels.get(3)); + if (sumTsValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumTsVal = ((DoubleBlock) sumTsValUncast).asVector(); + Block sumTsSqUncast = page.getBlock(channels.get(4)); + if (sumTsSqUncast.areAllValuesNull()) { + return; + } + LongVector sumTsSq = ((LongBlock) sumTsSqUncast).asVector(); + assert count.getPositionCount() == sumVal.getPositionCount() && count.getPositionCount() == sumTs.getPositionCount() && count.getPositionCount() == sumTsVal.getPositionCount() && count.getPositionCount() == sumTsSq.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + DerivDoubleAggregator.combineIntermediate(state, groupId, count.getLong(valuesPosition), sumVal.getDouble(valuesPosition), sumTs.getLong(valuesPosition), sumTsVal.getDouble(valuesPosition), sumTsSq.getLong(valuesPosition)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock valueBlock, + LongBlock timestampBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (valueBlock.isNull(valuesPosition)) { + continue; + } + if (timestampBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valueStart = valueBlock.getFirstValueIndex(valuesPosition); + int valueEnd = valueStart + valueBlock.getValueCount(valuesPosition); + for (int valueOffset = valueStart; valueOffset < valueEnd; valueOffset++) { + double valueValue = valueBlock.getDouble(valueOffset); + int timestampStart = timestampBlock.getFirstValueIndex(valuesPosition); + int timestampEnd = timestampStart + timestampBlock.getValueCount(valuesPosition); + for (int timestampOffset = timestampStart; timestampOffset < timestampEnd; timestampOffset++) { + long timestampValue = timestampBlock.getLong(timestampOffset); + DerivDoubleAggregator.combine(state, groupId, valueValue, timestampValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVector valueVector, + LongVector timestampVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + double valueValue = valueVector.getDouble(valuesPosition); + long timestampValue = timestampVector.getLong(valuesPosition); + DerivDoubleAggregator.combine(state, groupId, valueValue, timestampValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block countUncast = page.getBlock(channels.get(0)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + Block sumValUncast = page.getBlock(channels.get(1)); + if (sumValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumVal = ((DoubleBlock) sumValUncast).asVector(); + Block sumTsUncast = page.getBlock(channels.get(2)); + if (sumTsUncast.areAllValuesNull()) { + return; + } + LongVector sumTs = ((LongBlock) sumTsUncast).asVector(); + Block sumTsValUncast = page.getBlock(channels.get(3)); + if (sumTsValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumTsVal = ((DoubleBlock) sumTsValUncast).asVector(); + Block sumTsSqUncast = page.getBlock(channels.get(4)); + if (sumTsSqUncast.areAllValuesNull()) { + return; + } + LongVector sumTsSq = ((LongBlock) sumTsSqUncast).asVector(); + assert count.getPositionCount() == sumVal.getPositionCount() && count.getPositionCount() == sumTs.getPositionCount() && count.getPositionCount() == sumTsVal.getPositionCount() && count.getPositionCount() == sumTsSq.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + DerivDoubleAggregator.combineIntermediate(state, groupId, count.getLong(valuesPosition), sumVal.getDouble(valuesPosition), sumTs.getLong(valuesPosition), sumTsVal.getDouble(valuesPosition), sumTsSq.getLong(valuesPosition)); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock valueBlock, + LongBlock timestampBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + if (valueBlock.isNull(valuesPosition)) { + continue; + } + if (timestampBlock.isNull(valuesPosition)) { + continue; + } + int groupId = groups.getInt(groupPosition); + int valueStart = valueBlock.getFirstValueIndex(valuesPosition); + int valueEnd = valueStart + valueBlock.getValueCount(valuesPosition); + for (int valueOffset = valueStart; valueOffset < valueEnd; valueOffset++) { + double valueValue = valueBlock.getDouble(valueOffset); + int timestampStart = timestampBlock.getFirstValueIndex(valuesPosition); + int timestampEnd = timestampStart + timestampBlock.getValueCount(valuesPosition); + for (int timestampOffset = timestampStart; timestampOffset < timestampEnd; timestampOffset++) { + long timestampValue = timestampBlock.getLong(timestampOffset); + DerivDoubleAggregator.combine(state, groupId, valueValue, timestampValue); + } + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, DoubleVector valueVector, + LongVector timestampVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + int groupId = groups.getInt(groupPosition); + double valueValue = valueVector.getDouble(valuesPosition); + long timestampValue = timestampVector.getLong(valuesPosition); + DerivDoubleAggregator.combine(state, groupId, valueValue, timestampValue); + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block countUncast = page.getBlock(channels.get(0)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + Block sumValUncast = page.getBlock(channels.get(1)); + if (sumValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumVal = ((DoubleBlock) sumValUncast).asVector(); + Block sumTsUncast = page.getBlock(channels.get(2)); + if (sumTsUncast.areAllValuesNull()) { + return; + } + LongVector sumTs = ((LongBlock) sumTsUncast).asVector(); + Block sumTsValUncast = page.getBlock(channels.get(3)); + if (sumTsValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumTsVal = ((DoubleBlock) sumTsValUncast).asVector(); + Block sumTsSqUncast = page.getBlock(channels.get(4)); + if (sumTsSqUncast.areAllValuesNull()) { + return; + } + LongVector sumTsSq = ((LongBlock) sumTsSqUncast).asVector(); + assert count.getPositionCount() == sumVal.getPositionCount() && count.getPositionCount() == sumTs.getPositionCount() && count.getPositionCount() == sumTsVal.getPositionCount() && count.getPositionCount() == sumTsSq.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + int valuesPosition = groupPosition + positionOffset; + DerivDoubleAggregator.combineIntermediate(state, groupId, count.getLong(valuesPosition), sumVal.getDouble(valuesPosition), sumTs.getLong(valuesPosition), sumTsVal.getDouble(valuesPosition), sumTsSq.getLong(valuesPosition)); + } + } + + private void maybeEnableGroupIdTracking(SeenGroupIds seenGroupIds, DoubleBlock valueBlock, + LongBlock timestampBlock) { + if (valueBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + if (timestampBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + GroupingAggregatorEvaluationContext ctx) { + blocks[offset] = DerivDoubleAggregator.evaluateFinal(state, selected, ctx); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivIntAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivIntAggregatorFunction.java new file mode 100644 index 0000000000000..1c936d5de9135 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivIntAggregatorFunction.java @@ -0,0 +1,236 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link DerivIntAggregator}. + * This class is generated. Edit {@code AggregatorImplementer} instead. + */ +public final class DerivIntAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("count", ElementType.LONG), + new IntermediateStateDesc("sumVal", ElementType.DOUBLE), + new IntermediateStateDesc("sumTs", ElementType.LONG), + new IntermediateStateDesc("sumTsVal", ElementType.DOUBLE), + new IntermediateStateDesc("sumTsSq", ElementType.LONG) ); + + private final DriverContext driverContext; + + private final SimpleLinearRegressionWithTimeseries state; + + private final List channels; + + public DerivIntAggregatorFunction(DriverContext driverContext, List channels, + SimpleLinearRegressionWithTimeseries state) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + } + + public static DerivIntAggregatorFunction create(DriverContext driverContext, + List channels) { + return new DerivIntAggregatorFunction(driverContext, channels, DerivIntAggregator.initSingle(driverContext)); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + } else if (mask.allTrue()) { + addRawInputNotMasked(page); + } else { + addRawInputMasked(page, mask); + } + } + + private void addRawInputMasked(Page page, BooleanVector mask) { + IntBlock valueBlock = page.getBlock(channels.get(0)); + LongBlock timestampBlock = page.getBlock(channels.get(1)); + IntVector valueVector = valueBlock.asVector(); + if (valueVector == null) { + addRawBlock(valueBlock, timestampBlock, mask); + return; + } + LongVector timestampVector = timestampBlock.asVector(); + if (timestampVector == null) { + addRawBlock(valueBlock, timestampBlock, mask); + return; + } + addRawVector(valueVector, timestampVector, mask); + } + + private void addRawInputNotMasked(Page page) { + IntBlock valueBlock = page.getBlock(channels.get(0)); + LongBlock timestampBlock = page.getBlock(channels.get(1)); + IntVector valueVector = valueBlock.asVector(); + if (valueVector == null) { + addRawBlock(valueBlock, timestampBlock); + return; + } + LongVector timestampVector = timestampBlock.asVector(); + if (timestampVector == null) { + addRawBlock(valueBlock, timestampBlock); + return; + } + addRawVector(valueVector, timestampVector); + } + + private void addRawVector(IntVector valueVector, LongVector timestampVector) { + for (int valuesPosition = 0; valuesPosition < valueVector.getPositionCount(); valuesPosition++) { + int valueValue = valueVector.getInt(valuesPosition); + long timestampValue = timestampVector.getLong(valuesPosition); + DerivIntAggregator.combine(state, valueValue, timestampValue); + } + } + + private void addRawVector(IntVector valueVector, LongVector timestampVector, BooleanVector mask) { + for (int valuesPosition = 0; valuesPosition < valueVector.getPositionCount(); valuesPosition++) { + if (mask.getBoolean(valuesPosition) == false) { + continue; + } + int valueValue = valueVector.getInt(valuesPosition); + long timestampValue = timestampVector.getLong(valuesPosition); + DerivIntAggregator.combine(state, valueValue, timestampValue); + } + } + + private void addRawBlock(IntBlock valueBlock, LongBlock timestampBlock) { + for (int p = 0; p < valueBlock.getPositionCount(); p++) { + int valueValueCount = valueBlock.getValueCount(p); + if (valueValueCount == 0) { + continue; + } + int timestampValueCount = timestampBlock.getValueCount(p); + if (timestampValueCount == 0) { + continue; + } + int valueStart = valueBlock.getFirstValueIndex(p); + int valueEnd = valueStart + valueValueCount; + for (int valueOffset = valueStart; valueOffset < valueEnd; valueOffset++) { + int valueValue = valueBlock.getInt(valueOffset); + int timestampStart = timestampBlock.getFirstValueIndex(p); + int timestampEnd = timestampStart + timestampValueCount; + for (int timestampOffset = timestampStart; timestampOffset < timestampEnd; timestampOffset++) { + long timestampValue = timestampBlock.getLong(timestampOffset); + DerivIntAggregator.combine(state, valueValue, timestampValue); + } + } + } + } + + private void addRawBlock(IntBlock valueBlock, LongBlock timestampBlock, BooleanVector mask) { + for (int p = 0; p < valueBlock.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + int valueValueCount = valueBlock.getValueCount(p); + if (valueValueCount == 0) { + continue; + } + int timestampValueCount = timestampBlock.getValueCount(p); + if (timestampValueCount == 0) { + continue; + } + int valueStart = valueBlock.getFirstValueIndex(p); + int valueEnd = valueStart + valueValueCount; + for (int valueOffset = valueStart; valueOffset < valueEnd; valueOffset++) { + int valueValue = valueBlock.getInt(valueOffset); + int timestampStart = timestampBlock.getFirstValueIndex(p); + int timestampEnd = timestampStart + timestampValueCount; + for (int timestampOffset = timestampStart; timestampOffset < timestampEnd; timestampOffset++) { + long timestampValue = timestampBlock.getLong(timestampOffset); + DerivIntAggregator.combine(state, valueValue, timestampValue); + } + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block countUncast = page.getBlock(channels.get(0)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert count.getPositionCount() == 1; + Block sumValUncast = page.getBlock(channels.get(1)); + if (sumValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumVal = ((DoubleBlock) sumValUncast).asVector(); + assert sumVal.getPositionCount() == 1; + Block sumTsUncast = page.getBlock(channels.get(2)); + if (sumTsUncast.areAllValuesNull()) { + return; + } + LongVector sumTs = ((LongBlock) sumTsUncast).asVector(); + assert sumTs.getPositionCount() == 1; + Block sumTsValUncast = page.getBlock(channels.get(3)); + if (sumTsValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumTsVal = ((DoubleBlock) sumTsValUncast).asVector(); + assert sumTsVal.getPositionCount() == 1; + Block sumTsSqUncast = page.getBlock(channels.get(4)); + if (sumTsSqUncast.areAllValuesNull()) { + return; + } + LongVector sumTsSq = ((LongBlock) sumTsSqUncast).asVector(); + assert sumTsSq.getPositionCount() == 1; + DerivIntAggregator.combineIntermediate(state, count.getLong(0), sumVal.getDouble(0), sumTs.getLong(0), sumTsVal.getDouble(0), sumTsSq.getLong(0)); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = DerivIntAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivIntAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivIntAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..ecd4e4bf8dbd8 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivIntAggregatorFunctionSupplier.java @@ -0,0 +1,47 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link DerivIntAggregator}. + * This class is generated. Edit {@code AggregatorFunctionSupplierImplementer} instead. + */ +public final class DerivIntAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + public DerivIntAggregatorFunctionSupplier() { + } + + @Override + public List nonGroupingIntermediateStateDesc() { + return DerivIntAggregatorFunction.intermediateStateDesc(); + } + + @Override + public List groupingIntermediateStateDesc() { + return DerivIntGroupingAggregatorFunction.intermediateStateDesc(); + } + + @Override + public DerivIntAggregatorFunction aggregator(DriverContext driverContext, + List channels) { + return DerivIntAggregatorFunction.create(driverContext, channels); + } + + @Override + public DerivIntGroupingAggregatorFunction groupingAggregator(DriverContext driverContext, + List channels) { + return DerivIntGroupingAggregatorFunction.create(channels, driverContext); + } + + @Override + public String describe() { + return "deriv of ints"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..ab1e3d602be25 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java @@ -0,0 +1,439 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link DerivIntAggregator}. + * This class is generated. Edit {@code GroupingAggregatorImplementer} instead. + */ +public final class DerivIntGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("count", ElementType.LONG), + new IntermediateStateDesc("sumVal", ElementType.DOUBLE), + new IntermediateStateDesc("sumTs", ElementType.LONG), + new IntermediateStateDesc("sumTsVal", ElementType.DOUBLE), + new IntermediateStateDesc("sumTsSq", ElementType.LONG) ); + + private final DerivDoubleAggregator.GroupingState state; + + private final List channels; + + private final DriverContext driverContext; + + public DerivIntGroupingAggregatorFunction(List channels, + DerivDoubleAggregator.GroupingState state, DriverContext driverContext) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + } + + public static DerivIntGroupingAggregatorFunction create(List channels, + DriverContext driverContext) { + return new DerivIntGroupingAggregatorFunction(channels, DerivIntAggregator.initGrouping(driverContext), driverContext); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, + Page page) { + IntBlock valueBlock = page.getBlock(channels.get(0)); + LongBlock timestampBlock = page.getBlock(channels.get(1)); + IntVector valueVector = valueBlock.asVector(); + if (valueVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, valueBlock, timestampBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valueBlock, timestampBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valueBlock, timestampBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valueBlock, timestampBlock); + } + + @Override + public void close() { + } + }; + } + LongVector timestampVector = timestampBlock.asVector(); + if (timestampVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, valueBlock, timestampBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valueBlock, timestampBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valueBlock, timestampBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valueBlock, timestampBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valueVector, timestampVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valueVector, timestampVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valueVector, timestampVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, IntBlock valueBlock, + LongBlock timestampBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (valueBlock.isNull(valuesPosition)) { + continue; + } + if (timestampBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valueStart = valueBlock.getFirstValueIndex(valuesPosition); + int valueEnd = valueStart + valueBlock.getValueCount(valuesPosition); + for (int valueOffset = valueStart; valueOffset < valueEnd; valueOffset++) { + int valueValue = valueBlock.getInt(valueOffset); + int timestampStart = timestampBlock.getFirstValueIndex(valuesPosition); + int timestampEnd = timestampStart + timestampBlock.getValueCount(valuesPosition); + for (int timestampOffset = timestampStart; timestampOffset < timestampEnd; timestampOffset++) { + long timestampValue = timestampBlock.getLong(timestampOffset); + DerivIntAggregator.combine(state, groupId, valueValue, timestampValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector valueVector, + LongVector timestampVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valueValue = valueVector.getInt(valuesPosition); + long timestampValue = timestampVector.getLong(valuesPosition); + DerivIntAggregator.combine(state, groupId, valueValue, timestampValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block countUncast = page.getBlock(channels.get(0)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + Block sumValUncast = page.getBlock(channels.get(1)); + if (sumValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumVal = ((DoubleBlock) sumValUncast).asVector(); + Block sumTsUncast = page.getBlock(channels.get(2)); + if (sumTsUncast.areAllValuesNull()) { + return; + } + LongVector sumTs = ((LongBlock) sumTsUncast).asVector(); + Block sumTsValUncast = page.getBlock(channels.get(3)); + if (sumTsValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumTsVal = ((DoubleBlock) sumTsValUncast).asVector(); + Block sumTsSqUncast = page.getBlock(channels.get(4)); + if (sumTsSqUncast.areAllValuesNull()) { + return; + } + LongVector sumTsSq = ((LongBlock) sumTsSqUncast).asVector(); + assert count.getPositionCount() == sumVal.getPositionCount() && count.getPositionCount() == sumTs.getPositionCount() && count.getPositionCount() == sumTsVal.getPositionCount() && count.getPositionCount() == sumTsSq.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + DerivIntAggregator.combineIntermediate(state, groupId, count.getLong(valuesPosition), sumVal.getDouble(valuesPosition), sumTs.getLong(valuesPosition), sumTsVal.getDouble(valuesPosition), sumTsSq.getLong(valuesPosition)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock valueBlock, + LongBlock timestampBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (valueBlock.isNull(valuesPosition)) { + continue; + } + if (timestampBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valueStart = valueBlock.getFirstValueIndex(valuesPosition); + int valueEnd = valueStart + valueBlock.getValueCount(valuesPosition); + for (int valueOffset = valueStart; valueOffset < valueEnd; valueOffset++) { + int valueValue = valueBlock.getInt(valueOffset); + int timestampStart = timestampBlock.getFirstValueIndex(valuesPosition); + int timestampEnd = timestampStart + timestampBlock.getValueCount(valuesPosition); + for (int timestampOffset = timestampStart; timestampOffset < timestampEnd; timestampOffset++) { + long timestampValue = timestampBlock.getLong(timestampOffset); + DerivIntAggregator.combine(state, groupId, valueValue, timestampValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector valueVector, + LongVector timestampVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valueValue = valueVector.getInt(valuesPosition); + long timestampValue = timestampVector.getLong(valuesPosition); + DerivIntAggregator.combine(state, groupId, valueValue, timestampValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block countUncast = page.getBlock(channels.get(0)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + Block sumValUncast = page.getBlock(channels.get(1)); + if (sumValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumVal = ((DoubleBlock) sumValUncast).asVector(); + Block sumTsUncast = page.getBlock(channels.get(2)); + if (sumTsUncast.areAllValuesNull()) { + return; + } + LongVector sumTs = ((LongBlock) sumTsUncast).asVector(); + Block sumTsValUncast = page.getBlock(channels.get(3)); + if (sumTsValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumTsVal = ((DoubleBlock) sumTsValUncast).asVector(); + Block sumTsSqUncast = page.getBlock(channels.get(4)); + if (sumTsSqUncast.areAllValuesNull()) { + return; + } + LongVector sumTsSq = ((LongBlock) sumTsSqUncast).asVector(); + assert count.getPositionCount() == sumVal.getPositionCount() && count.getPositionCount() == sumTs.getPositionCount() && count.getPositionCount() == sumTsVal.getPositionCount() && count.getPositionCount() == sumTsSq.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + DerivIntAggregator.combineIntermediate(state, groupId, count.getLong(valuesPosition), sumVal.getDouble(valuesPosition), sumTs.getLong(valuesPosition), sumTsVal.getDouble(valuesPosition), sumTsSq.getLong(valuesPosition)); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, IntBlock valueBlock, + LongBlock timestampBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + if (valueBlock.isNull(valuesPosition)) { + continue; + } + if (timestampBlock.isNull(valuesPosition)) { + continue; + } + int groupId = groups.getInt(groupPosition); + int valueStart = valueBlock.getFirstValueIndex(valuesPosition); + int valueEnd = valueStart + valueBlock.getValueCount(valuesPosition); + for (int valueOffset = valueStart; valueOffset < valueEnd; valueOffset++) { + int valueValue = valueBlock.getInt(valueOffset); + int timestampStart = timestampBlock.getFirstValueIndex(valuesPosition); + int timestampEnd = timestampStart + timestampBlock.getValueCount(valuesPosition); + for (int timestampOffset = timestampStart; timestampOffset < timestampEnd; timestampOffset++) { + long timestampValue = timestampBlock.getLong(timestampOffset); + DerivIntAggregator.combine(state, groupId, valueValue, timestampValue); + } + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, IntVector valueVector, + LongVector timestampVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + int groupId = groups.getInt(groupPosition); + int valueValue = valueVector.getInt(valuesPosition); + long timestampValue = timestampVector.getLong(valuesPosition); + DerivIntAggregator.combine(state, groupId, valueValue, timestampValue); + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block countUncast = page.getBlock(channels.get(0)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + Block sumValUncast = page.getBlock(channels.get(1)); + if (sumValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumVal = ((DoubleBlock) sumValUncast).asVector(); + Block sumTsUncast = page.getBlock(channels.get(2)); + if (sumTsUncast.areAllValuesNull()) { + return; + } + LongVector sumTs = ((LongBlock) sumTsUncast).asVector(); + Block sumTsValUncast = page.getBlock(channels.get(3)); + if (sumTsValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumTsVal = ((DoubleBlock) sumTsValUncast).asVector(); + Block sumTsSqUncast = page.getBlock(channels.get(4)); + if (sumTsSqUncast.areAllValuesNull()) { + return; + } + LongVector sumTsSq = ((LongBlock) sumTsSqUncast).asVector(); + assert count.getPositionCount() == sumVal.getPositionCount() && count.getPositionCount() == sumTs.getPositionCount() && count.getPositionCount() == sumTsVal.getPositionCount() && count.getPositionCount() == sumTsSq.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + int valuesPosition = groupPosition + positionOffset; + DerivIntAggregator.combineIntermediate(state, groupId, count.getLong(valuesPosition), sumVal.getDouble(valuesPosition), sumTs.getLong(valuesPosition), sumTsVal.getDouble(valuesPosition), sumTsSq.getLong(valuesPosition)); + } + } + + private void maybeEnableGroupIdTracking(SeenGroupIds seenGroupIds, IntBlock valueBlock, + LongBlock timestampBlock) { + if (valueBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + if (timestampBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + GroupingAggregatorEvaluationContext ctx) { + blocks[offset] = DerivIntAggregator.evaluateFinal(state, selected, ctx); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivLongAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivLongAggregatorFunction.java new file mode 100644 index 0000000000000..b18118d21c08a --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivLongAggregatorFunction.java @@ -0,0 +1,235 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link DerivLongAggregator}. + * This class is generated. Edit {@code AggregatorImplementer} instead. + */ +public final class DerivLongAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("count", ElementType.LONG), + new IntermediateStateDesc("sumVal", ElementType.DOUBLE), + new IntermediateStateDesc("sumTs", ElementType.LONG), + new IntermediateStateDesc("sumTsVal", ElementType.DOUBLE), + new IntermediateStateDesc("sumTsSq", ElementType.LONG) ); + + private final DriverContext driverContext; + + private final SimpleLinearRegressionWithTimeseries state; + + private final List channels; + + public DerivLongAggregatorFunction(DriverContext driverContext, List channels, + SimpleLinearRegressionWithTimeseries state) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + } + + public static DerivLongAggregatorFunction create(DriverContext driverContext, + List channels) { + return new DerivLongAggregatorFunction(driverContext, channels, DerivLongAggregator.initSingle(driverContext)); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + } else if (mask.allTrue()) { + addRawInputNotMasked(page); + } else { + addRawInputMasked(page, mask); + } + } + + private void addRawInputMasked(Page page, BooleanVector mask) { + LongBlock valueBlock = page.getBlock(channels.get(0)); + LongBlock timestampBlock = page.getBlock(channels.get(1)); + LongVector valueVector = valueBlock.asVector(); + if (valueVector == null) { + addRawBlock(valueBlock, timestampBlock, mask); + return; + } + LongVector timestampVector = timestampBlock.asVector(); + if (timestampVector == null) { + addRawBlock(valueBlock, timestampBlock, mask); + return; + } + addRawVector(valueVector, timestampVector, mask); + } + + private void addRawInputNotMasked(Page page) { + LongBlock valueBlock = page.getBlock(channels.get(0)); + LongBlock timestampBlock = page.getBlock(channels.get(1)); + LongVector valueVector = valueBlock.asVector(); + if (valueVector == null) { + addRawBlock(valueBlock, timestampBlock); + return; + } + LongVector timestampVector = timestampBlock.asVector(); + if (timestampVector == null) { + addRawBlock(valueBlock, timestampBlock); + return; + } + addRawVector(valueVector, timestampVector); + } + + private void addRawVector(LongVector valueVector, LongVector timestampVector) { + for (int valuesPosition = 0; valuesPosition < valueVector.getPositionCount(); valuesPosition++) { + long valueValue = valueVector.getLong(valuesPosition); + long timestampValue = timestampVector.getLong(valuesPosition); + DerivLongAggregator.combine(state, valueValue, timestampValue); + } + } + + private void addRawVector(LongVector valueVector, LongVector timestampVector, + BooleanVector mask) { + for (int valuesPosition = 0; valuesPosition < valueVector.getPositionCount(); valuesPosition++) { + if (mask.getBoolean(valuesPosition) == false) { + continue; + } + long valueValue = valueVector.getLong(valuesPosition); + long timestampValue = timestampVector.getLong(valuesPosition); + DerivLongAggregator.combine(state, valueValue, timestampValue); + } + } + + private void addRawBlock(LongBlock valueBlock, LongBlock timestampBlock) { + for (int p = 0; p < valueBlock.getPositionCount(); p++) { + int valueValueCount = valueBlock.getValueCount(p); + if (valueValueCount == 0) { + continue; + } + int timestampValueCount = timestampBlock.getValueCount(p); + if (timestampValueCount == 0) { + continue; + } + int valueStart = valueBlock.getFirstValueIndex(p); + int valueEnd = valueStart + valueValueCount; + for (int valueOffset = valueStart; valueOffset < valueEnd; valueOffset++) { + long valueValue = valueBlock.getLong(valueOffset); + int timestampStart = timestampBlock.getFirstValueIndex(p); + int timestampEnd = timestampStart + timestampValueCount; + for (int timestampOffset = timestampStart; timestampOffset < timestampEnd; timestampOffset++) { + long timestampValue = timestampBlock.getLong(timestampOffset); + DerivLongAggregator.combine(state, valueValue, timestampValue); + } + } + } + } + + private void addRawBlock(LongBlock valueBlock, LongBlock timestampBlock, BooleanVector mask) { + for (int p = 0; p < valueBlock.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + int valueValueCount = valueBlock.getValueCount(p); + if (valueValueCount == 0) { + continue; + } + int timestampValueCount = timestampBlock.getValueCount(p); + if (timestampValueCount == 0) { + continue; + } + int valueStart = valueBlock.getFirstValueIndex(p); + int valueEnd = valueStart + valueValueCount; + for (int valueOffset = valueStart; valueOffset < valueEnd; valueOffset++) { + long valueValue = valueBlock.getLong(valueOffset); + int timestampStart = timestampBlock.getFirstValueIndex(p); + int timestampEnd = timestampStart + timestampValueCount; + for (int timestampOffset = timestampStart; timestampOffset < timestampEnd; timestampOffset++) { + long timestampValue = timestampBlock.getLong(timestampOffset); + DerivLongAggregator.combine(state, valueValue, timestampValue); + } + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block countUncast = page.getBlock(channels.get(0)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert count.getPositionCount() == 1; + Block sumValUncast = page.getBlock(channels.get(1)); + if (sumValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumVal = ((DoubleBlock) sumValUncast).asVector(); + assert sumVal.getPositionCount() == 1; + Block sumTsUncast = page.getBlock(channels.get(2)); + if (sumTsUncast.areAllValuesNull()) { + return; + } + LongVector sumTs = ((LongBlock) sumTsUncast).asVector(); + assert sumTs.getPositionCount() == 1; + Block sumTsValUncast = page.getBlock(channels.get(3)); + if (sumTsValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumTsVal = ((DoubleBlock) sumTsValUncast).asVector(); + assert sumTsVal.getPositionCount() == 1; + Block sumTsSqUncast = page.getBlock(channels.get(4)); + if (sumTsSqUncast.areAllValuesNull()) { + return; + } + LongVector sumTsSq = ((LongBlock) sumTsSqUncast).asVector(); + assert sumTsSq.getPositionCount() == 1; + DerivLongAggregator.combineIntermediate(state, count.getLong(0), sumVal.getDouble(0), sumTs.getLong(0), sumTsVal.getDouble(0), sumTsSq.getLong(0)); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = DerivLongAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivLongAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivLongAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..259eb756cb1b2 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivLongAggregatorFunctionSupplier.java @@ -0,0 +1,47 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link DerivLongAggregator}. + * This class is generated. Edit {@code AggregatorFunctionSupplierImplementer} instead. + */ +public final class DerivLongAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + public DerivLongAggregatorFunctionSupplier() { + } + + @Override + public List nonGroupingIntermediateStateDesc() { + return DerivLongAggregatorFunction.intermediateStateDesc(); + } + + @Override + public List groupingIntermediateStateDesc() { + return DerivLongGroupingAggregatorFunction.intermediateStateDesc(); + } + + @Override + public DerivLongAggregatorFunction aggregator(DriverContext driverContext, + List channels) { + return DerivLongAggregatorFunction.create(driverContext, channels); + } + + @Override + public DerivLongGroupingAggregatorFunction groupingAggregator(DriverContext driverContext, + List channels) { + return DerivLongGroupingAggregatorFunction.create(channels, driverContext); + } + + @Override + public String describe() { + return "deriv of longs"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..cb97638df4ea4 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java @@ -0,0 +1,438 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link DerivLongAggregator}. + * This class is generated. Edit {@code GroupingAggregatorImplementer} instead. + */ +public final class DerivLongGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("count", ElementType.LONG), + new IntermediateStateDesc("sumVal", ElementType.DOUBLE), + new IntermediateStateDesc("sumTs", ElementType.LONG), + new IntermediateStateDesc("sumTsVal", ElementType.DOUBLE), + new IntermediateStateDesc("sumTsSq", ElementType.LONG) ); + + private final DerivDoubleAggregator.GroupingState state; + + private final List channels; + + private final DriverContext driverContext; + + public DerivLongGroupingAggregatorFunction(List channels, + DerivDoubleAggregator.GroupingState state, DriverContext driverContext) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + } + + public static DerivLongGroupingAggregatorFunction create(List channels, + DriverContext driverContext) { + return new DerivLongGroupingAggregatorFunction(channels, DerivLongAggregator.initGrouping(driverContext), driverContext); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, + Page page) { + LongBlock valueBlock = page.getBlock(channels.get(0)); + LongBlock timestampBlock = page.getBlock(channels.get(1)); + LongVector valueVector = valueBlock.asVector(); + if (valueVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, valueBlock, timestampBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valueBlock, timestampBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valueBlock, timestampBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valueBlock, timestampBlock); + } + + @Override + public void close() { + } + }; + } + LongVector timestampVector = timestampBlock.asVector(); + if (timestampVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, valueBlock, timestampBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valueBlock, timestampBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valueBlock, timestampBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valueBlock, timestampBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valueVector, timestampVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valueVector, timestampVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valueVector, timestampVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock valueBlock, + LongBlock timestampBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (valueBlock.isNull(valuesPosition)) { + continue; + } + if (timestampBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valueStart = valueBlock.getFirstValueIndex(valuesPosition); + int valueEnd = valueStart + valueBlock.getValueCount(valuesPosition); + for (int valueOffset = valueStart; valueOffset < valueEnd; valueOffset++) { + long valueValue = valueBlock.getLong(valueOffset); + int timestampStart = timestampBlock.getFirstValueIndex(valuesPosition); + int timestampEnd = timestampStart + timestampBlock.getValueCount(valuesPosition); + for (int timestampOffset = timestampStart; timestampOffset < timestampEnd; timestampOffset++) { + long timestampValue = timestampBlock.getLong(timestampOffset); + DerivLongAggregator.combine(state, groupId, valueValue, timestampValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector valueVector, + LongVector timestampVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + long valueValue = valueVector.getLong(valuesPosition); + long timestampValue = timestampVector.getLong(valuesPosition); + DerivLongAggregator.combine(state, groupId, valueValue, timestampValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block countUncast = page.getBlock(channels.get(0)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + Block sumValUncast = page.getBlock(channels.get(1)); + if (sumValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumVal = ((DoubleBlock) sumValUncast).asVector(); + Block sumTsUncast = page.getBlock(channels.get(2)); + if (sumTsUncast.areAllValuesNull()) { + return; + } + LongVector sumTs = ((LongBlock) sumTsUncast).asVector(); + Block sumTsValUncast = page.getBlock(channels.get(3)); + if (sumTsValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumTsVal = ((DoubleBlock) sumTsValUncast).asVector(); + Block sumTsSqUncast = page.getBlock(channels.get(4)); + if (sumTsSqUncast.areAllValuesNull()) { + return; + } + LongVector sumTsSq = ((LongBlock) sumTsSqUncast).asVector(); + assert count.getPositionCount() == sumVal.getPositionCount() && count.getPositionCount() == sumTs.getPositionCount() && count.getPositionCount() == sumTsVal.getPositionCount() && count.getPositionCount() == sumTsSq.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + DerivLongAggregator.combineIntermediate(state, groupId, count.getLong(valuesPosition), sumVal.getDouble(valuesPosition), sumTs.getLong(valuesPosition), sumTsVal.getDouble(valuesPosition), sumTsSq.getLong(valuesPosition)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock valueBlock, + LongBlock timestampBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (valueBlock.isNull(valuesPosition)) { + continue; + } + if (timestampBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valueStart = valueBlock.getFirstValueIndex(valuesPosition); + int valueEnd = valueStart + valueBlock.getValueCount(valuesPosition); + for (int valueOffset = valueStart; valueOffset < valueEnd; valueOffset++) { + long valueValue = valueBlock.getLong(valueOffset); + int timestampStart = timestampBlock.getFirstValueIndex(valuesPosition); + int timestampEnd = timestampStart + timestampBlock.getValueCount(valuesPosition); + for (int timestampOffset = timestampStart; timestampOffset < timestampEnd; timestampOffset++) { + long timestampValue = timestampBlock.getLong(timestampOffset); + DerivLongAggregator.combine(state, groupId, valueValue, timestampValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector valueVector, + LongVector timestampVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + long valueValue = valueVector.getLong(valuesPosition); + long timestampValue = timestampVector.getLong(valuesPosition); + DerivLongAggregator.combine(state, groupId, valueValue, timestampValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block countUncast = page.getBlock(channels.get(0)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + Block sumValUncast = page.getBlock(channels.get(1)); + if (sumValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumVal = ((DoubleBlock) sumValUncast).asVector(); + Block sumTsUncast = page.getBlock(channels.get(2)); + if (sumTsUncast.areAllValuesNull()) { + return; + } + LongVector sumTs = ((LongBlock) sumTsUncast).asVector(); + Block sumTsValUncast = page.getBlock(channels.get(3)); + if (sumTsValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumTsVal = ((DoubleBlock) sumTsValUncast).asVector(); + Block sumTsSqUncast = page.getBlock(channels.get(4)); + if (sumTsSqUncast.areAllValuesNull()) { + return; + } + LongVector sumTsSq = ((LongBlock) sumTsSqUncast).asVector(); + assert count.getPositionCount() == sumVal.getPositionCount() && count.getPositionCount() == sumTs.getPositionCount() && count.getPositionCount() == sumTsVal.getPositionCount() && count.getPositionCount() == sumTsSq.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + DerivLongAggregator.combineIntermediate(state, groupId, count.getLong(valuesPosition), sumVal.getDouble(valuesPosition), sumTs.getLong(valuesPosition), sumTsVal.getDouble(valuesPosition), sumTsSq.getLong(valuesPosition)); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, LongBlock valueBlock, + LongBlock timestampBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + if (valueBlock.isNull(valuesPosition)) { + continue; + } + if (timestampBlock.isNull(valuesPosition)) { + continue; + } + int groupId = groups.getInt(groupPosition); + int valueStart = valueBlock.getFirstValueIndex(valuesPosition); + int valueEnd = valueStart + valueBlock.getValueCount(valuesPosition); + for (int valueOffset = valueStart; valueOffset < valueEnd; valueOffset++) { + long valueValue = valueBlock.getLong(valueOffset); + int timestampStart = timestampBlock.getFirstValueIndex(valuesPosition); + int timestampEnd = timestampStart + timestampBlock.getValueCount(valuesPosition); + for (int timestampOffset = timestampStart; timestampOffset < timestampEnd; timestampOffset++) { + long timestampValue = timestampBlock.getLong(timestampOffset); + DerivLongAggregator.combine(state, groupId, valueValue, timestampValue); + } + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, LongVector valueVector, + LongVector timestampVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + int groupId = groups.getInt(groupPosition); + long valueValue = valueVector.getLong(valuesPosition); + long timestampValue = timestampVector.getLong(valuesPosition); + DerivLongAggregator.combine(state, groupId, valueValue, timestampValue); + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block countUncast = page.getBlock(channels.get(0)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + Block sumValUncast = page.getBlock(channels.get(1)); + if (sumValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumVal = ((DoubleBlock) sumValUncast).asVector(); + Block sumTsUncast = page.getBlock(channels.get(2)); + if (sumTsUncast.areAllValuesNull()) { + return; + } + LongVector sumTs = ((LongBlock) sumTsUncast).asVector(); + Block sumTsValUncast = page.getBlock(channels.get(3)); + if (sumTsValUncast.areAllValuesNull()) { + return; + } + DoubleVector sumTsVal = ((DoubleBlock) sumTsValUncast).asVector(); + Block sumTsSqUncast = page.getBlock(channels.get(4)); + if (sumTsSqUncast.areAllValuesNull()) { + return; + } + LongVector sumTsSq = ((LongBlock) sumTsSqUncast).asVector(); + assert count.getPositionCount() == sumVal.getPositionCount() && count.getPositionCount() == sumTs.getPositionCount() && count.getPositionCount() == sumTsVal.getPositionCount() && count.getPositionCount() == sumTsSq.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + int valuesPosition = groupPosition + positionOffset; + DerivLongAggregator.combineIntermediate(state, groupId, count.getLong(valuesPosition), sumVal.getDouble(valuesPosition), sumTs.getLong(valuesPosition), sumTsVal.getDouble(valuesPosition), sumTsSq.getLong(valuesPosition)); + } + } + + private void maybeEnableGroupIdTracking(SeenGroupIds seenGroupIds, LongBlock valueBlock, + LongBlock timestampBlock) { + if (valueBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + if (timestampBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + GroupingAggregatorEvaluationContext ctx) { + blocks[offset] = DerivLongAggregator.evaluateFinal(state, selected, ctx); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivDoubleAggregator.java new file mode 100644 index 0000000000000..e44df4d56236b --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivDoubleAggregator.java @@ -0,0 +1,172 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.ObjectArray; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; + +@Aggregator( + { + @IntermediateState(name = "count", type = "LONG"), + @IntermediateState(name = "sumVal", type = "DOUBLE"), + @IntermediateState(name = "sumTs", type = "LONG"), + @IntermediateState(name = "sumTsVal", type = "DOUBLE"), + @IntermediateState(name = "sumTsSq", type = "LONG") } +) +@GroupingAggregator +class DerivDoubleAggregator { + + public static SimpleLinearRegressionWithTimeseries initSingle(DriverContext driverContext) { + return new SimpleLinearRegressionWithTimeseries(); + } + + public static void combine(SimpleLinearRegressionWithTimeseries current, double value, long timestamp) { + current.add(timestamp, value); + } + + public static void combineIntermediate( + SimpleLinearRegressionWithTimeseries state, + long count, + double sumVal, + long sumTs, + double sumTsVal, + long sumTsSq + ) { + state.count += count; + state.sumVal += sumVal; + state.sumTs += sumTs; + state.sumTsVal += sumTsVal; + state.sumTsSq += sumTsSq; + } + + public static Block evaluateFinal(SimpleLinearRegressionWithTimeseries state, DriverContext driverContext) { + BlockFactory blockFactory = driverContext.blockFactory(); + var slope = state.slope(); + if (Double.isNaN(slope)) { + return blockFactory.newConstantNullBlock(1); + } + return blockFactory.newConstantDoubleBlockWith(slope, 1); + } + + public static GroupingState initGrouping(DriverContext driverContext) { + return new GroupingState(driverContext.bigArrays()); + } + + public static void combine(GroupingState state, int groupId, double value, long timestamp) { + state.getAndGrow(groupId).add(timestamp, value); + } + + public static void combineIntermediate( + GroupingState state, + int groupId, + long count, + double sumVal, + long sumTs, + double sumTsVal, + long sumTsSq + ) { + combineIntermediate(state.getAndGrow(groupId), count, sumVal, sumTs, sumTsVal, sumTsSq); + } + + public static Block evaluateFinal(GroupingState state, IntVector selectedGroups, GroupingAggregatorEvaluationContext ctx) { + try (DoubleBlock.Builder builder = ctx.driverContext().blockFactory().newDoubleBlockBuilder(selectedGroups.getPositionCount())) { + for (int i = 0; i < selectedGroups.getPositionCount(); i++) { + int groupId = selectedGroups.getInt(i); + SimpleLinearRegressionWithTimeseries slr = state.get(groupId); + if (slr == null) { + builder.appendNull(); + continue; + } + double result = slr.slope(); + if (Double.isNaN(result)) { + builder.appendNull(); + continue; + } + builder.appendDouble(result); + } + return builder.build(); + } + } + + public static final class GroupingState extends AbstractArrayState { + private ObjectArray states; + + GroupingState(BigArrays bigArrays) { + super(bigArrays); + states = bigArrays.newObjectArray(1); + } + + SimpleLinearRegressionWithTimeseries get(int groupId) { + if (groupId >= states.size()) { + return null; + } + return states.get(groupId); + } + + SimpleLinearRegressionWithTimeseries getAndGrow(int groupId) { + if (groupId >= states.size()) { + states = bigArrays.grow(states, groupId + 1); + } + SimpleLinearRegressionWithTimeseries slr = states.get(groupId); + if (slr == null) { + slr = new SimpleLinearRegressionWithTimeseries(); + states.set(groupId, slr); + } + return slr; + } + + @Override + public void close() { + Releasables.close(states, super::close); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + try ( + LongBlock.Builder countBuilder = driverContext.blockFactory().newLongBlockBuilder(selected.getPositionCount()); + DoubleBlock.Builder sumValBuilder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount()); + LongBlock.Builder sumTsBuilder = driverContext.blockFactory().newLongBlockBuilder(selected.getPositionCount()); + DoubleBlock.Builder sumTsValBuilder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount()); + LongBlock.Builder sumTsSqBuilder = driverContext.blockFactory().newLongBlockBuilder(selected.getPositionCount()) + ) { + for (int i = 0; i < selected.getPositionCount(); i++) { + int groupId = selected.getInt(i); + SimpleLinearRegressionWithTimeseries slr = get(groupId); + if (slr == null) { + countBuilder.appendNull(); + sumValBuilder.appendNull(); + sumTsBuilder.appendNull(); + sumTsValBuilder.appendNull(); + sumTsSqBuilder.appendNull(); + } else { + countBuilder.appendLong(slr.count); + sumValBuilder.appendDouble(slr.sumVal); + sumTsBuilder.appendLong(slr.sumTs); + sumTsValBuilder.appendDouble(slr.sumTsVal); + sumTsSqBuilder.appendLong(slr.sumTsSq); + } + } + blocks[offset] = countBuilder.build(); + blocks[offset + 1] = sumValBuilder.build(); + blocks[offset + 2] = sumTsBuilder.build(); + blocks[offset + 3] = sumTsValBuilder.build(); + blocks[offset + 4] = sumTsSqBuilder.build(); + } + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivIntAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivIntAggregator.java new file mode 100644 index 0000000000000..0b1dcc912ce45 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivIntAggregator.java @@ -0,0 +1,78 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.operator.DriverContext; + +@Aggregator( + { + @IntermediateState(name = "count", type = "LONG"), + @IntermediateState(name = "sumVal", type = "DOUBLE"), + @IntermediateState(name = "sumTs", type = "LONG"), + @IntermediateState(name = "sumTsVal", type = "DOUBLE"), + @IntermediateState(name = "sumTsSq", type = "LONG") } +) +@GroupingAggregator +class DerivIntAggregator { + + public static SimpleLinearRegressionWithTimeseries initSingle(DriverContext driverContext) { + return new SimpleLinearRegressionWithTimeseries(); + } + + public static void combine(SimpleLinearRegressionWithTimeseries current, int value, long timestamp) { + DerivDoubleAggregator.combine(current, value, timestamp); + } + + public static void combineIntermediate( + SimpleLinearRegressionWithTimeseries state, + long count, + double sumVal, + long sumTs, + double sumTsVal, + long sumTsSq + ) { + DerivDoubleAggregator.combineIntermediate(state, count, sumVal, sumTs, sumTsVal, sumTsSq); + } + + public static Block evaluateFinal(SimpleLinearRegressionWithTimeseries state, DriverContext driverContext) { + return DerivDoubleAggregator.evaluateFinal(state, driverContext); + } + + public static DerivDoubleAggregator.GroupingState initGrouping(DriverContext driverContext) { + return new DerivDoubleAggregator.GroupingState(driverContext.bigArrays()); + } + + public static void combine(DerivDoubleAggregator.GroupingState state, int groupId, int value, long timestamp) { + DerivDoubleAggregator.combine(state.getAndGrow(groupId), value, timestamp); + } + + public static void combineIntermediate( + DerivDoubleAggregator.GroupingState state, + int groupId, + long count, + double sumVal, + long sumTs, + double sumTsVal, + long sumTsSq + ) { + combineIntermediate(state.getAndGrow(groupId), count, sumVal, sumTs, sumTsVal, sumTsSq); + } + + public static Block evaluateFinal( + DerivDoubleAggregator.GroupingState state, + IntVector selectedGroups, + GroupingAggregatorEvaluationContext ctx + ) { + return DerivDoubleAggregator.evaluateFinal(state, selectedGroups, ctx); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivLongAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivLongAggregator.java new file mode 100644 index 0000000000000..b1572be80618c --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivLongAggregator.java @@ -0,0 +1,78 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.operator.DriverContext; + +@Aggregator( + { + @IntermediateState(name = "count", type = "LONG"), + @IntermediateState(name = "sumVal", type = "DOUBLE"), + @IntermediateState(name = "sumTs", type = "LONG"), + @IntermediateState(name = "sumTsVal", type = "DOUBLE"), + @IntermediateState(name = "sumTsSq", type = "LONG") } +) +@GroupingAggregator +class DerivLongAggregator { + + public static SimpleLinearRegressionWithTimeseries initSingle(DriverContext driverContext) { + return new SimpleLinearRegressionWithTimeseries(); + } + + public static void combine(SimpleLinearRegressionWithTimeseries current, long value, long timestamp) { + DerivDoubleAggregator.combine(current, (double) value, timestamp); + } + + public static void combineIntermediate( + SimpleLinearRegressionWithTimeseries state, + long count, + double sumVal, + long sumTs, + double sumTsVal, + long sumTsSq + ) { + DerivDoubleAggregator.combineIntermediate(state, count, sumVal, sumTs, sumTsVal, sumTsSq); + } + + public static Block evaluateFinal(SimpleLinearRegressionWithTimeseries state, DriverContext driverContext) { + return DerivDoubleAggregator.evaluateFinal(state, driverContext); + } + + public static DerivDoubleAggregator.GroupingState initGrouping(DriverContext driverContext) { + return new DerivDoubleAggregator.GroupingState(driverContext.bigArrays()); + } + + public static void combine(DerivDoubleAggregator.GroupingState state, int groupId, long value, long timestamp) { + DerivDoubleAggregator.combine(state.getAndGrow(groupId), (double) value, timestamp); + } + + public static void combineIntermediate( + DerivDoubleAggregator.GroupingState state, + int groupId, + long count, + double sumVal, + long sumTs, + double sumTsVal, + long sumTsSq + ) { + combineIntermediate(state.getAndGrow(groupId), count, sumVal, sumTs, sumTsVal, sumTsSq); + } + + public static Block evaluateFinal( + DerivDoubleAggregator.GroupingState state, + IntVector selectedGroups, + GroupingAggregatorEvaluationContext ctx + ) { + return DerivDoubleAggregator.evaluateFinal(state, selectedGroups, ctx); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunction.java index a60bcb1523ffc..319b32c6a1fc8 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunction.java @@ -149,6 +149,10 @@ default void add(int positionOffset, IntBlock groupIds) { * Build the final results for this aggregation. * @param selected the groupIds that have been selected to be included in * the results. Always ascending. + * + *

This function is called in the coordinator node after all intermediate + * results have been gathered from the worker nodes, and aggregated into + * intermediate blocks.

*/ void evaluateFinal(Block[] blocks, int offset, IntVector selected, GroupingAggregatorEvaluationContext evaluationContext); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SimpleLinearRegressionWithTimeseries.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SimpleLinearRegressionWithTimeseries.java new file mode 100644 index 0000000000000..9ed3a5bf2b081 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SimpleLinearRegressionWithTimeseries.java @@ -0,0 +1,72 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.operator.DriverContext; + +class SimpleLinearRegressionWithTimeseries implements AggregatorState { + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset + 0] = driverContext.blockFactory().newConstantLongBlockWith(count, 1); + blocks[offset + 1] = driverContext.blockFactory().newConstantDoubleBlockWith(sumVal, 1); + blocks[offset + 2] = driverContext.blockFactory().newConstantLongBlockWith(sumTs, 1); + blocks[offset + 3] = driverContext.blockFactory().newConstantDoubleBlockWith(sumTsVal, 1); + blocks[offset + 4] = driverContext.blockFactory().newConstantLongBlockWith(sumTsSq, 1); + } + + @Override + public void close() { + + } + + long count; + double sumVal; + long sumTs; + double sumTsVal; + long sumTsSq; + + SimpleLinearRegressionWithTimeseries() { + this.count = 0; + this.sumVal = 0.0; + this.sumTs = 0; + this.sumTsVal = 0.0; + this.sumTsSq = 0; + } + + void add(long ts, double val) { + count++; + sumVal += val; + sumTs += ts; + sumTsVal += ts * val; + sumTsSq += ts * ts; + } + + double slope() { + if (count <= 1) { + return Double.NaN; + } + double numerator = count * sumTsVal - sumTs * sumVal; + double denominator = count * sumTsSq - sumTs * sumTs; + if (denominator == 0) { + return Double.NaN; + } + return numerator / denominator * 1000.0; // per second + } + + double intercept() { + if (count == 0) { + return 0.0; // or handle as needed + } + var slp = slope(); + if (Double.isNaN(slp)) { + return Double.NaN; + } + return (sumVal - slp * sumTs) / count; + } + +} diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec index 64cdae20e5635..a09c1602aa101 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec @@ -594,6 +594,7 @@ max_cost:integer | pod:keyword | time_bucket:datetime | max_cluster_cost:int 973 | two | 2024-05-10T00:00:00.000Z | 1209 | staging ; + max_of_stddev_over_time required_capability: ts_command_v0 required_capability: variance_stddev_over_time @@ -720,4 +721,49 @@ mx:integer | tbucket:datetime 1716 | 2024-05-10T00:00:00.000Z ; +derivative_of_gauge_metric +required_capability: ts_command_v0 +required_capability: TS_LINREG_DERIVATIVE + +// tag::deriv[] +TS k8s +| WHERE pod == "three" +| STATS max_deriv = MAX(DERIV(network.cost)) BY time_bucket = BUCKET(@timestamp,5minute), pod +// end::deriv[] +| EVAL max_deriv = ROUND(max_deriv,6) +| KEEP max_deriv, time_bucket, pod +| SORT pod, time_bucket +| LIMIT 5 +; + +// tag::deriv-result[] +max_deriv:double | time_bucket:datetime | pod:keyword +0.092774 | 2024-05-10T00:00:00.000Z | three +0.038026 | 2024-05-10T00:05:00.000Z | three +-0.0197 | 2024-05-10T00:10:00.000Z | three +// end::deriv-result[] +-0.001209 | 2024-05-10T00:15:00.000Z | three +0.019397 | 2024-05-10T00:20:00.000Z | three + +; + +derivative_compared_to_rate +required_capability: ts_command_v0 +required_capability: TS_LINREG_DERIVATIVE + +TS k8s +| STATS max_deriv = max(deriv(to_long(network.total_bytes_in))), max_rate = max(rate(network.total_bytes_in)) BY time_bucket = bucket(@timestamp,5minute), cluster +| EVAL max_deriv = ROUND(max_deriv,6), max_rate = ROUND(max_rate,6) +| KEEP max_deriv, max_rate, time_bucket, cluster +| SORT cluster, time_bucket +| LIMIT 5 +; + +max_deriv:double | max_rate:double | time_bucket:datetime | cluster:keyword +85.5 | 8.120833 | 2024-05-10T00:00:00.000Z | prod +4.933168 | 6.451737 | 2024-05-10T00:05:00.000Z | prod +8.922491 | 11.56274 | 2024-05-10T00:10:00.000Z | prod +16.62316 | 11.86081 | 2024-05-10T00:15:00.000Z | prod +9.026268 | 6.980661 | 2024-05-10T00:20:00.000Z | prod +; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 6497ddfc6afbf..c7851f8b723c9 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -1498,6 +1498,7 @@ public enum Cap { */ PERCENTILE_OVER_TIME, VARIANCE_STDDEV_OVER_TIME, + TS_LINREG_DERIVATIVE, /** * INLINE STATS fix incorrect prunning of null filtering * https://github.com/elastic/elasticsearch/pull/135011 diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index f79c3b1e154ce..efe8e17497a38 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -30,6 +30,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.CountDistinctOverTime; import org.elasticsearch.xpack.esql.expression.function.aggregate.CountOverTime; import org.elasticsearch.xpack.esql.expression.function.aggregate.Delta; +import org.elasticsearch.xpack.esql.expression.function.aggregate.Deriv; import org.elasticsearch.xpack.esql.expression.function.aggregate.First; import org.elasticsearch.xpack.esql.expression.function.aggregate.FirstOverTime; import org.elasticsearch.xpack.esql.expression.function.aggregate.Idelta; @@ -537,6 +538,7 @@ private static FunctionDefinition[][] functions() { defTS(Idelta.class, bi(Idelta::new), "idelta"), defTS(Delta.class, bi(Delta::new), "delta"), defTS(Increase.class, bi(Increase::new), "increase"), + defTS(Deriv.class, bi(Deriv::new), "deriv"), def(MaxOverTime.class, uni(MaxOverTime::new), "max_over_time"), def(MinOverTime.class, uni(MinOverTime::new), "min_over_time"), def(SumOverTime.class, uni(SumOverTime::new), "sum_over_time"), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateWritables.java index 9867539dd88a4..4827ae08ff7ab 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateWritables.java @@ -30,6 +30,7 @@ public static List getNamedWriteables() { Idelta.ENTRY, Increase.ENTRY, Delta.ENTRY, + Deriv.ENTRY, Sample.ENTRY, SpatialCentroid.ENTRY, SpatialExtent.ENTRY, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java new file mode 100644 index 0000000000000..431f9e1c4a056 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java @@ -0,0 +1,122 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.DerivDoubleAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.DerivIntAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.DerivLongAggregatorFunctionSupplier; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.Example; +import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; +import org.elasticsearch.xpack.esql.expression.function.FunctionType; +import org.elasticsearch.xpack.esql.expression.function.Param; +import org.elasticsearch.xpack.esql.expression.function.TimestampAware; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.planner.ToAggregator; + +import java.util.List; + +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT; +import static org.elasticsearch.xpack.esql.core.type.DataType.AGGREGATE_METRIC_DOUBLE; + +/** + * Calculates the derivative over time of a numeric field using linear regression. + */ +public class Deriv extends TimeSeriesAggregateFunction implements ToAggregator, TimestampAware { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Deriv", Deriv::new); + private final Expression timestamp; + + @FunctionInfo( + type = FunctionType.TIME_SERIES_AGGREGATE, + returnType = { "double" }, + description = "Calculates the derivative over time of a numeric field using linear regression.", + examples = { @Example(file = "k8s-timeseries", tag = "deriv") } + ) + public Deriv(Source source, @Param(name = "field", type = { "long", "integer", "double" }) Expression field, Expression timestamp) { + this(source, field, Literal.TRUE, NO_WINDOW, timestamp); + } + + public Deriv(Source source, Expression field, Expression filter, Expression window, Expression timestamp) { + super(source, field, filter, window, List.of(timestamp)); + this.timestamp = timestamp; + } + + private Deriv(org.elasticsearch.common.io.stream.StreamInput in) throws java.io.IOException { + this( + Source.readFrom((PlanStreamInput) in), + in.readNamedWriteable(Expression.class), + in.readNamedWriteable(Expression.class), + in.readNamedWriteable(Expression.class), + in.readNamedWriteableCollectionAsList(Expression.class).getFirst() + ); + } + + @Override + public Expression timestamp() { + return timestamp; + } + + @Override + public AggregateFunction perTimeSeriesAggregation() { + return this; + } + + @Override + public AggregateFunction withFilter(Expression filter) { + return new Deriv(source(), field(), filter, window(), timestamp); + } + + @Override + public DataType dataType() { + return DataType.DOUBLE; + } + + @Override + public TypeResolution resolveType() { + return TypeResolutions.isType( + field(), + dt -> dt.isNumeric() && dt != AGGREGATE_METRIC_DOUBLE, + sourceText(), + DEFAULT, + "numeric except counter types" + ); + } + + @Override + public Expression replaceChildren(List newChildren) { + return new Deriv(source(), newChildren.get(0), newChildren.get(1), newChildren.get(2), newChildren.get(3)); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, Deriv::new, field(), filter(), window(), timestamp); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + @Override + public AggregatorFunctionSupplier supplier() { + final DataType type = field().dataType(); + return switch (type) { + case DOUBLE -> new DerivDoubleAggregatorFunctionSupplier(); + case LONG -> new DerivLongAggregatorFunctionSupplier(); + case INTEGER -> new DerivIntAggregatorFunctionSupplier(); + default -> throw new IllegalArgumentException("Unsupported data type for deriv aggregation: " + type); + }; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/DerivTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/DerivTests.java new file mode 100644 index 0000000000000..2c20c20a20d8f --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/DerivTests.java @@ -0,0 +1,139 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase; +import org.elasticsearch.xpack.esql.expression.function.DocsV3Support; +import org.elasticsearch.xpack.esql.expression.function.MultiRowTestCaseSupplier; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import org.hamcrest.Matcher; +import org.hamcrest.Matchers; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.function.Supplier; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; + +public class DerivTests extends AbstractFunctionTestCase { + public DerivTests(@Name("TestCase") Supplier testCaseSupplier) { + this.testCase = testCaseSupplier.get(); + } + + @ParametersFactory + public static Iterable parameters() { + var suppliers = new ArrayList(); + + var valuesSuppliers = List.of( + MultiRowTestCaseSupplier.longCases(1, 1000, 0, 1000_000_000, true), + MultiRowTestCaseSupplier.intCases(1, 1000, 0, 1000_000_000, true), + MultiRowTestCaseSupplier.doubleCases(1, 1000, 0, 1000_000_000, true) + + ); + for (List valuesSupplier : valuesSuppliers) { + for (TestCaseSupplier.TypedDataSupplier fieldSupplier : valuesSupplier) { + TestCaseSupplier testCaseSupplier = makeSupplier(fieldSupplier); + suppliers.add(testCaseSupplier); + } + } + List parameters = new ArrayList<>(suppliers.size()); + for (TestCaseSupplier supplier : suppliers) { + parameters.add(new Object[] { supplier }); + } + return parameters; + } + + @Override + protected Expression build(Source source, List args) { + return new Deriv(source, args.get(0), args.get(1), args.get(2), args.get(3)); + } + + @SuppressWarnings("unchecked") + private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier fieldSupplier) { + DataType type = fieldSupplier.type(); + return new TestCaseSupplier(fieldSupplier.name(), List.of(type, DataType.DATETIME, DataType.INTEGER, DataType.LONG), () -> { + TestCaseSupplier.TypedData fieldTypedData = fieldSupplier.get(); + List dataRows = fieldTypedData.multiRowData(); + if (randomBoolean()) { + List withNulls = new ArrayList<>(dataRows.size()); + for (Object dataRow : dataRows) { + if (randomBoolean()) { + withNulls.add(null); + } else { + withNulls.add(dataRow); + } + } + dataRows = withNulls; + } + fieldTypedData = TestCaseSupplier.TypedData.multiRow(dataRows, type, fieldTypedData.name()); + List timestamps = new ArrayList<>(); + List slices = new ArrayList<>(); + List maxTimestamps = new ArrayList<>(); + long lastTimestamp = randomLongBetween(0, 1_000_000); + for (int row = 0; row < dataRows.size(); row++) { + lastTimestamp += randomLongBetween(1, 10_000); + timestamps.add(lastTimestamp); + slices.add(0); + maxTimestamps.add(Long.MAX_VALUE); + } + TestCaseSupplier.TypedData timestampsField = TestCaseSupplier.TypedData.multiRow( + timestamps.reversed(), + DataType.DATETIME, + "timestamps" + ); + TestCaseSupplier.TypedData sliceIndexType = TestCaseSupplier.TypedData.multiRow(slices, DataType.INTEGER, "_slice_index"); + TestCaseSupplier.TypedData nextTimestampType = TestCaseSupplier.TypedData.multiRow( + maxTimestamps, + DataType.LONG, + "_max_timestamp" + ); + + List nonNullDataRows = dataRows.stream().filter(Objects::nonNull).toList(); + Matcher matcher; + if (nonNullDataRows.size() < 2) { + matcher = Matchers.nullValue(); + } else { + var lastValue = ((Number) nonNullDataRows.getFirst()).doubleValue(); + var secondLastValue = ((Number) nonNullDataRows.get(1)).doubleValue(); + var increase = lastValue >= secondLastValue ? lastValue - secondLastValue : lastValue; + var largestTimestamp = timestamps.get(0); + var secondLargestTimestamp = timestamps.get(1); + var smallestTimestamp = timestamps.getLast(); + matcher = Matchers.allOf( + Matchers.greaterThanOrEqualTo(increase / (largestTimestamp - smallestTimestamp) * 1000 * 0.9), + Matchers.lessThanOrEqualTo( + increase / (largestTimestamp - secondLargestTimestamp) * (largestTimestamp - smallestTimestamp) * 1000 + ) + ); + } + + return new TestCaseSupplier.TestCase( + List.of(fieldTypedData, timestampsField, sliceIndexType, nextTimestampType), + Matchers.stringContainsInOrder("GroupingAggregator", "Deriv", "GroupingAggregatorFunction"), + DataType.DOUBLE, + matcher + ); + }); + } + + public static List signatureTypes(List params) { + assertThat(params, hasSize(4)); + assertThat(params.get(1).dataType(), equalTo(DataType.DATETIME)); + assertThat(params.get(2).dataType(), equalTo(DataType.INTEGER)); + assertThat(params.get(3).dataType(), equalTo(DataType.LONG)); + return List.of(params.get(0)); + } +}