diff --git a/x-pack/plugin/esql/compute/ann/src/main/java/org/elasticsearch/compute/ann/Aggregator.java b/x-pack/plugin/esql/compute/ann/src/main/java/org/elasticsearch/compute/ann/Aggregator.java index 794baf1759204..8096153459003 100644 --- a/x-pack/plugin/esql/compute/ann/src/main/java/org/elasticsearch/compute/ann/Aggregator.java +++ b/x-pack/plugin/esql/compute/ann/src/main/java/org/elasticsearch/compute/ann/Aggregator.java @@ -58,4 +58,8 @@ */ Class[] warnExceptions() default {}; + /** + * If {@code true} then the @timestamp LongVector will be appended to the input blocks of the aggregation function. + */ + boolean includeTimestamps() default false; } diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java index a07796882d46e..8a02f8bc4c69a 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java @@ -54,6 +54,8 @@ import static org.elasticsearch.compute.gen.Types.INTERMEDIATE_STATE_DESC; import static org.elasticsearch.compute.gen.Types.LIST_AGG_FUNC_DESC; import static org.elasticsearch.compute.gen.Types.LIST_INTEGER; +import static org.elasticsearch.compute.gen.Types.LONG_BLOCK; +import static org.elasticsearch.compute.gen.Types.LONG_VECTOR; import static org.elasticsearch.compute.gen.Types.PAGE; import static org.elasticsearch.compute.gen.Types.WARNINGS; import static org.elasticsearch.compute.gen.Types.blockType; @@ -73,9 +75,10 @@ public class AggregatorImplementer { private final List warnExceptions; private final ExecutableElement init; private final ExecutableElement combine; + private final List createParameters; private final ClassName implementation; private final List intermediateState; - private final List createParameters; + private final boolean includeTimestampVector; private final AggregationState aggState; private final AggregationParameter aggParam; @@ -84,7 +87,8 @@ public AggregatorImplementer( Elements elements, TypeElement declarationType, IntermediateState[] interStateAnno, - List warnExceptions + List warnExceptions, + boolean includeTimestampVector ) { this.declarationType = declarationType; this.warnExceptions = warnExceptions; @@ -102,10 +106,10 @@ public AggregatorImplementer( declarationType, aggState.declaredType().isPrimitive() ? requireType(aggState.declaredType()) : requireVoidType(), requireName("combine"), - requireArgs(requireType(aggState.declaredType()), requireAnyType("")) + combineArgs(aggState, includeTimestampVector) ); // TODO support multiple parameters - this.aggParam = AggregationParameter.create(combine.getParameters().get(1).asType()); + this.aggParam = AggregationParameter.create(combine.getParameters().getLast().asType()); this.createParameters = init.getParameters() .stream() @@ -117,7 +121,20 @@ public AggregatorImplementer( elements.getPackageOf(declarationType).toString(), (declarationType.getSimpleName() + "AggregatorFunction").replace("AggregatorAggregator", "Aggregator") ); - intermediateState = Arrays.stream(interStateAnno).map(IntermediateStateDesc::newIntermediateStateDesc).toList(); + this.intermediateState = Arrays.stream(interStateAnno).map(IntermediateStateDesc::newIntermediateStateDesc).toList(); + this.includeTimestampVector = includeTimestampVector; + } + + private static Methods.ArgumentMatcher combineArgs(AggregationState aggState, boolean includeTimestampVector) { + if (includeTimestampVector) { + return requireArgs( + requireType(aggState.declaredType()), + requireType(TypeName.LONG), // @timestamp + requireAnyType("") + ); + } else { + return requireArgs(requireType(aggState.declaredType()), requireAnyType("")); + } } ClassName implementation() { @@ -295,10 +312,18 @@ private MethodSpec addRawInput() { builder.addComment("No masking"); builder.addStatement("$T block = page.getBlock(channels.get(0))", blockType(aggParam.type())); builder.addStatement("$T vector = block.asVector()", vectorType(aggParam.type())); + if (includeTimestampVector) { + builder.addStatement("$T timestampsBlock = page.getBlock(channels.get(1))", LONG_BLOCK); + builder.addStatement("$T timestampsVector = timestampsBlock.asVector()", LONG_VECTOR); + + builder.beginControlFlow("if (timestampsVector == null) "); + builder.addStatement("throw new IllegalStateException($S)", "expected @timestamp vector; but got a block"); + builder.endControlFlow(); + } builder.beginControlFlow("if (vector != null)"); - builder.addStatement("addRawVector(vector)"); + builder.addStatement(includeTimestampVector ? "addRawVector(vector, timestampsVector)" : "addRawVector(vector)"); builder.nextControlFlow("else"); - builder.addStatement("addRawBlock(block)"); + builder.addStatement(includeTimestampVector ? "addRawBlock(block, timestampsVector)" : "addRawBlock(block)"); builder.endControlFlow(); builder.addStatement("return"); } @@ -307,10 +332,18 @@ private MethodSpec addRawInput() { builder.addComment("Some positions masked away, others kept"); builder.addStatement("$T block = page.getBlock(channels.get(0))", blockType(aggParam.type())); builder.addStatement("$T vector = block.asVector()", vectorType(aggParam.type())); + if (includeTimestampVector) { + builder.addStatement("$T timestampsBlock = page.getBlock(channels.get(1))", LONG_BLOCK); + builder.addStatement("$T timestampsVector = timestampsBlock.asVector()", LONG_VECTOR); + + builder.beginControlFlow("if (timestampsVector == null) "); + builder.addStatement("throw new IllegalStateException($S)", "expected @timestamp vector; but got a block"); + builder.endControlFlow(); + } builder.beginControlFlow("if (vector != null)"); - builder.addStatement("addRawVector(vector, mask)"); + builder.addStatement(includeTimestampVector ? "addRawVector(vector, timestampsVector, mask)" : "addRawVector(vector, mask)"); builder.nextControlFlow("else"); - builder.addStatement("addRawBlock(block, mask)"); + builder.addStatement(includeTimestampVector ? "addRawBlock(block, timestampsVector, mask)" : "addRawBlock(block, mask)"); builder.endControlFlow(); return builder.build(); } @@ -318,6 +351,9 @@ private MethodSpec addRawInput() { private MethodSpec addRawVector(boolean masked) { MethodSpec.Builder builder = MethodSpec.methodBuilder("addRawVector"); builder.addModifiers(Modifier.PRIVATE).addParameter(vectorType(aggParam.type()), "vector"); + if (includeTimestampVector) { + builder.addParameter(LONG_VECTOR, "timestamps"); + } if (masked) { builder.addParameter(BOOLEAN_VECTOR, "mask"); } @@ -348,6 +384,9 @@ private MethodSpec addRawVector(boolean masked) { private MethodSpec addRawBlock(boolean masked) { MethodSpec.Builder builder = MethodSpec.methodBuilder("addRawBlock"); builder.addModifiers(Modifier.PRIVATE).addParameter(blockType(aggParam.type()), "block"); + if (includeTimestampVector) { + builder.addParameter(LONG_VECTOR, "timestamps"); + } if (masked) { builder.addParameter(BOOLEAN_VECTOR, "mask"); } @@ -401,33 +440,57 @@ private void combineRawInput(MethodSpec.Builder builder, String blockVariable) { }); } - private void combineRawInputForPrimitive(TypeName returnType, MethodSpec.Builder builder, String blockVariable) { - builder.addStatement( - "state.$TValue($T.combine(state.$TValue(), $L.get$L(i)))", - returnType, - declarationType, - returnType, - blockVariable, - capitalize(combine.getParameters().get(1).asType().toString()) - ); + private void combineRawInputForBytesRef(MethodSpec.Builder builder, String blockVariable) { + // scratch is a BytesRef var that must have been defined before the iteration starts + if (includeTimestampVector) { + builder.addStatement("$T.combine(state, timestamps.getLong(i), $L.getBytesRef(i, scratch))", declarationType, blockVariable); + } else { + builder.addStatement("$T.combine(state, $L.getBytesRef(i, scratch))", declarationType, blockVariable); + } } - private void combineRawInputForArray(MethodSpec.Builder builder, String arrayVariable) { - warningsBlock(builder, () -> builder.addStatement("$T.combine(state, $L)", declarationType, arrayVariable)); + private void combineRawInputForPrimitive(TypeName returnType, MethodSpec.Builder builder, String blockVariable) { + if (includeTimestampVector) { + builder.addStatement( + "state.$TValue($T.combine(state.$TValue(), timestamps.getLong(i), $L.get$L(i)))", + returnType, + declarationType, + returnType, + blockVariable, + capitalize(combine.getParameters().get(1).asType().toString()) + ); + } else { + builder.addStatement( + "state.$TValue($T.combine(state.$TValue(), $L.get$L(i)))", + returnType, + declarationType, + returnType, + blockVariable, + capitalize(combine.getParameters().get(1).asType().toString()) + ); + } } private void combineRawInputForVoid(MethodSpec.Builder builder, String blockVariable) { - builder.addStatement( - "$T.combine(state, $L.get$L(i))", - declarationType, - blockVariable, - capitalize(combine.getParameters().get(1).asType().toString()) - ); + if (includeTimestampVector) { + builder.addStatement( + "$T.combine(state, timestamps.getLong(i), $L.get$L(i))", + declarationType, + blockVariable, + capitalize(combine.getParameters().get(1).asType().toString()) + ); + } else { + builder.addStatement( + "$T.combine(state, $L.get$L(i))", + declarationType, + blockVariable, + capitalize(combine.getParameters().get(1).asType().toString()) + ); + } } - private void combineRawInputForBytesRef(MethodSpec.Builder builder, String blockVariable) { - // scratch is a BytesRef var that must have been defined before the iteration starts - builder.addStatement("$T.combine(state, $L.getBytesRef(i, scratch))", declarationType, blockVariable); + private void combineRawInputForArray(MethodSpec.Builder builder, String arrayVariable) { + warningsBlock(builder, () -> builder.addStatement("$T.combine(state, $L)", declarationType, arrayVariable)); } private void warningsBlock(MethodSpec.Builder builder, Runnable block) { diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorProcessor.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorProcessor.java index 863db86eb934a..3ad2343ad1658 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorProcessor.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorProcessor.java @@ -87,7 +87,13 @@ public boolean process(Set set, RoundEnvironment roundEnv ); if (aggClass.getAnnotation(Aggregator.class) != null) { IntermediateState[] intermediateState = aggClass.getAnnotation(Aggregator.class).value(); - implementer = new AggregatorImplementer(env.getElementUtils(), aggClass, intermediateState, warnExceptionsTypes); + implementer = new AggregatorImplementer( + env.getElementUtils(), + aggClass, + intermediateState, + warnExceptionsTypes, + aggClass.getAnnotation(Aggregator.class).includeTimestamps() + ); write(aggClass, "aggregator", implementer.sourceFile(), env); } GroupingAggregatorImplementer groupingAggregatorImplementer = null; @@ -96,13 +102,12 @@ public boolean process(Set set, RoundEnvironment roundEnv if (intermediateState.length == 0 && aggClass.getAnnotation(Aggregator.class) != null) { intermediateState = aggClass.getAnnotation(Aggregator.class).value(); } - boolean includeTimestamps = aggClass.getAnnotation(GroupingAggregator.class).includeTimestamps(); groupingAggregatorImplementer = new GroupingAggregatorImplementer( env.getElementUtils(), aggClass, intermediateState, warnExceptionsTypes, - includeTimestamps + aggClass.getAnnotation(GroupingAggregator.class).includeTimestamps() ); write(aggClass, "grouping aggregator", groupingAggregatorImplementer.sourceFile(), env); } diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java index 180589fc64373..fc377f22bbbc6 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java @@ -112,7 +112,8 @@ public GroupingAggregatorImplementer( requireName("combine"), combineArgs(aggState, includeTimestampVector) ); - this.aggParam = AggregationParameter.create(combine.getParameters().get(combine.getParameters().size() - 1).asType()); + // TODO support multiple parameters + this.aggParam = AggregationParameter.create(combine.getParameters().getLast().asType()); this.createParameters = init.getParameters() .stream() @@ -125,7 +126,7 @@ public GroupingAggregatorImplementer( (declarationType.getSimpleName() + "GroupingAggregatorFunction").replace("AggregatorGroupingAggregator", "GroupingAggregator") ); - intermediateState = Arrays.stream(interStateAnno) + this.intermediateState = Arrays.stream(interStateAnno) .map(AggregatorImplementer.IntermediateStateDesc::newIntermediateStateDesc) .toList(); this.includeTimestampVector = includeTimestampVector; @@ -370,8 +371,7 @@ private TypeSpec addInput(Consumer addBlock) { private MethodSpec addRawInputLoop(TypeName groupsType, TypeName valuesType) { boolean groupsIsBlock = groupsType.toString().endsWith("Block"); boolean valuesIsBlock = valuesType.toString().endsWith("Block"); - String methodName = "addRawInput"; - MethodSpec.Builder builder = MethodSpec.methodBuilder(methodName); + MethodSpec.Builder builder = MethodSpec.methodBuilder("addRawInput"); builder.addModifiers(Modifier.PRIVATE); builder.addParameter(TypeName.INT, "positionOffset").addParameter(groupsType, "groups").addParameter(valuesType, "values"); if (includeTimestampVector) { @@ -443,8 +443,6 @@ private void combineRawInput(MethodSpec.Builder builder, String blockVariable, S warningsBlock(builder, () -> { if (aggParam.isBytesRef()) { combineRawInputForBytesRef(builder, blockVariable, offsetVariable); - } else if (includeTimestampVector) { - combineRawInputWithTimestamp(builder, offsetVariable); } else if (valueType.isPrimitive() == false) { throw new IllegalArgumentException("second parameter to combine must be a primitive, array or BytesRef: " + valueType); } else if (returnType.isPrimitive()) { @@ -457,48 +455,75 @@ private void combineRawInput(MethodSpec.Builder builder, String blockVariable, S }); } - private void combineRawInputForPrimitive(MethodSpec.Builder builder, String blockVariable, String offsetVariable) { - builder.addStatement( - "state.set(groupId, $T.combine(state.getOrDefault(groupId), $L.get$L($L)))", - declarationType, - blockVariable, - capitalize(aggParam.type().toString()), - offsetVariable - ); + private void combineRawInputForBytesRef(MethodSpec.Builder builder, String blockVariable, String offsetVariable) { + // scratch is a BytesRef var that must have been defined before the iteration starts + if (includeTimestampVector) { + if (offsetVariable.contains(" + ")) { + builder.addStatement("var valuePosition = $L", offsetVariable); + offsetVariable = "valuePosition"; + } + builder.addStatement( + "$T.combine(state, groupId, timestamps.getLong($L), $L.getBytesRef($L, scratch))", + declarationType, + offsetVariable, + blockVariable, + offsetVariable + ); + } else { + builder.addStatement("$T.combine(state, groupId, $L.getBytesRef($L, scratch))", declarationType, blockVariable, offsetVariable); + } } - private void combineRawInputForArray(MethodSpec.Builder builder, String arrayVariable) { - warningsBlock(builder, () -> builder.addStatement("$T.combine(state, groupId, $L)", declarationType, arrayVariable)); + private void combineRawInputForPrimitive(MethodSpec.Builder builder, String blockVariable, String offsetVariable) { + if (includeTimestampVector) { + if (offsetVariable.contains(" + ")) { + builder.addStatement("var valuePosition = $L", offsetVariable); + offsetVariable = "valuePosition"; + } + builder.addStatement( + "$T.combine(state, groupId, timestamps.getLong($L), values.get$L($L))", + declarationType, + offsetVariable, + capitalize(aggParam.type().toString()), + offsetVariable + ); + } else { + builder.addStatement( + "state.set(groupId, $T.combine(state.getOrDefault(groupId), $L.get$L($L)))", + declarationType, + blockVariable, + capitalize(aggParam.type().toString()), + offsetVariable + ); + } } private void combineRawInputForVoid(MethodSpec.Builder builder, String blockVariable, String offsetVariable) { - builder.addStatement( - "$T.combine(state, groupId, $L.get$L($L))", - declarationType, - blockVariable, - capitalize(aggParam.type().toString()), - offsetVariable - ); - } - - private void combineRawInputWithTimestamp(MethodSpec.Builder builder, String offsetVariable) { - String blockType = capitalize(aggParam.type().toString()); - if (offsetVariable.contains(" + ")) { - builder.addStatement("var valuePosition = $L", offsetVariable); - offsetVariable = "valuePosition"; + if (includeTimestampVector) { + if (offsetVariable.contains(" + ")) { + builder.addStatement("var valuePosition = $L", offsetVariable); + offsetVariable = "valuePosition"; + } + builder.addStatement( + "$T.combine(state, groupId, timestamps.getLong($L), values.get$L($L))", + declarationType, + offsetVariable, + capitalize(aggParam.type().toString()), + offsetVariable + ); + } else { + builder.addStatement( + "$T.combine(state, groupId, $L.get$L($L))", + declarationType, + blockVariable, + capitalize(aggParam.type().toString()), + offsetVariable + ); } - builder.addStatement( - "$T.combine(state, groupId, timestamps.getLong($L), values.get$L($L))", - declarationType, - offsetVariable, - blockType, - offsetVariable - ); } - private void combineRawInputForBytesRef(MethodSpec.Builder builder, String blockVariable, String offsetVariable) { - // scratch is a BytesRef var that must have been defined before the iteration starts - builder.addStatement("$T.combine(state, groupId, $L.getBytesRef($L, scratch))", declarationType, blockVariable, offsetVariable); + private void combineRawInputForArray(MethodSpec.Builder builder, String arrayVariable) { + warningsBlock(builder, () -> builder.addStatement("$T.combine(state, groupId, $L)", declarationType, arrayVariable)); } private void warningsBlock(MethodSpec.Builder builder, Runnable block) {