Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,8 @@
*/
Class<? extends Exception>[] warnExceptions() default {};

/**
* If {@code true} then the @timestamp LongVector will be appended to the input blocks of the aggregation function.
*/
boolean includeTimestamps() default false;
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
import static org.elasticsearch.compute.gen.Types.INTERMEDIATE_STATE_DESC;
import static org.elasticsearch.compute.gen.Types.LIST_AGG_FUNC_DESC;
import static org.elasticsearch.compute.gen.Types.LIST_INTEGER;
import static org.elasticsearch.compute.gen.Types.LONG_BLOCK;
import static org.elasticsearch.compute.gen.Types.LONG_VECTOR;
import static org.elasticsearch.compute.gen.Types.PAGE;
import static org.elasticsearch.compute.gen.Types.WARNINGS;
import static org.elasticsearch.compute.gen.Types.blockType;
Expand All @@ -73,9 +75,10 @@ public class AggregatorImplementer {
private final List<TypeMirror> warnExceptions;
private final ExecutableElement init;
private final ExecutableElement combine;
private final List<Parameter> createParameters;
private final ClassName implementation;
private final List<IntermediateStateDesc> intermediateState;
private final List<Parameter> createParameters;
private final boolean includeTimestampVector;

private final AggregationState aggState;
private final AggregationParameter aggParam;
Expand All @@ -84,7 +87,8 @@ public AggregatorImplementer(
Elements elements,
TypeElement declarationType,
IntermediateState[] interStateAnno,
List<TypeMirror> warnExceptions
List<TypeMirror> warnExceptions,
boolean includeTimestampVector
) {
this.declarationType = declarationType;
this.warnExceptions = warnExceptions;
Expand All @@ -102,10 +106,10 @@ public AggregatorImplementer(
declarationType,
aggState.declaredType().isPrimitive() ? requireType(aggState.declaredType()) : requireVoidType(),
requireName("combine"),
requireArgs(requireType(aggState.declaredType()), requireAnyType("<aggregation input column type>"))
combineArgs(aggState, includeTimestampVector)
);
// TODO support multiple parameters
this.aggParam = AggregationParameter.create(combine.getParameters().get(1).asType());
this.aggParam = AggregationParameter.create(combine.getParameters().getLast().asType());

this.createParameters = init.getParameters()
.stream()
Expand All @@ -117,7 +121,20 @@ public AggregatorImplementer(
elements.getPackageOf(declarationType).toString(),
(declarationType.getSimpleName() + "AggregatorFunction").replace("AggregatorAggregator", "Aggregator")
);
intermediateState = Arrays.stream(interStateAnno).map(IntermediateStateDesc::newIntermediateStateDesc).toList();
this.intermediateState = Arrays.stream(interStateAnno).map(IntermediateStateDesc::newIntermediateStateDesc).toList();
this.includeTimestampVector = includeTimestampVector;
}

private static Methods.ArgumentMatcher combineArgs(AggregationState aggState, boolean includeTimestampVector) {
if (includeTimestampVector) {
return requireArgs(
requireType(aggState.declaredType()),
requireType(TypeName.LONG), // @timestamp
requireAnyType("<aggregation input column type>")
);
} else {
return requireArgs(requireType(aggState.declaredType()), requireAnyType("<aggregation input column type>"));
}
}

ClassName implementation() {
Expand Down Expand Up @@ -295,10 +312,18 @@ private MethodSpec addRawInput() {
builder.addComment("No masking");
builder.addStatement("$T block = page.getBlock(channels.get(0))", blockType(aggParam.type()));
builder.addStatement("$T vector = block.asVector()", vectorType(aggParam.type()));
if (includeTimestampVector) {
builder.addStatement("$T timestampsBlock = page.getBlock(channels.get(1))", LONG_BLOCK);
builder.addStatement("$T timestampsVector = timestampsBlock.asVector()", LONG_VECTOR);

builder.beginControlFlow("if (timestampsVector == null) ");
builder.addStatement("throw new IllegalStateException($S)", "expected @timestamp vector; but got a block");
builder.endControlFlow();
}
builder.beginControlFlow("if (vector != null)");
builder.addStatement("addRawVector(vector)");
builder.addStatement(includeTimestampVector ? "addRawVector(vector, timestampsVector)" : "addRawVector(vector)");
builder.nextControlFlow("else");
builder.addStatement("addRawBlock(block)");
builder.addStatement(includeTimestampVector ? "addRawBlock(block, timestampsVector)" : "addRawBlock(block)");
builder.endControlFlow();
builder.addStatement("return");
}
Expand All @@ -307,17 +332,28 @@ private MethodSpec addRawInput() {
builder.addComment("Some positions masked away, others kept");
builder.addStatement("$T block = page.getBlock(channels.get(0))", blockType(aggParam.type()));
builder.addStatement("$T vector = block.asVector()", vectorType(aggParam.type()));
if (includeTimestampVector) {
builder.addStatement("$T timestampsBlock = page.getBlock(channels.get(1))", LONG_BLOCK);
builder.addStatement("$T timestampsVector = timestampsBlock.asVector()", LONG_VECTOR);

builder.beginControlFlow("if (timestampsVector == null) ");
builder.addStatement("throw new IllegalStateException($S)", "expected @timestamp vector; but got a block");
builder.endControlFlow();
}
builder.beginControlFlow("if (vector != null)");
builder.addStatement("addRawVector(vector, mask)");
builder.addStatement(includeTimestampVector ? "addRawVector(vector, timestampsVector, mask)" : "addRawVector(vector, mask)");
builder.nextControlFlow("else");
builder.addStatement("addRawBlock(block, mask)");
builder.addStatement(includeTimestampVector ? "addRawBlock(block, timestampsVector, mask)" : "addRawBlock(block, mask)");
builder.endControlFlow();
return builder.build();
}

private MethodSpec addRawVector(boolean masked) {
MethodSpec.Builder builder = MethodSpec.methodBuilder("addRawVector");
builder.addModifiers(Modifier.PRIVATE).addParameter(vectorType(aggParam.type()), "vector");
if (includeTimestampVector) {
builder.addParameter(LONG_VECTOR, "timestamps");
}
if (masked) {
builder.addParameter(BOOLEAN_VECTOR, "mask");
}
Expand Down Expand Up @@ -348,6 +384,9 @@ private MethodSpec addRawVector(boolean masked) {
private MethodSpec addRawBlock(boolean masked) {
MethodSpec.Builder builder = MethodSpec.methodBuilder("addRawBlock");
builder.addModifiers(Modifier.PRIVATE).addParameter(blockType(aggParam.type()), "block");
if (includeTimestampVector) {
builder.addParameter(LONG_VECTOR, "timestamps");
}
if (masked) {
builder.addParameter(BOOLEAN_VECTOR, "mask");
}
Expand Down Expand Up @@ -401,33 +440,57 @@ private void combineRawInput(MethodSpec.Builder builder, String blockVariable) {
});
}

private void combineRawInputForPrimitive(TypeName returnType, MethodSpec.Builder builder, String blockVariable) {
builder.addStatement(
"state.$TValue($T.combine(state.$TValue(), $L.get$L(i)))",
returnType,
declarationType,
returnType,
blockVariable,
capitalize(combine.getParameters().get(1).asType().toString())
);
private void combineRawInputForBytesRef(MethodSpec.Builder builder, String blockVariable) {
// scratch is a BytesRef var that must have been defined before the iteration starts
if (includeTimestampVector) {
builder.addStatement("$T.combine(state, timestamps.getLong(i), $L.getBytesRef(i, scratch))", declarationType, blockVariable);
} else {
builder.addStatement("$T.combine(state, $L.getBytesRef(i, scratch))", declarationType, blockVariable);
}
}

private void combineRawInputForArray(MethodSpec.Builder builder, String arrayVariable) {
warningsBlock(builder, () -> builder.addStatement("$T.combine(state, $L)", declarationType, arrayVariable));
private void combineRawInputForPrimitive(TypeName returnType, MethodSpec.Builder builder, String blockVariable) {
if (includeTimestampVector) {
builder.addStatement(
"state.$TValue($T.combine(state.$TValue(), timestamps.getLong(i), $L.get$L(i)))",
returnType,
declarationType,
returnType,
blockVariable,
capitalize(combine.getParameters().get(1).asType().toString())
);
} else {
builder.addStatement(
"state.$TValue($T.combine(state.$TValue(), $L.get$L(i)))",
returnType,
declarationType,
returnType,
blockVariable,
capitalize(combine.getParameters().get(1).asType().toString())
);
}
}

private void combineRawInputForVoid(MethodSpec.Builder builder, String blockVariable) {
builder.addStatement(
"$T.combine(state, $L.get$L(i))",
declarationType,
blockVariable,
capitalize(combine.getParameters().get(1).asType().toString())
);
if (includeTimestampVector) {
builder.addStatement(
"$T.combine(state, timestamps.getLong(i), $L.get$L(i))",
declarationType,
blockVariable,
capitalize(combine.getParameters().get(1).asType().toString())
);
} else {
builder.addStatement(
"$T.combine(state, $L.get$L(i))",
declarationType,
blockVariable,
capitalize(combine.getParameters().get(1).asType().toString())
);
}
}

private void combineRawInputForBytesRef(MethodSpec.Builder builder, String blockVariable) {
// scratch is a BytesRef var that must have been defined before the iteration starts
builder.addStatement("$T.combine(state, $L.getBytesRef(i, scratch))", declarationType, blockVariable);
private void combineRawInputForArray(MethodSpec.Builder builder, String arrayVariable) {
warningsBlock(builder, () -> builder.addStatement("$T.combine(state, $L)", declarationType, arrayVariable));
}

private void warningsBlock(MethodSpec.Builder builder, Runnable block) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,13 @@ public boolean process(Set<? extends TypeElement> set, RoundEnvironment roundEnv
);
if (aggClass.getAnnotation(Aggregator.class) != null) {
IntermediateState[] intermediateState = aggClass.getAnnotation(Aggregator.class).value();
implementer = new AggregatorImplementer(env.getElementUtils(), aggClass, intermediateState, warnExceptionsTypes);
implementer = new AggregatorImplementer(
env.getElementUtils(),
aggClass,
intermediateState,
warnExceptionsTypes,
aggClass.getAnnotation(Aggregator.class).includeTimestamps()
);
write(aggClass, "aggregator", implementer.sourceFile(), env);
}
GroupingAggregatorImplementer groupingAggregatorImplementer = null;
Expand All @@ -96,13 +102,12 @@ public boolean process(Set<? extends TypeElement> set, RoundEnvironment roundEnv
if (intermediateState.length == 0 && aggClass.getAnnotation(Aggregator.class) != null) {
intermediateState = aggClass.getAnnotation(Aggregator.class).value();
}
boolean includeTimestamps = aggClass.getAnnotation(GroupingAggregator.class).includeTimestamps();
groupingAggregatorImplementer = new GroupingAggregatorImplementer(
env.getElementUtils(),
aggClass,
intermediateState,
warnExceptionsTypes,
includeTimestamps
aggClass.getAnnotation(GroupingAggregator.class).includeTimestamps()
);
write(aggClass, "grouping aggregator", groupingAggregatorImplementer.sourceFile(), env);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ public GroupingAggregatorImplementer(
requireName("combine"),
combineArgs(aggState, includeTimestampVector)
);
this.aggParam = AggregationParameter.create(combine.getParameters().get(combine.getParameters().size() - 1).asType());
// TODO support multiple parameters
this.aggParam = AggregationParameter.create(combine.getParameters().getLast().asType());

this.createParameters = init.getParameters()
.stream()
Expand All @@ -125,7 +126,7 @@ public GroupingAggregatorImplementer(
(declarationType.getSimpleName() + "GroupingAggregatorFunction").replace("AggregatorGroupingAggregator", "GroupingAggregator")
);

intermediateState = Arrays.stream(interStateAnno)
this.intermediateState = Arrays.stream(interStateAnno)
.map(AggregatorImplementer.IntermediateStateDesc::newIntermediateStateDesc)
.toList();
this.includeTimestampVector = includeTimestampVector;
Expand Down Expand Up @@ -370,8 +371,7 @@ private TypeSpec addInput(Consumer<MethodSpec.Builder> addBlock) {
private MethodSpec addRawInputLoop(TypeName groupsType, TypeName valuesType) {
boolean groupsIsBlock = groupsType.toString().endsWith("Block");
boolean valuesIsBlock = valuesType.toString().endsWith("Block");
String methodName = "addRawInput";
MethodSpec.Builder builder = MethodSpec.methodBuilder(methodName);
MethodSpec.Builder builder = MethodSpec.methodBuilder("addRawInput");
builder.addModifiers(Modifier.PRIVATE);
builder.addParameter(TypeName.INT, "positionOffset").addParameter(groupsType, "groups").addParameter(valuesType, "values");
if (includeTimestampVector) {
Expand Down Expand Up @@ -443,8 +443,6 @@ private void combineRawInput(MethodSpec.Builder builder, String blockVariable, S
warningsBlock(builder, () -> {
if (aggParam.isBytesRef()) {
combineRawInputForBytesRef(builder, blockVariable, offsetVariable);
} else if (includeTimestampVector) {
combineRawInputWithTimestamp(builder, offsetVariable);
} else if (valueType.isPrimitive() == false) {
throw new IllegalArgumentException("second parameter to combine must be a primitive, array or BytesRef: " + valueType);
} else if (returnType.isPrimitive()) {
Expand All @@ -457,48 +455,75 @@ private void combineRawInput(MethodSpec.Builder builder, String blockVariable, S
});
}

private void combineRawInputForPrimitive(MethodSpec.Builder builder, String blockVariable, String offsetVariable) {
builder.addStatement(
"state.set(groupId, $T.combine(state.getOrDefault(groupId), $L.get$L($L)))",
declarationType,
blockVariable,
capitalize(aggParam.type().toString()),
offsetVariable
);
private void combineRawInputForBytesRef(MethodSpec.Builder builder, String blockVariable, String offsetVariable) {
// scratch is a BytesRef var that must have been defined before the iteration starts
if (includeTimestampVector) {
if (offsetVariable.contains(" + ")) {
builder.addStatement("var valuePosition = $L", offsetVariable);
offsetVariable = "valuePosition";
}
builder.addStatement(
"$T.combine(state, groupId, timestamps.getLong($L), $L.getBytesRef($L, scratch))",
declarationType,
offsetVariable,
blockVariable,
offsetVariable
);
} else {
builder.addStatement("$T.combine(state, groupId, $L.getBytesRef($L, scratch))", declarationType, blockVariable, offsetVariable);
}
}

private void combineRawInputForArray(MethodSpec.Builder builder, String arrayVariable) {
warningsBlock(builder, () -> builder.addStatement("$T.combine(state, groupId, $L)", declarationType, arrayVariable));
private void combineRawInputForPrimitive(MethodSpec.Builder builder, String blockVariable, String offsetVariable) {
if (includeTimestampVector) {
if (offsetVariable.contains(" + ")) {
builder.addStatement("var valuePosition = $L", offsetVariable);
offsetVariable = "valuePosition";
}
builder.addStatement(
"$T.combine(state, groupId, timestamps.getLong($L), values.get$L($L))",
declarationType,
offsetVariable,
capitalize(aggParam.type().toString()),
offsetVariable
);
} else {
builder.addStatement(
"state.set(groupId, $T.combine(state.getOrDefault(groupId), $L.get$L($L)))",
declarationType,
blockVariable,
capitalize(aggParam.type().toString()),
offsetVariable
);
}
}

private void combineRawInputForVoid(MethodSpec.Builder builder, String blockVariable, String offsetVariable) {
builder.addStatement(
"$T.combine(state, groupId, $L.get$L($L))",
declarationType,
blockVariable,
capitalize(aggParam.type().toString()),
offsetVariable
);
}

private void combineRawInputWithTimestamp(MethodSpec.Builder builder, String offsetVariable) {
String blockType = capitalize(aggParam.type().toString());
if (offsetVariable.contains(" + ")) {
builder.addStatement("var valuePosition = $L", offsetVariable);
offsetVariable = "valuePosition";
if (includeTimestampVector) {
if (offsetVariable.contains(" + ")) {
builder.addStatement("var valuePosition = $L", offsetVariable);
offsetVariable = "valuePosition";
}
builder.addStatement(
"$T.combine(state, groupId, timestamps.getLong($L), values.get$L($L))",
declarationType,
offsetVariable,
capitalize(aggParam.type().toString()),
offsetVariable
);
} else {
builder.addStatement(
"$T.combine(state, groupId, $L.get$L($L))",
declarationType,
blockVariable,
capitalize(aggParam.type().toString()),
offsetVariable
);
}
builder.addStatement(
"$T.combine(state, groupId, timestamps.getLong($L), values.get$L($L))",
declarationType,
offsetVariable,
blockType,
offsetVariable
);
}

private void combineRawInputForBytesRef(MethodSpec.Builder builder, String blockVariable, String offsetVariable) {
// scratch is a BytesRef var that must have been defined before the iteration starts
builder.addStatement("$T.combine(state, groupId, $L.getBytesRef($L, scratch))", declarationType, blockVariable, offsetVariable);
private void combineRawInputForArray(MethodSpec.Builder builder, String arrayVariable) {
warningsBlock(builder, () -> builder.addStatement("$T.combine(state, groupId, $L)", declarationType, arrayVariable));
}

private void warningsBlock(MethodSpec.Builder builder, Runnable block) {
Expand Down