Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import javax.lang.model.util.Elements;

import static java.util.stream.Collectors.joining;
import static org.elasticsearch.compute.gen.Methods.getMethod;
import static org.elasticsearch.compute.gen.Methods.optionalStaticMethod;
import static org.elasticsearch.compute.gen.Methods.requireAnyArgs;
import static org.elasticsearch.compute.gen.Methods.requireAnyType;
Expand All @@ -53,7 +54,6 @@
import static org.elasticsearch.compute.gen.Types.BLOCK;
import static org.elasticsearch.compute.gen.Types.BLOCK_ARRAY;
import static org.elasticsearch.compute.gen.Types.BOOLEAN_VECTOR;
import static org.elasticsearch.compute.gen.Types.BYTES_REF;
import static org.elasticsearch.compute.gen.Types.DRIVER_CONTEXT;
import static org.elasticsearch.compute.gen.Types.ELEMENT_TYPE;
import static org.elasticsearch.compute.gen.Types.INTERMEDIATE_STATE_DESC;
Expand All @@ -62,6 +62,8 @@
import static org.elasticsearch.compute.gen.Types.PAGE;
import static org.elasticsearch.compute.gen.Types.WARNINGS;
import static org.elasticsearch.compute.gen.Types.blockType;
import static org.elasticsearch.compute.gen.Types.fromString;
import static org.elasticsearch.compute.gen.Types.scratchType;
import static org.elasticsearch.compute.gen.Types.vectorType;

/**
Expand All @@ -85,7 +87,7 @@ public class AggregatorImplementer {

private final AggregationState aggState;
private final List<Argument> aggParams;
private final boolean hasOnlyBlockArguments;
private final boolean attemptToUseVectors;

public AggregatorImplementer(
Elements elements,
Expand Down Expand Up @@ -119,7 +121,8 @@ public AggregatorImplementer(
return a;
}).filter(a -> a instanceof PositionArgument == false).toList();

this.hasOnlyBlockArguments = this.aggParams.stream().allMatch(a -> a instanceof BlockArgument);
this.attemptToUseVectors = aggParams.stream().anyMatch(a -> (a instanceof BlockArgument) == false)
&& aggParams.stream().noneMatch(a -> a instanceof StandardArgument && a.dataType(false) == null);

this.createParameters = init.getParameters()
.stream()
Expand Down Expand Up @@ -199,7 +202,7 @@ private TypeSpec type() {
builder.addMethod(addRawInput());
builder.addMethod(addRawInputExploded(true));
builder.addMethod(addRawInputExploded(false));
if (hasOnlyBlockArguments == false) {
if (attemptToUseVectors) {
builder.addMethod(addRawVector(false));
builder.addMethod(addRawVector(true));
}
Expand Down Expand Up @@ -340,16 +343,18 @@ private MethodSpec addRawInputExploded(boolean hasMask) {
builder.addStatement("$T $L = page.getBlock(channels.get($L))", a.dataType(true), a.blockName(), i);
}

for (Argument a : aggParams) {
String rawBlock = "addRawBlock("
+ aggParams.stream().map(arg -> arg.blockName()).collect(joining(", "))
+ (hasMask ? ", mask" : "")
+ ")";
if (attemptToUseVectors) {
for (Argument a : aggParams) {
String rawBlock = "addRawBlock("
+ aggParams.stream().map(arg -> arg.blockName()).collect(joining(", "))
+ (hasMask ? ", mask" : "")
+ ")";

a.resolveVectors(builder, rawBlock, "return");
a.resolveVectors(builder, rawBlock, "return");
}
}

builder.addStatement(invokeAddRaw(hasOnlyBlockArguments, hasMask));
builder.addStatement(invokeAddRaw(attemptToUseVectors == false, hasMask));
return builder.build();
}

Expand Down Expand Up @@ -499,9 +504,9 @@ private MethodSpec.Builder initAddRaw(boolean blockStyle, boolean masked) {
builder.addParameter(BOOLEAN_VECTOR, "mask");
}
for (Argument a : aggParams) {
if (a.isBytesRef()) {
// Add bytes_ref scratch var that will be used for bytes_ref blocks/vectors
builder.addStatement("$T $L = new $T()", BYTES_REF, a.scratchName(), BYTES_REF);
if (a.scratchType() != null) {
// Add scratch var that will be used for some blocks/vectors, e.g. for bytes_ref
builder.addStatement("$T $L = new $T()", a.scratchType(), a.scratchName(), a.scratchType());
}
}
return builder;
Expand Down Expand Up @@ -610,8 +615,8 @@ private MethodSpec addIntermediateInput() {
).map(Methods::requireType).toArray(TypeMatcher[]::new)
)
);
if (intermediateState.stream().map(IntermediateStateDesc::elementType).anyMatch(n -> n.equals("BYTES_REF"))) {
builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF);
for (IntermediateStateDesc interState : intermediateState) {
interState.addScratchDeclaration(builder);
}
builder.addStatement("$T.combineIntermediate(state, " + intermediateStateRowAccess() + ")", declarationType);
}
Expand Down Expand Up @@ -706,13 +711,25 @@ public String access(String position) {
if (block) {
return name();
}
String s = name() + "." + vectorAccessorName(elementType()) + "(" + position;
if (elementType().equals("BYTES_REF")) {
s += ", scratch";
String s = name() + ".";
if (vectorType(elementType) != null) {
s += vectorAccessorName(elementType()) + "(" + position;
} else {
s += getMethod(fromString(elementType())) + "(" + name() + ".getFirstValueIndex(" + position + ")";
}
if (scratchType(elementType()) != null) {
s += ", " + name() + "Scratch";
}
return s + ")";
}

public void addScratchDeclaration(MethodSpec.Builder builder) {
ClassName scratchType = scratchType(elementType());
if (scratchType != null) {
builder.addStatement("$T $L = new $T()", scratchType, name() + "Scratch", scratchType);
}
}

public void assignToVariable(MethodSpec.Builder builder, int offset) {
builder.addStatement("Block $L = page.getBlock(channels.get($L))", name + "Uncast", offset);
ClassName blockType = blockType(elementType());
Expand All @@ -721,7 +738,7 @@ public void assignToVariable(MethodSpec.Builder builder, int offset) {
builder.addStatement("return");
builder.endControlFlow();
}
if (block) {
if (block || vectorType(elementType) == null) {
builder.addStatement("$T $L = ($T) $L", blockType, name, blockType, name + "Uncast");
} else {
builder.addStatement("$T $L = (($T) $L).asVector()", vectorType(elementType), name, blockType, name + "Uncast");
Expand All @@ -732,6 +749,7 @@ public TypeName combineArgType() {
var type = Types.fromString(elementType);
return block ? blockType(type) : type;
}

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.elasticsearch.compute.gen.argument.BlockArgument;
import org.elasticsearch.compute.gen.argument.BuilderArgument;
import org.elasticsearch.compute.gen.argument.FixedArgument;
import org.elasticsearch.compute.gen.argument.PositionArgument;

import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -51,6 +52,7 @@ public class EvaluatorImplementer {
private final ProcessFunction processFunction;
private final ClassName implementation;
private final boolean processOutputsMultivalued;
private final boolean vectorsUnsupported;
private final boolean allNullsIsNull;

public EvaluatorImplementer(
Expand All @@ -69,6 +71,10 @@ public EvaluatorImplementer(
declarationType.getSimpleName() + extraName + "Evaluator"
);
this.processOutputsMultivalued = this.processFunction.hasBlockType;
boolean anyParameterNotSupportingVectors = this.processFunction.args.stream()
.filter(a -> a instanceof FixedArgument == false && a instanceof PositionArgument == false)
.anyMatch(a -> a.dataType(false) == null);
vectorsUnsupported = processOutputsMultivalued || anyParameterNotSupportingVectors;
this.allNullsIsNull = allNullsIsNull;
}

Expand Down Expand Up @@ -101,7 +107,7 @@ private TypeSpec type() {
builder.addMethod(eval());
builder.addMethod(processFunction.baseRamBytesUsed());

if (processOutputsMultivalued) {
if (vectorsUnsupported) {
if (processFunction.args.stream().anyMatch(x -> x instanceof FixedArgument == false)) {
builder.addMethod(realEval(true));
}
Expand Down Expand Up @@ -145,7 +151,7 @@ private MethodSpec eval() {
builder.addModifiers(Modifier.PUBLIC).returns(BLOCK).addParameter(PAGE, "page");
processFunction.args.forEach(a -> a.evalToBlock(builder));
String invokeBlockEval = invokeRealEval(true);
if (processOutputsMultivalued) {
if (vectorsUnsupported) {
builder.addStatement(invokeBlockEval);
} else {
processFunction.args.forEach(a -> a.resolveVectors(builder, invokeBlockEval));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
import static org.elasticsearch.compute.gen.Types.BIG_ARRAYS;
import static org.elasticsearch.compute.gen.Types.BLOCK;
import static org.elasticsearch.compute.gen.Types.BLOCK_ARRAY;
import static org.elasticsearch.compute.gen.Types.BYTES_REF;
import static org.elasticsearch.compute.gen.Types.DRIVER_CONTEXT;
import static org.elasticsearch.compute.gen.Types.ELEMENT_TYPE;
import static org.elasticsearch.compute.gen.Types.GROUPING_AGGREGATOR_EVALUATOR_CONTEXT;
Expand Down Expand Up @@ -93,6 +92,7 @@ public class GroupingAggregatorImplementer {
private final AggregationState aggState;
private final List<Argument> aggParams;
private final boolean hasOnlyBlockArguments;
private final boolean allArgumentsSupportVectors;

public GroupingAggregatorImplementer(
Elements elements,
Expand Down Expand Up @@ -128,6 +128,7 @@ public GroupingAggregatorImplementer(
}).filter(a -> a instanceof PositionArgument == false).toList();

this.hasOnlyBlockArguments = this.aggParams.stream().allMatch(a -> a instanceof BlockArgument);
this.allArgumentsSupportVectors = aggParams.stream().noneMatch(a -> a instanceof StandardArgument && a.dataType(false) == null);

this.createParameters = init.getParameters()
.stream()
Expand Down Expand Up @@ -204,7 +205,7 @@ private TypeSpec type() {
builder.addMethod(prepareProcessRawInputPage());
for (ClassName groupIdClass : GROUP_IDS_CLASSES) {
builder.addMethod(addRawInputLoop(groupIdClass, false));
if (hasOnlyBlockArguments == false) {
if (hasOnlyBlockArguments == false && allArgumentsSupportVectors) {
builder.addMethod(addRawInputLoop(groupIdClass, true));
}
builder.addMethod(addIntermediateInput(groupIdClass));
Expand Down Expand Up @@ -330,26 +331,31 @@ private MethodSpec prepareProcessRawInputPage() {
builder.addStatement("$T $L = page.getBlock(channels.get($L))", a.dataType(true), a.blockName(), i);
}

for (Argument a : aggParams) {
builder.addStatement(
"$T $L = $L.asVector()",
vectorType(a.elementType()),
(a instanceof BlockArgument) ? (a.name() + "Vector") : a.vectorName(),
a.blockName()
);
builder.beginControlFlow("if ($L == null)", (a instanceof BlockArgument) ? (a.name() + "Vector") : a.vectorName());
{
String groupIdTrackingStatement = "maybeEnableGroupIdTracking(seenGroupIds, "
+ aggParams.stream().map(arg -> arg.blockName()).collect(joining(", "))
+ ")";

if (allArgumentsSupportVectors) {

for (Argument a : aggParams) {
builder.addStatement(
"maybeEnableGroupIdTracking(seenGroupIds, "
+ aggParams.stream().map(arg -> arg.blockName()).collect(joining(", "))
+ ")"
"$T $L = $L.asVector()",
vectorType(a.elementType()),
(a instanceof BlockArgument) ? (a.name() + "Vector") : a.vectorName(),
a.blockName()
);
returnAddInput(builder, false);
builder.beginControlFlow("if ($L == null)", (a instanceof BlockArgument) ? (a.name() + "Vector") : a.vectorName());
{
builder.addStatement(groupIdTrackingStatement);
returnAddInput(builder, false);
}
builder.endControlFlow();
}
builder.endControlFlow();
returnAddInput(builder, true);
} else {
builder.addStatement(groupIdTrackingStatement);
returnAddInput(builder, false);
}

returnAddInput(builder, true);
return builder.build();
}

Expand Down Expand Up @@ -443,9 +449,9 @@ private MethodSpec addRawInputLoop(TypeName groupsType, boolean valuesAreVector)
);
}
for (Argument a : aggParams) {
if (a.isBytesRef()) {
// Add bytes_ref scratch var that will be used for bytes_ref blocks/vectors
builder.addStatement("$T $L = new $T()", BYTES_REF, a.scratchName(), BYTES_REF);
if (a.scratchType() != null) {
// Add scratch var that will be used for some blocks/vectors, e.g. for bytes_ref
builder.addStatement("$T $L = new $T()", a.scratchType(), a.scratchName(), a.scratchType());
}
}

Expand Down Expand Up @@ -645,11 +651,7 @@ private MethodSpec addIntermediateInput(TypeName groupsType) {
.collect(Collectors.joining(", "));
builder.addStatement("$T.combineIntermediate(state, positionOffset, groups, " + states + ")", declarationType);
} else {
if (intermediateState.stream()
.map(AggregatorImplementer.IntermediateStateDesc::elementType)
.anyMatch(n -> n.equals("BYTES_REF"))) {
builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF);
}
intermediateState.forEach(state -> state.addScratchDeclaration(builder));
builder.beginControlFlow("for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++)");
{
if (groupsIsBlock) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,9 @@ public static String getMethod(TypeName elementType) {
if (elementType.equals(TypeName.FLOAT)) {
return "getFloat";
}
if (elementType.equals(Types.EXPONENTIAL_HISTOGRAM)) {
return "getExponentialHistogram";
}
throw new IllegalArgumentException("unknown get method for [" + elementType + "]");
}

Expand Down
Loading