Skip to content

Commit

Permalink
ESQL: Sum, Min, Max and Avg of constants (#105454)
Browse files Browse the repository at this point in the history
Allow expressions like
... | STATS sum([1, -9]), sum(null), min(21.0*3), avg([1,2,3])
by substituting sum(const) by mv_sum(const)*count(*) and min(const) by
mv_min(const) (and similarly for max and avg).
  • Loading branch information
alex-spies committed Mar 26, 2024
1 parent b39b373 commit 829ea4d
Show file tree
Hide file tree
Showing 10 changed files with 375 additions and 32 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/105454.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 105454
summary: "ESQL: Sum of constants"
area: ES|QL
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -1226,3 +1226,99 @@ FROM employees
vals:l
183
;

sumOfConst#[skip:-8.13.99,reason:supported in 8.14]
FROM employees
| STATS s1 = sum(1), s2point1 = sum(2.1), s_mv = sum([-1, 0, 3]) * 3, s_null = sum(null), rows = count(*)
;

s1:l | s2point1:d | s_mv:l | s_null:d | rows:l
100 | 210.0 | 600 | null | 100
;

sumOfConstGrouped#[skip:-8.13.99,reason:supported in 8.14]
FROM employees
| STATS s2point1 = round(sum(2.1), 1), s_mv = sum([-1, 0, 3]), rows = count(*) by languages
| SORT languages
;

s2point1:d | s_mv:l | rows:l | languages:i
31.5 | 30 | 15 | 1
39.9 | 38 | 19 | 2
35.7 | 34 | 17 | 3
37.8 | 36 | 18 | 4
44.1 | 42 | 21 | 5
21.0 | 20 | 10 | null
;

avgOfConst#[skip:-8.13.99,reason:supported in 8.14]
FROM employees
| STATS s1 = avg(1), s_mv = avg([-1, 0, 3]) * 3, s_null = avg(null)
;

s1:d | s_mv:d | s_null:d
1.0 | 2.0 | null
;

avgOfConstGrouped#[skip:-8.13.99,reason:supported in 8.14]
FROM employees
| STATS s2point1 = avg(2.1), s_mv = avg([-1, 0, 3]) * 3 by languages
| SORT languages
;

s2point1:d | s_mv:d | languages:i
2.1 | 2.0 | 1
2.1 | 2.0 | 2
2.1 | 2.0 | 3
2.1 | 2.0 | 4
2.1 | 2.0 | 5
2.1 | 2.0 | null
;

minOfConst#[skip:-8.13.99,reason:supported in 8.14]
FROM employees
| STATS s1 = min(1), s_mv = min([-1, 0, 3]), s_null = min(null)
;

s1:i | s_mv:i | s_null:null
1 | -1 | null
;

minOfConstGrouped#[skip:-8.13.99,reason:supported in 8.14]
FROM employees
| STATS s2point1 = min(2.1), s_mv = min([-1, 0, 3]) by languages
| SORT languages
;

s2point1:d | s_mv:i | languages:i
2.1 | -1 | 1
2.1 | -1 | 2
2.1 | -1 | 3
2.1 | -1 | 4
2.1 | -1 | 5
2.1 | -1 | null
;

maxOfConst#[skip:-8.13.99,reason:supported in 8.14]
FROM employees
| STATS s1 = max(1), s_mv = max([-1, 0, 3]), s_null = max(null)
;

s1:i | s_mv:i | s_null:null
1 | 3 | null
;

maxOfConstGrouped#[skip:-8.13.99,reason:supported in 8.14]
FROM employees
| STATS s2point1 = max(2.1), s_mv = max([-1, 0, 3]) by languages
| SORT languages
;

s2point1:d | s_mv:i | languages:i
2.1 | 3 | 1
2.1 | 3 | 2
2.1 | 3 | 3
2.1 | 3 | 4
2.1 | 3 | 5
2.1 | 3 | null
;
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,8 @@
*/
public interface SurrogateExpression {

/**
* Returns the expression to be replaced by or {@code null} if this cannot be replaced.
*/
Expression surrogate();
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.elasticsearch.xpack.esql.expression.SurrogateExpression;
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvAvg;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div;
import org.elasticsearch.xpack.ql.expression.Expression;
import org.elasticsearch.xpack.ql.expression.function.aggregate.AggregateFunction;
Expand Down Expand Up @@ -60,6 +61,7 @@ public Avg replaceChildren(List<Expression> newChildren) {
public Expression surrogate() {
var s = source();
var field = field();
return new Div(s, new Sum(s, field), new Count(s, field), dataType());

return field().foldable() ? new MvAvg(s, field) : new Div(s, new Sum(s, field), new Count(s, field), dataType());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@
import org.elasticsearch.compute.aggregation.MaxDoubleAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.MaxIntAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.MaxLongAggregatorFunctionSupplier;
import org.elasticsearch.xpack.esql.expression.SurrogateExpression;
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMax;
import org.elasticsearch.xpack.ql.expression.Expression;
import org.elasticsearch.xpack.ql.tree.NodeInfo;
import org.elasticsearch.xpack.ql.tree.Source;
import org.elasticsearch.xpack.ql.type.DataType;

import java.util.List;

public class Max extends NumericAggregate {
public class Max extends NumericAggregate implements SurrogateExpression {

@FunctionInfo(returnType = { "double", "integer", "long" }, description = "The maximum value of a numeric field.", isAggregation = true)
public Max(Source source, @Param(name = "number", type = { "double", "integer", "long" }) Expression field) {
Expand Down Expand Up @@ -61,4 +63,9 @@ protected AggregatorFunctionSupplier intSupplier(List<Integer> inputChannels) {
protected AggregatorFunctionSupplier doubleSupplier(List<Integer> inputChannels) {
return new MaxDoubleAggregatorFunctionSupplier(inputChannels);
}

@Override
public Expression surrogate() {
return field().foldable() ? new MvMax(source(), field()) : null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@
import org.elasticsearch.compute.aggregation.MinDoubleAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.MinIntAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.MinLongAggregatorFunctionSupplier;
import org.elasticsearch.xpack.esql.expression.SurrogateExpression;
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMin;
import org.elasticsearch.xpack.ql.expression.Expression;
import org.elasticsearch.xpack.ql.tree.NodeInfo;
import org.elasticsearch.xpack.ql.tree.Source;
import org.elasticsearch.xpack.ql.type.DataType;

import java.util.List;

public class Min extends NumericAggregate {
public class Min extends NumericAggregate implements SurrogateExpression {

@FunctionInfo(returnType = { "double", "integer", "long" }, description = "The minimum value of a numeric field.", isAggregation = true)
public Min(Source source, @Param(name = "number", type = { "double", "integer", "long" }) Expression field) {
Expand Down Expand Up @@ -61,4 +63,9 @@ protected AggregatorFunctionSupplier intSupplier(List<Integer> inputChannels) {
protected AggregatorFunctionSupplier doubleSupplier(List<Integer> inputChannels) {
return new MinDoubleAggregatorFunctionSupplier(inputChannels);
}

@Override
public Expression surrogate() {
return field().foldable() ? new MvMin(source(), field()) : null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,18 @@
import org.elasticsearch.compute.aggregation.SumDoubleAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.SumIntAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.SumLongAggregatorFunctionSupplier;
import org.elasticsearch.xpack.esql.expression.SurrogateExpression;
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvSum;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul;
import org.elasticsearch.xpack.ql.expression.Expression;
import org.elasticsearch.xpack.ql.expression.Literal;
import org.elasticsearch.xpack.ql.tree.NodeInfo;
import org.elasticsearch.xpack.ql.tree.Source;
import org.elasticsearch.xpack.ql.type.DataType;
import org.elasticsearch.xpack.ql.type.DataTypes;
import org.elasticsearch.xpack.ql.util.StringUtils;

import java.util.List;

Expand All @@ -26,7 +32,7 @@
/**
* Sum all values of a field in matching documents.
*/
public class Sum extends NumericAggregate {
public class Sum extends NumericAggregate implements SurrogateExpression {

@FunctionInfo(returnType = "long", description = "The sum of a numeric field.", isAggregation = true)
public Sum(Source source, @Param(name = "number", type = { "double", "integer", "long" }) Expression field) {
Expand Down Expand Up @@ -63,4 +69,15 @@ protected AggregatorFunctionSupplier intSupplier(List<Integer> inputChannels) {
protected AggregatorFunctionSupplier doubleSupplier(List<Integer> inputChannels) {
return new SumDoubleAggregatorFunctionSupplier(inputChannels);
}

@Override
public Expression surrogate() {
var s = source();
var field = field();

// SUM(const) is equivalent to MV_SUM(const)*COUNT(*).
return field.foldable()
? new Mul(s, new MvSum(s, field), new Count(s, new Literal(s, StringUtils.WILDCARD, DataTypes.KEYWORD)))
: null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.elasticsearch.xpack.ql.expression.Attribute;
import org.elasticsearch.xpack.ql.expression.AttributeMap;
import org.elasticsearch.xpack.ql.expression.AttributeSet;
import org.elasticsearch.xpack.ql.expression.EmptyAttribute;
import org.elasticsearch.xpack.ql.expression.Expression;
import org.elasticsearch.xpack.ql.expression.ExpressionSet;
import org.elasticsearch.xpack.ql.expression.Expressions;
Expand Down Expand Up @@ -107,6 +108,23 @@ protected List<Batch<LogicalPlan>> batches() {
return rules();
}

protected static Batch<LogicalPlan> substitutions() {
return new Batch<>(
"Substitutions",
Limiter.ONCE,
// first extract nested aggs top-level - this simplifies the rest of the rules
new ReplaceStatsAggExpressionWithEval(),
// second extract nested aggs inside of them
new ReplaceStatsNestedExpressionWithEval(),
// lastly replace surrogate functions
new SubstituteSurrogates(),
new ReplaceRegexMatch(),
new ReplaceAliasingEvalWithProject(),
new SkipQueryOnEmptyMappings()
// new NormalizeAggregate(), - waits on https://github.com/elastic/elasticsearch/issues/100634
);
}

protected static Batch<LogicalPlan> operators() {
return new Batch<>(
"Operator Optimization",
Expand Down Expand Up @@ -150,26 +168,11 @@ protected static Batch<LogicalPlan> cleanup() {
}

protected static List<Batch<LogicalPlan>> rules() {
var substitutions = new Batch<>(
"Substitutions",
Limiter.ONCE,
// first extract nested aggs top-level - this simplifies the rest of the rules
new ReplaceStatsAggExpressionWithEval(),
// second extract nested aggs inside of them
new ReplaceStatsNestedExpressionWithEval(),
// lastly replace surrogate functions
new SubstituteSurrogates(),
new ReplaceRegexMatch(),
new ReplaceAliasingEvalWithProject(),
new SkipQueryOnEmptyMappings()
// new NormalizeAggregate(), - waits on https://github.com/elastic/elasticsearch/issues/100634
);

var skip = new Batch<>("Skip Compute", new SkipQueryOnLimitZero());
var defaultTopN = new Batch<>("Add default TopN", new AddDefaultTopN());
var label = new Batch<>("Set as Optimized", Limiter.ONCE, new SetAsOptimized());

return asList(substitutions, operators(), skip, cleanup(), defaultTopN, label);
return asList(substitutions(), operators(), skip, cleanup(), defaultTopN, label);
}

// TODO: currently this rule only works for aggregate functions (AVG)
Expand All @@ -191,16 +194,18 @@ protected LogicalPlan rule(Aggregate aggregate) {

// first pass to check existing aggregates (to avoid duplication and alias waste)
for (NamedExpression agg : aggs) {
if (Alias.unwrap(agg) instanceof AggregateFunction af && af instanceof SurrogateExpression == false) {
aggFuncToAttr.put(af, agg.toAttribute());
if (Alias.unwrap(agg) instanceof AggregateFunction af) {
if ((af instanceof SurrogateExpression se && se.surrogate() != null) == false) {
aggFuncToAttr.put(af, agg.toAttribute());
}
}
}

int[] counter = new int[] { 0 };
// 0. check list of surrogate expressions
for (NamedExpression agg : aggs) {
Expression e = Alias.unwrap(agg);
if (e instanceof SurrogateExpression sf) {
if (e instanceof SurrogateExpression sf && sf.surrogate() != null) {
changed = true;
Expression s = sf.surrogate();

Expand Down Expand Up @@ -240,9 +245,22 @@ protected LogicalPlan rule(Aggregate aggregate) {
LogicalPlan plan = aggregate;
if (changed) {
var source = aggregate.source();
plan = new Aggregate(aggregate.source(), aggregate.child(), aggregate.groupings(), newAggs);
if (newAggs.isEmpty() == false) {
plan = new Aggregate(source, aggregate.child(), aggregate.groupings(), newAggs);
} else {
// All aggs actually have been surrogates for (foldable) expressions, e.g.
// \_Aggregate[[],[AVG([1, 2][INTEGER]) AS s]]
// Replace by a local relation with one row, followed by an eval, e.g.
// \_Eval[[MVAVG([1, 2][INTEGER]) AS s]]
// \_LocalRelation[[{e}#21],[ConstantNullBlock[positions=1]]]
plan = new LocalRelation(
source,
List.of(new EmptyAttribute(source)),
LocalSupplier.of(new Block[] { BlockUtils.constantBlock(PlannerUtils.NON_BREAKING_BLOCK_FACTORY, null, 1) })
);
}
// 5. force the initial projection in place
if (transientEval.size() > 0) {
if (transientEval.isEmpty() == false) {
plan = new Eval(source, plan, transientEval);
// project away transient fields and re-enforce the original order using references (not copies) to the original aggs
// this works since the replaced aliases have their nameId copied to avoid having to update all references (which has
Expand Down Expand Up @@ -500,6 +518,8 @@ public LogicalPlan apply(LogicalPlan plan) {

plan = plan.transformUp(p -> {
// Apply the replacement inside Filter and Eval (which shouldn't make a difference)
// TODO: also allow aggregates once aggs on constants are supported.
// C.f. https://github.com/elastic/elasticsearch/issues/100634
if (p instanceof Filter || p instanceof Eval) {
p = p.transformExpressionsOnly(ReferenceAttribute.class, replaceReference);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.expression.SurrogateExpression;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
import org.elasticsearch.xpack.esql.expression.function.aggregate.CountDistinct;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Max;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Median;
import org.elasticsearch.xpack.esql.expression.function.aggregate.MedianAbsoluteDeviation;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Min;
import org.elasticsearch.xpack.esql.expression.function.aggregate.NumericAggregate;
Expand Down Expand Up @@ -43,7 +41,6 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;

Expand All @@ -55,12 +52,11 @@ public class AggregateMapper {
static final List<String> NUMERIC = List.of("Int", "Long", "Double");
static final List<String> SPATIAL = List.of("GeoPoint", "CartesianPoint");

/** List of all ESQL agg functions. */
/** List of all mappable ESQL agg functions (excludes surrogates like AVG = SUM/COUNT). */
static final List<? extends Class<? extends Function>> AGG_FUNCTIONS = List.of(
Count.class,
CountDistinct.class,
Max.class,
Median.class,
MedianAbsoluteDeviation.class,
Min.class,
Percentile.class,
Expand All @@ -79,7 +75,7 @@ record AggDef(Class<?> aggClazz, String type, String extra, boolean grouping) {}
private final HashMap<Expression, List<? extends NamedExpression>> cache = new HashMap<>();

AggregateMapper() {
this(AGG_FUNCTIONS.stream().filter(Predicate.not(SurrogateExpression.class::isAssignableFrom)).toList());
this(AGG_FUNCTIONS);
}

AggregateMapper(List<? extends Class<? extends Function>> aggregateFunctionClasses) {
Expand Down

0 comments on commit 829ea4d

Please sign in to comment.