Skip to content

Commit

Permalink
Add rate aggregation function (#106703)
Browse files Browse the repository at this point in the history
This PR introduces a rate aggregation function to ESQL, with two main 
changes:

1. Define of the grouping state for the rate aggregation function:

- For raw input, the function expects data to arrive in descending 
timestamp order without gaps; hence, perform a reduction with each
incoming entry. Each grouping should consist of at most two entries: one
for the starting time and one for the ending time.

- For intermediate input, the function buffers data as they can arrive 
out of order, although non-overlapping. This shouldn't have
significant issues, as we expect at most two entries per participating
pipeline.

- The intermediate output consists of three blocks: timestamps, values, 
and resets. Both timestamps and values can contain multiple
entries sorted in descending order by timestamp.

- This rate function does not support non-grouping aggregation. However, 
I can enable it if we think otherwise.

2. Modifies the GroupingAggregatorImplementer code generator to include 
the timestamp vector block. I explored several options to generate
multiple input blocks. However, both generated and code generator are
much more complicated in a generic solution. And it's unlikely that we
will need another function requires multiple input blocks. Hence, I
decided to tweak this class to append the timestamps long vector block
when specified.
  • Loading branch information
dnhatn committed Mar 25, 2024
1 parent 96230f7 commit 49cf3cb
Show file tree
Hide file tree
Showing 16 changed files with 2,145 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,9 @@
public @interface GroupingAggregator {

IntermediateState[] value() default {};

/**
* If {@code true} then the @timestamp LongVector will be appended to the input blocks of the aggregation function.
*/
boolean includeTimestamps() default false;
}
20 changes: 19 additions & 1 deletion x-pack/plugin/esql/compute/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,25 @@ tasks.named('stringTemplates').configure {
it.inputFile = valuesAggregatorInputFile
it.outputFile = "org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java"
}
File multivalueDedupeInputFile = new File("${projectDir}/src/main/java/org/elasticsearch/compute/operator/X-MultivalueDedupe.java.st")

File rateAggregatorInputFile = file("src/main/java/org/elasticsearch/compute/aggregation/X-RateAggregator.java.st")
template {
it.properties = intProperties
it.inputFile = rateAggregatorInputFile
it.outputFile = "org/elasticsearch/compute/aggregation/RateIntAggregator.java"
}
template {
it.properties = longProperties
it.inputFile = rateAggregatorInputFile
it.outputFile = "org/elasticsearch/compute/aggregation/RateLongAggregator.java"
}
template {
it.properties = doubleProperties
it.inputFile = rateAggregatorInputFile
it.outputFile = "org/elasticsearch/compute/aggregation/RateDoubleAggregator.java"
}

File multivalueDedupeInputFile = file("src/main/java/org/elasticsearch/compute/operator/X-MultivalueDedupe.java.st")
template {
it.properties = intProperties
it.inputFile = multivalueDedupeInputFile
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,12 @@ public AggregatorFunctionSupplierImplementer(
this.groupingAggregatorImplementer = groupingAggregatorImplementer;

Set<Parameter> createParameters = new LinkedHashSet<>();
createParameters.addAll(aggregatorImplementer.createParameters());
createParameters.addAll(groupingAggregatorImplementer.createParameters());
if (aggregatorImplementer != null) {
createParameters.addAll(aggregatorImplementer.createParameters());
}
if (groupingAggregatorImplementer != null) {
createParameters.addAll(groupingAggregatorImplementer.createParameters());
}
this.createParameters = new ArrayList<>(createParameters);
this.createParameters.add(0, new Parameter(LIST_INTEGER, "channels"));

Expand Down Expand Up @@ -84,7 +88,11 @@ private TypeSpec type() {

createParameters.stream().forEach(p -> p.declareField(builder));
builder.addMethod(ctor());
builder.addMethod(aggregator());
if (aggregatorImplementer != null) {
builder.addMethod(aggregator());
} else {
builder.addMethod(unsupportedNonGroupingAggregator());
}
builder.addMethod(groupingAggregator());
builder.addMethod(describe());
return builder.build();
Expand All @@ -96,6 +104,15 @@ private MethodSpec ctor() {
return builder.build();
}

private MethodSpec unsupportedNonGroupingAggregator() {
MethodSpec.Builder builder = MethodSpec.methodBuilder("aggregator")
.addParameter(DRIVER_CONTEXT, "driverContext")
.returns(Types.AGGREGATOR_FUNCTION);
builder.addAnnotation(Override.class).addModifiers(Modifier.PUBLIC);
builder.addStatement("throw new UnsupportedOperationException($S)", "non-grouping aggregator is not supported");
return builder.build();
}

private MethodSpec aggregator() {
MethodSpec.Builder builder = MethodSpec.methodBuilder("aggregator")
.addParameter(DRIVER_CONTEXT, "driverContext")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,21 @@ public boolean process(Set<? extends TypeElement> set, RoundEnvironment roundEnv
write(aggClass, "aggregator", implementer.sourceFile(), env);
}
GroupingAggregatorImplementer groupingAggregatorImplementer = null;
if (aggClass.getAnnotation(Aggregator.class) != null) {
assert aggClass.getAnnotation(GroupingAggregator.class) != null;
if (aggClass.getAnnotation(GroupingAggregator.class) != null) {
IntermediateState[] intermediateState = aggClass.getAnnotation(GroupingAggregator.class).value();
if (intermediateState.length == 0) {
if (intermediateState.length == 0 && aggClass.getAnnotation(Aggregator.class) != null) {
intermediateState = aggClass.getAnnotation(Aggregator.class).value();
}

groupingAggregatorImplementer = new GroupingAggregatorImplementer(env.getElementUtils(), aggClass, intermediateState);
boolean includeTimestamps = aggClass.getAnnotation(GroupingAggregator.class).includeTimestamps();
groupingAggregatorImplementer = new GroupingAggregatorImplementer(
env.getElementUtils(),
aggClass,
intermediateState,
includeTimestamps
);
write(aggClass, "grouping aggregator", groupingAggregatorImplementer.sourceFile(), env);
}
if (implementer != null && groupingAggregatorImplementer != null) {
if (implementer != null || groupingAggregatorImplementer != null) {
write(
aggClass,
"aggregator function supplier",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
import static org.elasticsearch.compute.gen.Types.INT_VECTOR;
import static org.elasticsearch.compute.gen.Types.LIST_AGG_FUNC_DESC;
import static org.elasticsearch.compute.gen.Types.LIST_INTEGER;
import static org.elasticsearch.compute.gen.Types.LONG_BLOCK;
import static org.elasticsearch.compute.gen.Types.LONG_VECTOR;
import static org.elasticsearch.compute.gen.Types.PAGE;
import static org.elasticsearch.compute.gen.Types.SEEN_GROUP_IDS;

Expand All @@ -71,8 +73,14 @@ public class GroupingAggregatorImplementer {
private final List<Parameter> createParameters;
private final ClassName implementation;
private final List<AggregatorImplementer.IntermediateStateDesc> intermediateState;
private final boolean includeTimestampVector;

public GroupingAggregatorImplementer(Elements elements, TypeElement declarationType, IntermediateState[] interStateAnno) {
public GroupingAggregatorImplementer(
Elements elements,
TypeElement declarationType,
IntermediateState[] interStateAnno,
boolean includeTimestampVector
) {
this.declarationType = declarationType;

this.init = findRequiredMethod(declarationType, new String[] { "init", "initGrouping" }, e -> true);
Expand Down Expand Up @@ -103,6 +111,7 @@ public GroupingAggregatorImplementer(Elements elements, TypeElement declarationT
intermediateState = Arrays.stream(interStateAnno)
.map(AggregatorImplementer.IntermediateStateDesc::newIntermediateStateDesc)
.toList();
this.includeTimestampVector = includeTimestampVector;
}

public ClassName implementation() {
Expand Down Expand Up @@ -264,15 +273,24 @@ private MethodSpec prepareProcessPage() {

builder.addStatement("$T valuesBlock = page.getBlock(channels.get(0))", valueBlockType(init, combine));
builder.addStatement("$T valuesVector = valuesBlock.asVector()", valueVectorType(init, combine));
if (includeTimestampVector) {
builder.addStatement("$T timestampsBlock = page.getBlock(channels.get(1))", LONG_BLOCK);
builder.addStatement("$T timestampsVector = timestampsBlock.asVector()", LONG_VECTOR);

builder.beginControlFlow("if (timestampsVector == null) ");
builder.addStatement("throw new IllegalStateException($S)", "expected @timestamp vector; but got a block");
builder.endControlFlow();
}
builder.beginControlFlow("if (valuesVector == null)");
String extra = includeTimestampVector ? ", timestampsVector" : "";
{
builder.beginControlFlow("if (valuesBlock.mayHaveNulls())");
builder.addStatement("state.enableGroupIdTracking(seenGroupIds)");
builder.endControlFlow();
builder.addStatement("return $L", addInput(b -> b.addStatement("addRawInput(positionOffset, groupIds, valuesBlock)")));
builder.addStatement("return $L", addInput(b -> b.addStatement("addRawInput(positionOffset, groupIds, valuesBlock$L)", extra)));
}
builder.endControlFlow();
builder.addStatement("return $L", addInput(b -> b.addStatement("addRawInput(positionOffset, groupIds, valuesVector)")));
builder.addStatement("return $L", addInput(b -> b.addStatement("addRawInput(positionOffset, groupIds, valuesVector$L)", extra)));
return builder.build();
}

Expand Down Expand Up @@ -308,6 +326,9 @@ private MethodSpec addRawInputLoop(TypeName groupsType, TypeName valuesType) {
MethodSpec.Builder builder = MethodSpec.methodBuilder(methodName);
builder.addModifiers(Modifier.PRIVATE);
builder.addParameter(TypeName.INT, "positionOffset").addParameter(groupsType, "groups").addParameter(valuesType, "values");
if (includeTimestampVector) {
builder.addParameter(LONG_VECTOR, "timestamps");
}
if (valuesIsBytesRef) {
// Add bytes_ref scratch var that will be used for bytes_ref blocks/vectors
builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF);
Expand Down Expand Up @@ -354,6 +375,10 @@ private void combineRawInput(MethodSpec.Builder builder, String blockVariable, S
combineRawInputForBytesRef(builder, blockVariable, offsetVariable);
return;
}
if (includeTimestampVector) {
combineRawInputWithTimestamp(builder, offsetVariable);
return;
}
TypeName valueType = TypeName.get(combine.getParameters().get(combine.getParameters().size() - 1).asType());
if (valueType.isPrimitive() == false) {
throw new IllegalArgumentException("second parameter to combine must be a primitive");
Expand Down Expand Up @@ -403,6 +428,22 @@ private void combineRawInputForVoid(
);
}

private void combineRawInputWithTimestamp(MethodSpec.Builder builder, String offsetVariable) {
TypeName valueType = TypeName.get(combine.getParameters().get(combine.getParameters().size() - 1).asType());
String blockType = valueType.toString().substring(0, 1).toUpperCase(Locale.ROOT) + valueType.toString().substring(1);
if (offsetVariable.contains(" + ")) {
builder.addStatement("var valuePosition = $L", offsetVariable);
offsetVariable = "valuePosition";
}
builder.addStatement(
"$T.combine(state, groupId, timestamps.getLong($L), values.get$L($L))",
declarationType,
offsetVariable,
blockType,
offsetVariable
);
}

private void combineRawInputForBytesRef(MethodSpec.Builder builder, String blockVariable, String offsetVariable) {
// scratch is a BytesRef var that must have been defined before the iteration starts
builder.addStatement("$T.combine(state, groupId, $L.getBytesRef($L, scratch))", declarationType, blockVariable, offsetVariable);
Expand Down

0 comments on commit 49cf3cb

Please sign in to comment.