Skip to content

Commit

Permalink
Basic conceptual view behind InlineStats
Browse files Browse the repository at this point in the history
Introduce InlineStats as InlineAggregate (to reuse the existing
 terminology) which acts only as syntactic sugar; it get optimized to a
 join (the type depends on whether or not there are groups).
The logical optimizer tests contains a series of basic tests showing
 what the plan looks like - note these are failing.
  • Loading branch information
costin committed Apr 19, 2024
1 parent 8efe77b commit 0ab07e1
Show file tree
Hide file tree
Showing 8 changed files with 485 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,11 @@
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
import org.elasticsearch.xpack.esql.plan.logical.Eval;
import org.elasticsearch.xpack.esql.plan.logical.Grok;
import org.elasticsearch.xpack.esql.plan.logical.InlineAggregate;
import org.elasticsearch.xpack.esql.plan.logical.MvExpand;
import org.elasticsearch.xpack.esql.plan.logical.TopN;
import org.elasticsearch.xpack.esql.plan.logical.join.Join;
import org.elasticsearch.xpack.esql.plan.logical.join.JoinType;
import org.elasticsearch.xpack.esql.plan.logical.join.JoinTypes;
import org.elasticsearch.xpack.esql.plan.logical.local.EsqlProject;
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
import org.elasticsearch.xpack.esql.plan.physical.DissectExec;
Expand Down Expand Up @@ -297,7 +299,7 @@ public static List<PlanNameRegistry.Entry> namedTypeEntries() {
of(LogicalPlan.class, EsqlProject.class, PlanNamedTypes::writeEsqlProject, PlanNamedTypes::readEsqlProject),
of(LogicalPlan.class, Filter.class, PlanNamedTypes::writeFilter, PlanNamedTypes::readFilter),
of(LogicalPlan.class, Grok.class, PlanNamedTypes::writeGrok, PlanNamedTypes::readGrok),
of(LogicalPlan.class, InlineAggregate.class, PlanNamedTypes::writeInlineAggregate, PlanNamedTypes::readInlineAggregate),
of(LogicalPlan.class, Join.class, PlanNamedTypes::writeJoin, PlanNamedTypes::readJoin),
of(LogicalPlan.class, Limit.class, PlanNamedTypes::writeLimit, PlanNamedTypes::readLimit),
of(LogicalPlan.class, MvExpand.class, PlanNamedTypes::writeMvExpand, PlanNamedTypes::readMvExpand),
of(LogicalPlan.class, OrderBy.class, PlanNamedTypes::writeOrderBy, PlanNamedTypes::readOrderBy),
Expand Down Expand Up @@ -920,20 +922,16 @@ static void writeGrok(PlanStreamOutput out, Grok grok) throws IOException {
writeAttributes(out, grok.extractedFields());
}

static InlineAggregate readInlineAggregate(PlanStreamInput in) throws IOException {
return new InlineAggregate(
in.readSource(),
in.readLogicalPlanNode(),
in.readCollectionAsList(readerFromPlanReader(PlanStreamInput::readExpression)),
readNamedExpressions(in)
);
static Join readJoin(PlanStreamInput in) throws IOException {
return new Join(in.readSource(), in.readLogicalPlanNode(), in.readLogicalPlanNode(), readJoinType(in), in.readExpression());
}

static void writeInlineAggregate(PlanStreamOutput out, InlineAggregate aggregate) throws IOException {
static void writeJoin(PlanStreamOutput out, Join join) throws IOException {
out.writeNoSource();
out.writeLogicalPlanNode(aggregate.child());
out.writeCollection(aggregate.groupings(), writerFromPlanWriter(PlanStreamOutput::writeExpression));
writeNamedExpressions(out, aggregate.aggregates());
out.writeLogicalPlanNode(join.left());
out.writeLogicalPlanNode(join.right());
writeJoinType(out, join.type());
out.writeExpression(join.condition());
}

static Limit readLimit(PlanStreamInput in) throws IOException {
Expand Down Expand Up @@ -1941,6 +1939,30 @@ static void writeDissectParser(PlanStreamOutput out, Parser dissectParser) throw
out.writeString(dissectParser.appendSeparator());
}

static JoinType readJoinType(PlanStreamInput in) throws IOException {
switch (in.readByte()) {
case 0:
return in.readEnum(JoinTypes.CoreJoinType.class);
case 1:
return new JoinTypes.UsingJoinType(in.readEnum(JoinTypes.CoreJoinType.class), readAttributes(in));
default:
throw new IllegalArgumentException("Unknown join type");
}
}

static void writeJoinType(PlanStreamOutput out, JoinType joinType) throws IOException {
if (joinType instanceof JoinTypes.CoreJoinType coreJoinType) {
out.writeByte((byte) 0);
out.writeEnum(coreJoinType);
} else if (joinType instanceof JoinTypes.UsingJoinType usingJoinType) {
out.writeByte((byte) 1);
out.writeEnum(usingJoinType.coreJoin());
writeAttributes(out, usingJoinType.columns());
} else {
throw new IllegalArgumentException("Unknown join type: " + joinType);
}
}

static Log readLog(PlanStreamInput in) throws IOException {
return new Log(in.readSource(), in.readExpression(), in.readOptionalNamed(Expression.class));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,17 @@
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In;
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
import org.elasticsearch.xpack.esql.plan.logical.Eval;
import org.elasticsearch.xpack.esql.plan.logical.InlineAggregate;
import org.elasticsearch.xpack.esql.plan.logical.MvExpand;
import org.elasticsearch.xpack.esql.plan.logical.RegexExtract;
import org.elasticsearch.xpack.esql.plan.logical.TopN;
import org.elasticsearch.xpack.esql.plan.logical.join.Join;
import org.elasticsearch.xpack.esql.plan.logical.join.JoinType;
import org.elasticsearch.xpack.esql.plan.logical.join.JoinTypes;
import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation;
import org.elasticsearch.xpack.esql.plan.logical.local.LocalSupplier;
import org.elasticsearch.xpack.esql.planner.PlannerUtils;
import org.elasticsearch.xpack.esql.type.EsqlDataTypes;
import org.elasticsearch.xpack.ql.analyzer.AnalyzerRules;
import org.elasticsearch.xpack.ql.common.Failures;
import org.elasticsearch.xpack.ql.expression.Alias;
import org.elasticsearch.xpack.ql.expression.Attribute;
Expand Down Expand Up @@ -116,6 +119,7 @@ protected static Batch<LogicalPlan> substitutions() {
return new Batch<>(
"Substitutions",
Limiter.ONCE,
new ReplaceInlineAggsWithJoin(),
new RemoveStatsOverride(),
// first extract nested expressions inside aggs
new ReplaceStatsNestedExpressionWithEval(),
Expand Down Expand Up @@ -1305,7 +1309,11 @@ protected Expression regexToEquals(RegexMatch<?> regexMatch, Literal literal) {
static class ReplaceStatsNestedExpressionWithEval extends OptimizerRules.OptimizerRule<Aggregate> {

@Override
protected LogicalPlan rule(Aggregate aggregate) {
protected Aggregate rule(Aggregate aggregate) {
return replaceNestedExpressions(aggregate);
}

static Aggregate replaceNestedExpressions(Aggregate aggregate) {
List<Alias> evals = new ArrayList<>();
Map<String, Attribute> evalNames = new HashMap<>();
List<Expression> newGroupings = new ArrayList<>(aggregate.groupings());
Expand Down Expand Up @@ -1616,19 +1624,18 @@ private LogicalPlan rule(Eval eval) {
* becomes
* STATS max($x + 1) BY $x = a + b
*/
private static class RemoveStatsOverride extends AnalyzerRules.AnalyzerRule<Aggregate> {
private static class RemoveStatsOverride extends OptimizerRules.OptimizerRule<Aggregate> {

@Override
protected boolean skipResolved() {
return false;
RemoveStatsOverride() {
super(TransformDirection.UP);
}

@Override
protected LogicalPlan rule(Aggregate agg) {
return agg.resolved() ? removeAggDuplicates(agg) : agg;
protected Aggregate rule(Aggregate agg) {
return removeAggDuplicates(agg);
}

private static Aggregate removeAggDuplicates(Aggregate agg) {
static Aggregate removeAggDuplicates(Aggregate agg) {
var groupings = agg.groupings();
var aggregates = agg.aggregates();

Expand All @@ -1655,6 +1662,64 @@ private static <T extends Expression> List<T> removeDuplicateNames(List<T> list)
}
}

/**
* Replace inline aggregations with a logical join mainly for reusing the join optimizations across
* different commands.
* An inlinestats with no grouping is replaced by a cross join
* FROM index | INLINESTATS mx = max(x)
* becomes
* FROM index | CROSS JOIN [FROM index | STATS x=max(x)]
* If grouping is present, a left join is used for replacing
* FROM index | INLINESTATS mx = max(x) by a, b
* becomes
* FROM index | INNER JOIN [FROM index | STATS x=max(x) by a, b] ON a, b
* Expressions specified in the grouping are extracted as EVALs before grouping so they can be used later for joining
* the two sides:
* FROM index | INLINESTATS mx = max(x) by g = a + b
* becomes
* FROM index | EVAL g = a + b | INNER JOIN [FROM index | STATS x=max(x) by g] ON g
*/
private static class ReplaceInlineAggsWithJoin extends OptimizerRules.OptimizerRule<InlineAggregate> {

ReplaceInlineAggsWithJoin() {
super(TransformDirection.UP);
}

@Override
protected LogicalPlan rule(InlineAggregate inlineAgg) {
JoinType joinType;
Expression condition = Literal.TRUE;

LogicalPlan child = inlineAgg.child();
// create the aggregate but since it can accept expressions, optimize its groups first as they are used as joining key
Aggregate right = new Aggregate(inlineAgg.source(), child, new ArrayList<>(inlineAgg.groupings()), inlineAgg.aggregates());
// remove group duplicates
right = RemoveStatsOverride.removeAggDuplicates(right);
// then the grouping expressions
right = ReplaceStatsNestedExpressionWithEval.replaceNestedExpressions(right);
// the aggs optimization will be picked up by the separate rule if necessary

// set the left side as the child of the right to have the latest join keys (such as potential evals of the aggregate)
LogicalPlan left = right.child();

List<Expression> groupings = right.groupings();
// grouping specified -> INNER (could be left but since it's self eq-join, they are the equivalent)
if (groupings.size() > 0) {
List<Attribute> groupingAttributes = new ArrayList<>(groupings.size());
for (Expression grouping : groupings) {
groupingAttributes.add(Expressions.attribute(grouping));
}
joinType = new JoinTypes.UsingJoinType(JoinTypes.INNER, groupingAttributes);
}
// no grouping -> CROSS
else {
joinType = JoinTypes.CROSS;
}

return new Join(inlineAgg.source(), left, right, joinType, condition);
}
}

private abstract static class ParameterizedOptimizerRule<SubPlan extends LogicalPlan, P> extends ParameterizedRule<
SubPlan,
LogicalPlan,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.core.Tuple;
import org.elasticsearch.dissect.DissectException;
import org.elasticsearch.dissect.DissectParser;
import org.elasticsearch.monitor.os.OsStats;
import org.elasticsearch.xpack.esql.VerificationException;
import org.elasticsearch.xpack.esql.expression.UnresolvedNamePattern;
import org.elasticsearch.xpack.esql.parser.EsqlBaseParser.MetadataOptionContext;
Expand Down Expand Up @@ -273,7 +274,7 @@ public PlanFactory visitInlinestatsCommand(EsqlBaseParser.InlinestatsCommandCont
List<NamedExpression> aggregates = new ArrayList<>(visitFields(ctx.stats));
List<NamedExpression> groupings = visitGrouping(ctx.grouping);
aggregates.addAll(groupings);
return input -> new InlineAggregate(source(ctx), input, new ArrayList<>(groupings), aggregates);
return input -> new InlineAggregate(source(ctx), input, groupings, aggregates);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,27 @@
import org.elasticsearch.xpack.esql.expression.NamedExpressions;
import org.elasticsearch.xpack.ql.capabilities.Resolvables;
import org.elasticsearch.xpack.ql.expression.Attribute;
import org.elasticsearch.xpack.ql.expression.Expression;
import org.elasticsearch.xpack.ql.expression.Expressions;
import org.elasticsearch.xpack.ql.expression.NamedExpression;
import org.elasticsearch.xpack.ql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.ql.plan.logical.UnaryPlan;
import org.elasticsearch.xpack.ql.tree.NodeInfo;
import org.elasticsearch.xpack.ql.tree.Source;
import org.elasticsearch.xpack.ql.util.CollectionUtils;

import java.util.Collection;
import java.util.List;
import java.util.Objects;

public class InlineAggregate extends UnaryPlan {

private final List<Expression> groupings;
private final List<? extends NamedExpression> groupings;
private final List<? extends NamedExpression> aggregates;
private List<Attribute> lazyOutput;

public InlineAggregate(Source source, LogicalPlan child, List<Expression> groupings, List<? extends NamedExpression> aggregates) {
public InlineAggregate(
Source source,
LogicalPlan child,
List<? extends NamedExpression> groupings,
List<? extends NamedExpression> aggregates
) {
super(source, child);
this.groupings = groupings;
this.aggregates = aggregates;
Expand All @@ -45,7 +46,7 @@ public InlineAggregate replaceChild(LogicalPlan newChild) {
return new InlineAggregate(source(), newChild, groupings, aggregates);
}

public List<Expression> groupings() {
public List<? extends NamedExpression> groupings() {
return groupings;
}

Expand Down
Loading

0 comments on commit 0ab07e1

Please sign in to comment.