Skip to content

Commit

Permalink
Fix small bug that caused replacement of expressions inside aggregations
Browse files Browse the repository at this point in the history
 to be skipped despite being applied
Improved Verifier to not repeat error messages in case for Aggregates
Removed verification heuristics for missing columns as functions as it
 was too broad
  • Loading branch information
costin committed Apr 4, 2024
1 parent c9ae0ff commit 91b9134
Show file tree
Hide file tree
Showing 6 changed files with 323 additions and 156 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1188,6 +1188,58 @@ e:i | l:i
4 | 3
;

nestedAggsOverGroupingExpressionWithoutAlias#[skip:-8.13.99,reason:supported in 8.14]
FROM employees
| STATS e = max(languages + emp_no) + 1 by languages + emp_no
| SORT e
| LIMIT 3
;

e:i | languages + emp_no:i
10004 | 10003
10007 | 10006
10008 | 10007
;

nestedAggsOverGroupingExpressionMultiGroupWithoutAlias#[skip:-8.13.99,reason:supported in 8.14]
FROM employees
| STATS e = max(languages + emp_no + 10) + 1 by languages + emp_no, emp_no % 3
| SORT e
| LIMIT 3
;

e:i | languages + emp_no:i | emp_no % 3:i
10014 | 10003 | 2
10017 | 10006 | 0
10018 | 10007 | 1
;

nestedAggsOverGroupingExpressionWithAlias#[skip:-8.13.99,reason:supported in 8.14]
FROM employees
| STATS e = max(languages + emp_no + 10) + 1 by languages + emp_no
| SORT e
| LIMIT 3
;

e:i | languages + emp_no:i
10014 | 10003
10017 | 10006
10018 | 10007
;

nestedAggsOverGroupingExpressionWithAlias#[skip:-8.13.99,reason:supported in 8.14]
FROM employees
| STATS e = max(a), f = min(a), g = count(a) + 1 by a = languages + emp_no
| SORT a
| LIMIT 3
;

e: i | f:i | g:l | a:i
10003 | 10003 | 2 | 10003
10006 | 10006 | 2 | 10006
10007 | 10007 | 3 | 10007
;

nestedAggsOverGroupingTwiceWithAlias#[skip:-8.12.99,reason:supported in 8.13]
FROM employees
| STATS vals = COUNT() BY x = emp_no, x = languages
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import org.elasticsearch.xpack.ql.expression.Expressions;
import org.elasticsearch.xpack.ql.expression.NamedExpression;
import org.elasticsearch.xpack.ql.expression.TypeResolutions;
import org.elasticsearch.xpack.ql.expression.UnresolvedAttribute;
import org.elasticsearch.xpack.ql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.ql.expression.predicate.BinaryOperator;
import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparison;
Expand All @@ -47,6 +46,7 @@
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import java.util.function.Consumer;
import java.util.stream.Stream;

import static org.elasticsearch.xpack.ql.analyzer.VerifierChecks.checkFilterConditionType;
Expand Down Expand Up @@ -89,16 +89,8 @@ else if (p.resolved()) {
p.forEachExpressionUp(Alias.class, a -> aliases.put(a.toAttribute(), a.child()));
return;
}
// handle aggregate first to disambiguate between missing fields or incorrect function declaration
if (p instanceof Aggregate aggregate) {
for (NamedExpression agg : aggregate.aggregates()) {
var child = Alias.unwrap(agg);
if (child instanceof UnresolvedAttribute) {
failures.add(fail(child, "invalid stats declaration; [{}] is not an aggregate function", child.sourceText()));
}
}
}
p.forEachExpression(e -> {

Consumer<Expression> unresolvedExpressions = e -> {
// everything is fine, skip expression
if (e.resolved()) {
return;
Expand All @@ -120,7 +112,20 @@ else if (p.resolved()) {
failures.add(fail(ae, ae.typeResolved().message()));
}
});
});
};

// aggregates duplicate grouping inside aggs - to avoid potentially confusing messages, we only check the aggregates
if (p instanceof Aggregate agg) {
// do groupings first
var groupings = agg.groupings();
groupings.forEach(unresolvedExpressions);
// followed by just the aggregates (to avoid going through the groups again)
var aggs = agg.aggregates();
int size = aggs.size() - groupings.size();
aggs.subList(0, size).forEach(unresolvedExpressions);
} else {
p.forEachExpression(unresolvedExpressions);
}
});

// in case of failures bail-out as all other checks will be redundant
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,11 @@ protected static Batch<LogicalPlan> substitutions() {
return new Batch<>(
"Substitutions",
Limiter.ONCE,
new RemoveAggregateOverrides(),
// first extract nested aggs top-level - this simplifies the rest of the rules
new ReplaceStatsAggExpressionWithEval(),
// second extract nested aggs inside of them
new RemoveStatsOverride(),
// first extract nested expressions inside aggs
new ReplaceStatsNestedExpressionWithEval(),
// then extract nested aggs top-level
new ReplaceStatsAggExpressionWithEval(),
// lastly replace surrogate functions
new SubstituteSurrogates(),
new ReplaceRegexMatch(),
Expand Down Expand Up @@ -1259,9 +1259,9 @@ protected LogicalPlan rule(Aggregate aggregate) {
Attribute attr = expToAttribute.computeIfAbsent(field.canonical(), k -> {
Alias newAlias = new Alias(k.source(), syntheticName(k, af, counter[0]++), null, k, null, true);
evals.add(newAlias);
aggsChanged.set(true);
return newAlias.toAttribute();
});
aggsChanged.set(true);
// replace field with attribute
List<Expression> newChildren = new ArrayList<>(af.children());
newChildren.set(0, attr);
Expand Down Expand Up @@ -1506,8 +1506,12 @@ private LogicalPlan rule(Eval eval) {
* STATS BY x = c + 10
* That is the last declaration for a given alias, overrides all the other declarations, with
* groups having priority vs aggregates.
* Separately, it replaces expressions used as group keys inside the aggregates with references:
* STATS max(a + b + 1) BY a + b
* becomes
* STATS max($x + 1) BY $x = a + b
*/
private static class RemoveAggregateOverrides extends AnalyzerRules.AnalyzerRule<Aggregate> {
private static class RemoveStatsOverride extends AnalyzerRules.AnalyzerRule<Aggregate> {

@Override
protected boolean skipResolved() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1759,7 +1759,7 @@ public void testFoldableInGrouping() {
|stats x by 1
"""));

assertThat(e.getMessage(), containsString("[x] is not an aggregate function"));
assertThat(e.getMessage(), containsString("Unknown column [x]"));
}

public void testScalarFunctionsInStats() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,11 @@ public void testNestedAggField() {
assertEquals("1:27: Unknown column [avg]", error("from test | stats c = avg(avg)"));
}

public void testUnfinishedAggFunction() {
assertEquals("1:23: invalid stats declaration; [avg] is not an aggregate function", error("from test | stats c = avg"));
public void testNotFoundFieldInNestedFunction() {
assertEquals("""
1:30: Unknown column [missing]
line 1:43: Unknown column [not_found]
line 1:23: Unknown column [avg]""", error("from test | stats c = avg by missing + 1, not_found"));
}

public void testSpatialSort() {
Expand Down
Loading

0 comments on commit 91b9134

Please sign in to comment.