Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.compute.ann;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
* Used on parameters on methods annotated with {@link Evaluator} or in
* {@link Aggregator} or {@link GroupingAggregator} to indicate an argument
* that is the position in a block.
*/
@Target(ElementType.PARAMETER)
@Retention(RetentionPolicy.SOURCE)
public @interface Position {
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import org.elasticsearch.compute.ann.IntermediateState;
import org.elasticsearch.compute.gen.Methods.TypeMatcher;
import org.elasticsearch.compute.gen.argument.Argument;
import org.elasticsearch.compute.gen.argument.ArrayArgument;
import org.elasticsearch.compute.gen.argument.BlockArgument;
import org.elasticsearch.compute.gen.argument.PositionArgument;
import org.elasticsearch.compute.gen.argument.StandardArgument;

import java.util.ArrayList;
Expand Down Expand Up @@ -110,12 +111,13 @@ public AggregatorImplementer(
requireName("combine"),
requireArgsStartsWith(requireType(aggState.declaredType()), requireAnyType("<aggregation input column type>"))
);
this.aggParams = combine.getParameters().stream().skip(1).map(v -> {
this.aggParams = combine.getParameters().stream().skip(1).flatMap(v -> {
Argument a = Argument.fromParameter(types, v);
return switch (a) {
case StandardArgument sa -> new AggregationParameter(sa.name(), sa.type(), false);
case ArrayArgument aa -> new AggregationParameter(aa.name(), aa.componentType(), true);
default -> throw new IllegalArgumentException("unsupported argument [" + a + "]");
case StandardArgument sa -> Stream.of(new AggregationParameter(sa.name(), sa.type(), false));
case BlockArgument ba -> Stream.of(new AggregationParameter(ba.name(), Types.elementType(ba.type()), true));
case PositionArgument pa -> Stream.of();
default -> throw new IllegalArgumentException("unsupported argument [" + declarationType + "][" + a + "]");
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the next PR I should be ale to remove this instanceof and just us the Argument directly. It'll need a few more methods, but should be fine.

};
}).toList();

Expand Down Expand Up @@ -435,22 +437,10 @@ private MethodSpec addRawBlock(boolean masked) {
if (aggParams.size() > 1) {
throw new IllegalArgumentException("array mode not supported for multiple args");
}
builder.addStatement("int start = $L.getFirstValueIndex(p)", aggParams.getFirst().blockName());
builder.addStatement("int end = start + $L.getValueCount(p)", aggParams.getFirst().blockName());
// TODO move this to the top of the loop
builder.addStatement(
"$L[] valuesArray = new $L[end - start]",
aggParams.getFirst().arrayType(),
aggParams.getFirst().arrayType()
warningsBlock(
builder,
() -> builder.addStatement("$T.combine(state, p, $L)", declarationType, aggParams.getFirst().blockName())
);
builder.beginControlFlow("for (int i = start; i < end; i++)");
builder.addStatement(
"valuesArray[i-start] = $L.get$L(i)",
aggParams.getFirst().blockName(),
capitalize(aggParams.getFirst().arrayType())
);
builder.endControlFlow();
combineRawInputForArray(builder, "valuesArray");
} else {
if (first == null && aggState.hasSeen()) {
builder.addStatement("state.seen(true)");
Expand Down Expand Up @@ -547,10 +537,6 @@ private void invokeCombineRawInput(TypeName returnType, MethodSpec.Builder build
builder.addStatement(pattern.toString(), params.toArray());
}

private void combineRawInputForArray(MethodSpec.Builder builder, String arrayVariable) {
warningsBlock(builder, () -> builder.addStatement("$T.combine(state, $L)", declarationType, arrayVariable));
}

private void warningsBlock(MethodSpec.Builder builder, Runnable block) {
if (warnExceptions.isEmpty() == false) {
builder.beginControlFlow("try");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.compute.gen;

import org.elasticsearch.compute.ann.Fixed;
import org.elasticsearch.compute.ann.Position;

import java.util.List;
import java.util.Set;
Expand Down Expand Up @@ -44,7 +45,8 @@ public Set<String> getSupportedAnnotationTypes() {
"org.elasticsearch.xpack.esql.expression.function.MapParam",
"org.elasticsearch.rest.ServerlessScope",
"org.elasticsearch.xcontent.ParserConstructor",
Fixed.class.getName()
Fixed.class.getName(),
Position.class.getName()
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ private MethodSpec realEval(boolean blockStyle) {

StringBuilder pattern = new StringBuilder();
List<Object> args = new ArrayList<>();
pattern.append(processOutputsMultivalued ? "$T.$N(result, p, " : "$T.$N(");
pattern.append("$T.$N(");
args.add(declarationType);
args.add(processFunction.function.getSimpleName());
processFunction.args.stream().forEach(a -> {
Expand Down Expand Up @@ -312,10 +312,7 @@ static class ProcessFunction {
}
builderArg = ba;
} else if (arg instanceof BlockArgument) {
if (builderArg != null && args.size() == 2 && hasBlockType == false) {
args.clear();
hasBlockType = true;
}
hasBlockType = true;
}
args.add(arg);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import org.elasticsearch.compute.gen.AggregatorImplementer.AggregationParameter;
import org.elasticsearch.compute.gen.AggregatorImplementer.AggregationState;
import org.elasticsearch.compute.gen.argument.Argument;
import org.elasticsearch.compute.gen.argument.ArrayArgument;
import org.elasticsearch.compute.gen.argument.BlockArgument;
import org.elasticsearch.compute.gen.argument.PositionArgument;
import org.elasticsearch.compute.gen.argument.StandardArgument;

import java.util.ArrayList;
Expand All @@ -37,7 +38,6 @@
import javax.lang.model.util.Elements;

import static java.util.stream.Collectors.joining;
import static org.elasticsearch.compute.gen.AggregatorImplementer.capitalize;
import static org.elasticsearch.compute.gen.Methods.optionalStaticMethod;
import static org.elasticsearch.compute.gen.Methods.requireAnyArgs;
import static org.elasticsearch.compute.gen.Methods.requireAnyType;
Expand Down Expand Up @@ -118,12 +118,13 @@ public GroupingAggregatorImplementer(
requireName("combine"),
combineArgs(aggState)
);
this.aggParams = combine.getParameters().stream().skip(aggState.declaredType().isPrimitive() ? 1 : 2).map(v -> {
this.aggParams = combine.getParameters().stream().skip(aggState.declaredType().isPrimitive() ? 1 : 2).flatMap(v -> {
Argument a = Argument.fromParameter(types, v);
return switch (a) {
case StandardArgument sa -> new AggregationParameter(sa.name(), sa.type(), false);
case ArrayArgument aa -> new AggregationParameter(aa.name(), aa.componentType(), true);
default -> throw new IllegalArgumentException("unsupported argument [" + a + "]");
case StandardArgument sa -> Stream.of(new AggregationParameter(sa.name(), sa.type(), false));
case BlockArgument ba -> Stream.of(new AggregationParameter(ba.name(), Types.elementType(ba.type()), true));
case PositionArgument pa -> Stream.of();
default -> throw new IllegalArgumentException("unsupported argument [" + declarationType + "][" + a + "]");
};
}).toList();

Expand Down Expand Up @@ -476,21 +477,14 @@ private MethodSpec addRawInputLoop(TypeName groupsType, boolean valuesAreVector)
if (aggParams.size() > 1) {
throw new IllegalArgumentException("array mode not supported for multiple args");
}
String arrayType = aggParams.getFirst().type().toString().replace("[]", "");
builder.addStatement("int valuesStart = $L.getFirstValueIndex(valuesPosition)", aggParams.getFirst().blockName());
builder.addStatement(
"int valuesEnd = valuesStart + $L.getValueCount(valuesPosition)",
aggParams.getFirst().blockName()
);
builder.addStatement("$L[] valuesArray = new $L[valuesEnd - valuesStart]", arrayType, arrayType);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

builder.beginControlFlow("for (int v = valuesStart; v < valuesEnd; v++)");
builder.addStatement(
"valuesArray[v-valuesStart] = $L.get$L(v)",
aggParams.getFirst().blockName(),
capitalize(aggParams.getFirst().arrayType())
warningsBlock(
builder,
() -> builder.addStatement(
"$T.combine(state, groupId, valuesPosition, $L)",
declarationType,
aggParams.getFirst().blockName()
)
);
builder.endControlFlow();
combineRawInputForArray(builder, "valuesArray");
} else {
for (AggregationParameter p : aggParams) {
builder.addStatement("int $L = $L.getFirstValueIndex(valuesPosition)", p.startName(), p.blockName());
Expand Down Expand Up @@ -536,6 +530,9 @@ private void invokeCombineRawInput(TypeName returnType, MethodSpec.Builder build
pattern.append("$T.combine(state, groupId");
params.add(declarationType);
}
if (aggParams.getFirst().isArray()) {
pattern.append(", p");
}
for (AggregationParameter p : aggParams) {
pattern.append(", $L");
params.add(p.valueName());
Expand All @@ -547,10 +544,6 @@ private void invokeCombineRawInput(TypeName returnType, MethodSpec.Builder build
builder.addStatement(pattern.toString(), params.toArray());
}

private void combineRawInputForArray(MethodSpec.Builder builder, String arrayVariable) {
warningsBlock(builder, () -> builder.addStatement("$T.combine(state, groupId, $L)", declarationType, arrayVariable));
}

private boolean shouldWrapAddInput(boolean valuesAreVector) {
return optionalStaticMethod(
declarationType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import com.squareup.javapoet.TypeSpec;

import org.elasticsearch.compute.ann.Fixed;
import org.elasticsearch.compute.ann.Position;
import org.elasticsearch.compute.gen.Types;

import java.util.List;
Expand Down Expand Up @@ -41,6 +42,12 @@ static Argument fromParameter(javax.lang.model.util.Types types, VariableElement
Types.extendsSuper(types, v.asType(), "org.elasticsearch.core.Releasable")
);
}

Position position = v.getAnnotation(Position.class);
if (position != null) {
return new PositionArgument();
}

if (type instanceof ClassName c
&& c.simpleName().equals("Builder")
&& c.enclosingClassName() != null
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.compute.gen.argument;

import com.squareup.javapoet.MethodSpec;
import com.squareup.javapoet.TypeName;
import com.squareup.javapoet.TypeSpec;

import java.util.List;

/**
* The position in a block.
*/
public record PositionArgument() implements Argument {
@Override
public TypeName dataType(boolean blockStyle) {
return TypeName.INT;
}

@Override
public String paramName(boolean blockStyle) {
// No need to pass it
return null;
}

@Override
public void declareField(TypeSpec.Builder builder) {
// Nothing to do
}

@Override
public void declareFactoryField(TypeSpec.Builder builder) {
// Nothing to do
}

@Override
public void implementCtor(MethodSpec.Builder builder) {
// Nothing to do
}

@Override
public void implementFactoryCtor(MethodSpec.Builder builder) {
// Nothing to do
}

@Override
public String factoryInvocation(MethodSpec.Builder factoryMethodBuilder) {
return null;
}

@Override
public void evalToBlock(MethodSpec.Builder builder) {
// nothing to do
}

@Override
public void closeEvalToBlock(MethodSpec.Builder builder) {
// nothing to do
}

@Override
public void resolveVectors(MethodSpec.Builder builder, String invokeBlockEval) {
// nothing to do
}

@Override
public void createScratch(MethodSpec.Builder builder) {
// nothing to do
}

@Override
public void skipNull(MethodSpec.Builder builder) {
// nothing to do
}

@Override
public void allBlocksAreNull(MethodSpec.Builder builder) {
// nothing to do
}

@Override
public void read(MethodSpec.Builder builder, boolean blockStyle) {
// nothing to do
}

@Override
public void buildInvocation(StringBuilder pattern, List<Object> args, boolean blockStyle) {
pattern.append("p");
}

@Override
public void buildToStringInvocation(StringBuilder pattern, List<Object> args, String prefix) {
// nothing to do
}

@Override
public String closeInvocation() {
return null;
}

@Override
public void sumBaseRamBytesUsed(MethodSpec.Builder builder) {}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading