Skip to content

Commit

Permalink
Improve verifier to not allow scalar functions over grouping
Browse files Browse the repository at this point in the history
  • Loading branch information
costin committed Feb 6, 2024
1 parent 4acfa4a commit feeedfd
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 117 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1147,41 +1147,15 @@ x:s | y:d | z:i
1.010097 | -16701.0 | 1
;

nestedAggsAliasedOverGrouping#[skip:-8.12.99,reason:supported in 8.13]
nestedAggsOverGroupingWithAlias#[skip:-8.12.99,reason:supported in 8.13]
FROM employees
| STATS e = max(languages) + languages by l = languages
;

e:l | l:i
0 | 0
;

nestedAggsAliasedOverGrouping#[skip:-8.12.99,reason:supported in 8.13]
FROM employees
| STATS max(languages) + languages by l = languages
;

max(languages) + languages:l | l:i
0 | 0
;

nestedAggsOverGroupingWithMulti#[skip:-8.12.99,reason:supported in 8.13]
FROM employees
| STATS max(languages + 1) , m = languages + min(salary + 1) by l = languages, s = salary
;

max(languages + 1): l | m:l | l:i | s:i
0 | 0 | 0 | 0
;

scalarFunctionOverGroupingColumn
FROM employees
| STATS length(first_name), count(1) by first_name
| SORT first_name
| STATS e = max(languages) + 1 by l = languages
| SORT l
| LIMIT 3
;

// TODO: update results
length(first_name):i| count(1):i | first_name:s
12 | 100 | 0
e:i | l:i
2 | 1
3 | 2
4 | 3
;
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,9 @@
import org.elasticsearch.xpack.ql.analyzer.AnalyzerRules;
import org.elasticsearch.xpack.ql.analyzer.AnalyzerRules.BaseAnalyzerRule;
import org.elasticsearch.xpack.ql.analyzer.AnalyzerRules.ParameterizedAnalyzerRule;
import org.elasticsearch.xpack.ql.capabilities.Resolvables;
import org.elasticsearch.xpack.ql.common.Failure;
import org.elasticsearch.xpack.ql.expression.Alias;
import org.elasticsearch.xpack.ql.expression.Attribute;
import org.elasticsearch.xpack.ql.expression.AttributeMap;
import org.elasticsearch.xpack.ql.expression.EmptyAttribute;
import org.elasticsearch.xpack.ql.expression.Expression;
import org.elasticsearch.xpack.ql.expression.Expressions;
Expand Down Expand Up @@ -298,7 +296,6 @@ protected LogicalPlan doRule(LogicalPlan plan) {
childrenOutput.addAll(output);
}


if (plan instanceof Drop d) {
return resolveDrop(d, childrenOutput);
}
Expand Down Expand Up @@ -326,45 +323,6 @@ protected LogicalPlan doRule(LogicalPlan plan) {
return plan.transformExpressionsUp(UnresolvedAttribute.class, ua -> maybeResolveAttribute(ua, childrenOutput));
}

private LogicalPlan resolveAggregate(Aggregate a, List<Attribute> childrenOutput) {
// if the grouping is unresolved but the aggs are, use the latter to resolve the former.
// e.g. STATS x AS a ... GROUP BY a

a = (Aggregate) a.transformExpressionsUp(UnresolvedAttribute.class, ua -> maybeResolveAttribute(ua, childrenOutput));

if (a.expressionsResolved() == false && Resolvables.resolved(a.aggregates())) {
List<Expression> groupings = a.groupings();
List<Expression> newGroupings = new ArrayList<>();

AttributeMap<Expression> resolved = new AttributeMap();
for (NamedExpression ne : a.aggregates()) {
if (ne instanceof Alias as) {
resolved.put(as.toAttribute(), as.child());
}
}
List<Attribute> keyList = new ArrayList<>(resolved.keySet());

boolean changed = false;
for (Expression grouping : groupings) {
if (grouping instanceof UnresolvedAttribute) {
Attribute maybeResolved = maybeResolveAttribute((UnresolvedAttribute) grouping, keyList);
if (maybeResolved != null) {
changed = true;
if (maybeResolved.resolved()) {
grouping = resolved.get(maybeResolved);
} else {
grouping = maybeResolved;
}
}
}
newGroupings.add(grouping);
}

a = changed ? new Aggregate(a.source(), a.child(), newGroupings, a.aggregates()) : a;
}
return a;
}

private LogicalPlan resolveMvExpand(MvExpand p, List<Attribute> childrenOutput) {
if (p.target() instanceof UnresolvedAttribute ua) {
Attribute resolved = maybeResolveAttribute(ua, childrenOutput);
Expand Down Expand Up @@ -716,6 +674,7 @@ private static LogicalPlan removeAggDuplicates(Aggregate agg) {
return agg;
}
}

private static class AddImplicitLimit extends ParameterizedRule<LogicalPlan, LogicalPlan, AnalyzerContext> {
@Override
public LogicalPlan apply(LogicalPlan logicalPlan, AnalyzerContext context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,26 +165,31 @@ private static void checkAggregate(LogicalPlan p, Set<Failure> failures, Attribu
// specified in the grouping clause
agg.aggregates().forEach(e -> {
var exp = Alias.unwrap(e);
if (exp.foldable()) {
failures.add(fail(exp, "expected an aggregate function but found [{}]", exp.sourceText()));
}
// traverse the tree to find invalid matches
checkInvalidNamedExpressionUsage(exp, nakedGroups, failures);
checkInvalidNamedExpressionUsage(exp, nakedGroups, failures, 0);
});
}
}

// traverse the expression and look either for an agg function or a grouping match
// stop either when no children are left, the leaves are literals or a reference attribute is given
private static void checkInvalidNamedExpressionUsage(Expression e, List<Expression> groups, Set<Failure> failures) {
private static void checkInvalidNamedExpressionUsage(Expression e, List<Expression> groups, Set<Failure> failures, int level) {
// found an aggregate, constant or a group, bail out
if (e instanceof AggregateFunction af) {
af.field().forEachDown(AggregateFunction.class, f -> {
failures.add(fail(f, "nested aggregations [{}] not allowed inside other aggregations [{}]", f, af));
});
return;
} else if (e.foldable()) {
// don't do anything
}
if (e.foldable() || groups.contains(e))

{
return;
// don't allow nested groupings for now stats substring(group) by group as we don't optimize yet for them
else if (groups.contains(e)) {
if (level != 0) {
failures.add(fail(e, "scalar functions over groupings [{}] not allowed yet", e.sourceText()));
}
}
// if a reference is found, mark it as an error
else if (e instanceof NamedExpression ne) {
Expand All @@ -193,7 +198,7 @@ else if (e instanceof NamedExpression ne) {
// other keep on going
else {
for (Expression child : e.children()) {
checkInvalidNamedExpressionUsage(child, groups, failures);
checkInvalidNamedExpressionUsage(child, groups, failures, level + 1);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,7 @@ static String toString(Expression ex) {
}

static String extractString(Expression ex) {
return ex instanceof NamedExpression ne
? ne.name()
: limitToString(ex.sourceText()).replace(' ', '_');
return ex instanceof NamedExpression ne ? ne.name() : limitToString(ex.sourceText()).replace(' ', '_');
}

static String limitToString(String string) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1194,18 +1194,6 @@ public void testAggsWithoutAggAndFollowingCommand() throws Exception {
assertEquals(agg.groupings(), agg.aggregates());
}

public void testAggsWithPartialGrouping() {
analyze("from test| stats max(languages) by l = languages + 1 + 2 + 3");
}

public void testAggsWithExpressionOverAggs() {
analyze("from test | stats max(languages + 1) , m = languages + min(salary + 1) by l = languages, s = salary");
}

public void testAggScalarOverGroupingColumn() {
analyze("from test | stats length(first_name), count(1) by first_name");
}

public void testEmptyEsRelationOnLimitZeroWithCount() throws IOException {
var query = """
from test*
Expand Down Expand Up @@ -1576,7 +1564,7 @@ public void testLiteralInAggregateNoGrouping() {
|stats 1
"""));

assertThat(e.getMessage(), containsString("expected an aggregate function but got [1]"));
assertThat(e.getMessage(), containsString("expected an aggregate function but found [1]"));
}

public void testLiteralBehindEvalInAggregateNoGrouping() {
Expand All @@ -1586,7 +1574,7 @@ public void testLiteralBehindEvalInAggregateNoGrouping() {
|stats x
"""));

assertThat(e.getMessage(), containsString("expected an aggregate function but got [x] referencing [1]"));
assertThat(e.getMessage(), containsString("column [x] must appear in the STATS BY clause or be used in an aggregate function"));
}

public void testLiteralsInAggregateNoGrouping() {
Expand All @@ -1595,7 +1583,7 @@ public void testLiteralsInAggregateNoGrouping() {
|stats 1 + 2
"""));

assertThat(e.getMessage(), containsString("expected an aggregate function but got [1 + 2]"));
assertThat(e.getMessage(), containsString("expected an aggregate function but found [1 + 2]"));
}

public void testLiteralsBehindEvalInAggregateNoGrouping() {
Expand All @@ -1605,7 +1593,7 @@ public void testLiteralsBehindEvalInAggregateNoGrouping() {
|stats x
"""));

assertThat(e.getMessage(), containsString("expected an aggregate function but got [x]"));
assertThat(e.getMessage(), containsString("column [x] must appear in the STATS BY clause or be used in an aggregate function"));
}

public void testFoldableInAggregateWithGrouping() {
Expand All @@ -1614,7 +1602,7 @@ public void testFoldableInAggregateWithGrouping() {
|stats 1 + 2 by languages
"""));

assertThat(e.getMessage(), containsString("expected an aggregate function but got [1 + 2]"));
assertThat(e.getMessage(), containsString("expected an aggregate function but found [1 + 2]"));
}

public void testLiteralsInAggregateWithGrouping() {
Expand All @@ -1623,7 +1611,7 @@ public void testLiteralsInAggregateWithGrouping() {
|stats "a" by languages
"""));

assertThat(e.getMessage(), containsString("expected an aggregate function but got [\"a\"] of type [Literal]"));
assertThat(e.getMessage(), containsString("expected an aggregate function but found [\"a\"]"));
}

public void testFoldableBehindEvalInAggregateWithGrouping() {
Expand All @@ -1633,7 +1621,7 @@ public void testFoldableBehindEvalInAggregateWithGrouping() {
|stats x by languages
"""));

assertThat(e.getMessage(), containsString("expected an aggregate function but got [x] referencing [1 + 2]"));
assertThat(e.getMessage(), containsString("column [x] must appear in the STATS BY clause or be used in an aggregate function"));
}

public void testFoldableInGrouping() {
Expand All @@ -1651,14 +1639,10 @@ public void testScalarFunctionsInStats() {
|stats salary % 3 by languages
"""));

assertThat(e.getMessage(), containsString("expected an aggregate function but got [salary % 3] of type [Mod]"));
}

public void testGroupingInAggs() {
assertProjection("""
from test
|stats e = salary + max(salary) by languages
""", "e", "languages");
assertThat(
e.getMessage(),
containsString("column [salary] must appear in the STATS BY clause or be used in an aggregate function")
);
}

public void testDeferredGroupingInStats() {
Expand All @@ -1668,7 +1652,7 @@ public void testDeferredGroupingInStats() {
|stats x by first_name
"""));

assertThat(e.getMessage(), containsString("expected an aggregate function but got [x] referencing [first_name]"));
assertThat(e.getMessage(), containsString("column [x] must appear in the STATS BY clause or be used in an aggregate function"));
}

public void testUnsupportedTypesInStats() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@ public void testAggsExpressionsInStatsAggs() {
"1:44: column [salary] must appear in the STATS BY clause or be used in an aggregate function",
error("from test | eval z = 2 | stats x = avg(z), salary by emp_no")
);
assertEquals(
"1:26: scalar functions over groupings [first_name] not allowed yet",
error("from test | stats length(first_name), count(1) by first_name")
);
assertEquals(
"1:36: scalar functions over groupings [languages] not allowed yet",
error("from test | stats max(languages) + languages by l = languages")
);
assertEquals(
"1:23: nested aggregations [max(salary)] not allowed inside other aggregations [max(max(salary))]",
error("from test | stats max(max(salary)) by first_name")
Expand All @@ -81,6 +89,7 @@ public void testAggsExpressionsInStatsAggs() {
"1:23: second argument of [count_distinct(languages, languages)] must be a constant, received [languages]",
error("from test | stats x = count_distinct(languages, languages) by emp_no")
);

}

public void testAggsInsideGrouping() {
Expand Down Expand Up @@ -118,6 +127,27 @@ public void testAggsInsideEval() throws Exception {
assertEquals("1:29: aggregate function [max(b)] not allowed outside STATS command", error("row a = 1, b = 2 | eval x = max(b)"));
}

public void testAggsWithExpressionOverAggs() {
assertEquals(
"1:44: scalar functions over groupings [languages] not allowed yet",
error("from test | stats max(languages + 1) , m = languages + min(salary + 1) by l = languages, s = salary")
);
}

public void testAggScalarOverGroupingColumn() {
assertEquals(
"1:26: scalar functions over groupings [first_name] not allowed yet",
error("from test | stats length(first_name), count(1) by first_name")
);
}

public void testGroupingInAggs() {
assertEquals("2:12: column [salary] must appear in the STATS BY clause or be used in an aggregate function", error("""
from test
|stats e = salary + max(salary) by languages
"""));
}

public void testDoubleRenamingField() {
assertEquals(
"1:44: Column [emp_no] renamed to [r1] and is no longer available [emp_no as r3]",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1800,11 +1800,11 @@ public void testSimpleAvgReplacement() {
var agg = as(limit.child(), Aggregate.class);
var aggs = agg.aggregates();
var a = as(aggs.get(0), Alias.class);
assertThat(a.name(), startsWith("__a_SUM@"));
assertThat(a.name(), startsWith("$$SUM$a$"));
var sum = as(a.child(), Sum.class);

a = as(aggs.get(1), Alias.class);
assertThat(a.name(), startsWith("__a_COUNT@"));
assertThat(a.name(), startsWith("$$COUNT$a$"));
var count = as(a.child(), Count.class);

assertThat(Expressions.names(agg.groupings()), contains("last_name"));
Expand Down Expand Up @@ -1861,7 +1861,7 @@ public void testSemiClashingAvgReplacement() {
var agg = as(limit.child(), Aggregate.class);
var aggs = agg.aggregates();
var a = as(aggs.get(0), Alias.class);
assertThat(a.name(), startsWith("__a_COUNT@"));
assertThat(a.name(), startsWith("$$COUNT$a$0"));
var sum = as(a.child(), Count.class);

a = as(aggs.get(1), Alias.class);
Expand Down

0 comments on commit feeedfd

Please sign in to comment.