Skip to content

Commit

Permalink
SQL: Fix behaviour of COUNT(DISTINCT <literal>) (#56869) (#56931)
Browse files Browse the repository at this point in the history
Previously `COUNT(DISTINCT <literal>)` was returning the same result
as `COUNT(<literal>)` which is not correct as it should always return 1
if there is at least one matching row (bucket if there is a GROUP BY),
or 0 otherwise.

(cherry picked from commit 7f7d756)
  • Loading branch information
matriv committed May 19, 2020
1 parent 91d26c1 commit d1d6605
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 32 deletions.
9 changes: 8 additions & 1 deletion x-pack/plugin/sql/qa/src/main/resources/agg.sql-spec
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,21 @@ SELECT gender g, languages l, COUNT(*) c FROM "test_emp" GROUP BY g, l ORDER BY
aggCountDistinctWithAliasAndGroupBy
SELECT COUNT(*) cnt, COUNT(DISTINCT first_name) as names, gender FROM test_emp GROUP BY gender ORDER BY gender;

localCount
SELECT COUNT(1), COUNT(22), COUNT('foo'), COUNT(DISTINCT 1), COUNT(DISTINCT 22), COUNT(DISTINCT 'foo');
localSum
SELECT CAST(SUM(1) AS BIGINT);
localSumWithAlias
SELECT CAST(SUM(1) AS BIGINT) AS s, CAST(SUM(1) AS BIGINT);
localMax
SELECT MAX(1);
localAggregates
SELECT CAST(SUM(1) AS BIGINT), CAST(SUM(123) AS BIGINT), MAX(1), MAX(32), MIN(3), MIN(55+2) AS mn, CAST(AVG(33/3) AS INTEGER) AS av, CAST(AVG(1) AS INTEGER);
SELECT CAST(SUM(1) AS BIGINT), CAST(SUM(123) AS BIGINT), MAX(1), MAX(32), MIN(3), MIN(55+2) AS mn, CAST(AVG(33/3) AS INTEGER) AS av, CAST(AVG(1) AS INTEGER);

countOfLiteralsFromIndex
SELECT COUNT(1), COUNT(22), COUNT('foo'), COUNT(DISTINCT 1), COUNT(DISTINCT 22), COUNT(DISTINCT 'foo') FROM test_emp;
countOfLiteralsFromIndexWithGroupBy
SELECT COUNT(1), COUNT(22), COUNT('foo'), COUNT(DISTINCT 1), COUNT(DISTINCT 22), COUNT(DISTINCT 'foo') FROM test_emp GROUP BY gender ORDER BY gender;
aggregatesOfLiteralsFromIndex
SELECT MAX(1), MIN(1), CAST(SUM(1) AS BIGINT), CAST(AVG(1) AS INTEGER), COUNT(1) FROM test_emp;
aggregatesOfLiteralsFromIndex_WithNoMatchingFilter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@
import org.elasticsearch.xpack.sql.expression.function.aggregate.MatrixStatsEnclosed;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Max;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Min;
import org.elasticsearch.xpack.sql.expression.function.aggregate.NumericAggregate;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Percentile;
import org.elasticsearch.xpack.sql.expression.function.aggregate.PercentileRank;
import org.elasticsearch.xpack.sql.expression.function.aggregate.PercentileRanks;
Expand Down Expand Up @@ -788,27 +787,34 @@ private Expression simplify(BinaryComparison bc) {
/**
* Any numeric aggregates (avg, min, max, sum) acting on literals are converted to an iif(count(1)=0, null, literal*count(1)) for sum,
* and to iif(count(1)=0,null,literal) for the other three.
* Additionally count(DISTINCT literal) is converted to iif(count(1)=0, 0, 1).
*/
private static class ReplaceAggregatesWithLiterals extends OptimizerRule<LogicalPlan> {

@Override
protected LogicalPlan rule(LogicalPlan p) {
return p.transformExpressionsDown(e -> {
if (e instanceof Min || e instanceof Max || e instanceof Avg || e instanceof Sum) {
NumericAggregate a = (NumericAggregate) e;
if (e instanceof Min || e instanceof Max || e instanceof Avg || e instanceof Sum ||
(e instanceof Count && ((Count) e).distinct())) {

AggregateFunction a = (AggregateFunction) e;

if (a.field().foldable()) {
Expression countOne = new Count(a.source(), new Literal(Source.EMPTY, 1, a.dataType()), false);
Equals countEqZero = new Equals(a.source(), countOne, new Literal(Source.EMPTY, 0, a.dataType()));
Expression argument = a.field();
Literal foldedArgument = new Literal(argument.source(), argument.fold(), a.dataType());

Expression iifResult = Literal.NULL;
Expression iifElseResult = foldedArgument;
if (e instanceof Sum) {
iifElseResult = new Mul(a.source(), countOne, foldedArgument);
} else if (e instanceof Count) {
iifResult = new Literal(Source.EMPTY, 0, e.dataType());
iifElseResult = new Literal(Source.EMPTY, 1, e.dataType());
}

return new Iif(a.source(), countEqZero, Literal.NULL, iifElseResult);
return new Iif(a.source(), countEqZero, iifResult, iifElseResult);
}
}
return e;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
import java.time.Duration;
import java.time.ZonedDateTime;
import java.time.ZoneId;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.StringJoiner;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static org.elasticsearch.test.ESTestCase.randomAlphaOfLength;
import static org.elasticsearch.test.ESTestCase.randomBoolean;
Expand Down Expand Up @@ -97,5 +102,41 @@ public static Literal literal(Source source, Object value) {
}
return new Literal(source, value, SqlDataTypes.fromJava(value));
}

public static String randomOrderByAndLimit(int noOfSelectArgs, Random rnd) {
StringBuilder sb = new StringBuilder();
if (randomBoolean()) {
sb.append(" ORDER BY ");

List<Integer> shuffledArgIndices = IntStream.range(1, noOfSelectArgs + 1).boxed().collect(Collectors.toList());
Collections.shuffle(shuffledArgIndices, rnd);
for (int i = 0; i < noOfSelectArgs; i++) {
sb.append(shuffledArgIndices.get(i));
switch (randomInt(2)) {
case 0:
sb.append(" DESC");
break;
case 1:
sb.append(" ASC");
break;
}
switch (randomInt(2)) {
case 0:
sb.append(" NULLS FIRST");
break;
case 1:
sb.append(" NULLS LAST");
break;
}
if (i < noOfSelectArgs - 1) {
sb.append(", ");
}
}
}
if (randomBoolean()) {
sb.append(" LIMIT ").append(randomIntBetween(1, 100));
}
return sb.toString();
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -175,62 +175,76 @@ public void testLocalExecWithoutFromClauseWithPrunedFilter() {
assertThat(ee.output().get(0).toString(), startsWith("E(){r}#"));
}

public void testLocalExecWithAggs() {
PhysicalPlan p = plan("SELECT MIN(10), MAX(123), SUM(1), AVG(0)");
public void testLocalExecWithCount() {
PhysicalPlan p = plan("SELECT COUNT(10), COUNT(DISTINCT 20)" + randomOrderByAndLimit(2));
assertEquals(LocalExec.class, p.getClass());
LocalExec le = (LocalExec) p;
assertEquals(SingletonExecutable.class, le.executable().getClass());
SingletonExecutable ee = (SingletonExecutable) le.executable();
assertEquals(4, ee.output().size());
assertThat(ee.output().get(0).toString(), startsWith("MIN(10){r}#"));
assertThat(ee.output().get(1).toString(), startsWith("MAX(123){r}#"));
assertThat(ee.output().get(2).toString(), startsWith("SUM(1){r}#"));
assertThat(ee.output().get(3).toString(), startsWith("AVG(0){r}#"));
assertEquals(2, ee.output().size());
assertThat(ee.output().get(0).toString(), startsWith("COUNT(10){r}#"));
assertThat(ee.output().get(1).toString(), startsWith("COUNT(DISTINCT 20){r}#"));
}

public void testLocalExecWithAggsAndWhereFalseFilter() {
PhysicalPlan p = plan("SELECT SUM(10) WHERE 2 > 3");
public void testLocalExecWithCountAndWhereFalseFilter() {
PhysicalPlan p = plan("SELECT COUNT(10), COUNT(DISTINCT 20) WHERE 1 = 2" + randomOrderByAndLimit(2));
assertEquals(LocalExec.class, p.getClass());
LocalExec le = (LocalExec) p;
assertEquals(EmptyExecutable.class, le.executable().getClass());
EmptyExecutable ee = (EmptyExecutable) le.executable();
assertEquals(1, ee.output().size());
assertThat(ee.output().get(0).toString(), startsWith("SUM(10){r}#"));
assertEquals(2, ee.output().size());
assertThat(ee.output().get(0).toString(), startsWith("COUNT(10){r}#"));
assertThat(ee.output().get(1).toString(), startsWith("COUNT(DISTINCT 20){r}#"));
}

public void testLocalExecWithAggsAndWhereTrueFilter() {
PhysicalPlan p = plan("SELECT MIN(10), MAX(123), SUM(1), AVG(0) WHERE 1 = 1");
public void testLocalExecWithCountAndWhereTrueFilter() {
PhysicalPlan p = plan("SELECT COUNT(10), COUNT(DISTINCT 20) WHERE 1 = 1" + randomOrderByAndLimit(2));
assertEquals(LocalExec.class, p.getClass());
LocalExec le = (LocalExec) p;
assertEquals(SingletonExecutable.class, le.executable().getClass());
SingletonExecutable ee = (SingletonExecutable) le.executable();
assertEquals(2, ee.output().size());
assertThat(ee.output().get(0).toString(), startsWith("COUNT(10){r}#"));
assertThat(ee.output().get(1).toString(), startsWith("COUNT(DISTINCT 20){r}#"));
}

public void testLocalExecWithAggs() {
PhysicalPlan p = plan("SELECT MIN(10), MAX(123), SUM(20), AVG(30)" + randomOrderByAndLimit(4));
assertEquals(LocalExec.class, p.getClass());
LocalExec le = (LocalExec) p;
assertEquals(SingletonExecutable.class, le.executable().getClass());
SingletonExecutable ee = (SingletonExecutable) le.executable();
assertEquals(4, ee.output().size());
assertThat(ee.output().get(0).toString(), startsWith("MIN(10){r}#"));
assertThat(ee.output().get(1).toString(), startsWith("MAX(123){r}#"));
assertThat(ee.output().get(2).toString(), startsWith("SUM(1){r}#"));
assertThat(ee.output().get(3).toString(), startsWith("AVG(0){r}#"));
assertThat(ee.output().get(2).toString(), startsWith("SUM(20){r}#"));
assertThat(ee.output().get(3).toString(), startsWith("AVG(30){r}#"));
}

public void testLocalExecWithAggsAndWhereTrueFilterAndOrderBy() {
PhysicalPlan p = plan("SELECT MAX(23), SUM(1) WHERE 1 = 1 ORDER BY 1, 2 DESC");
public void testLocalExecWithAggsAndWhereFalseFilter() {
PhysicalPlan p = plan("SELECT MIN(10), MAX(123), SUM(20), AVG(30) WHERE 2 > 3" + randomOrderByAndLimit(4));
assertEquals(LocalExec.class, p.getClass());
LocalExec le = (LocalExec) p;
assertEquals(SingletonExecutable.class, le.executable().getClass());
SingletonExecutable ee = (SingletonExecutable) le.executable();
assertEquals(2, ee.output().size());
assertThat(ee.output().get(0).toString(), startsWith("MAX(23){r}#"));
assertThat(ee.output().get(1).toString(), startsWith("SUM(1){r}#"));
assertEquals(EmptyExecutable.class, le.executable().getClass());
EmptyExecutable ee = (EmptyExecutable) le.executable();
assertEquals(4, ee.output().size());
assertThat(ee.output().get(0).toString(), startsWith("MIN(10){r}#"));
assertThat(ee.output().get(1).toString(), startsWith("MAX(123){r}#"));
assertThat(ee.output().get(2).toString(), startsWith("SUM(20){r}#"));
assertThat(ee.output().get(3).toString(), startsWith("AVG(30){r}#"));
}

public void testLocalExecWithAggsAndWhereTrueFilterAndOrderByAndLimit() {
PhysicalPlan p = plan("SELECT AVG(10), SUM(2) WHERE 1 = 1 ORDER BY 1, 2 DESC LIMIT 5");
public void testLocalExecWithAggsAndWhereTrueFilter() {
PhysicalPlan p = plan("SELECT MIN(10), MAX(123), SUM(20), AVG(30) WHERE 1 = 1" + randomOrderByAndLimit(4));
assertEquals(LocalExec.class, p.getClass());
LocalExec le = (LocalExec) p;
assertEquals(SingletonExecutable.class, le.executable().getClass());
SingletonExecutable ee = (SingletonExecutable) le.executable();
assertEquals(2, ee.output().size());
assertThat(ee.output().get(0).toString(), startsWith("AVG(10){r}#"));
assertThat(ee.output().get(1).toString(), startsWith("SUM(2){r}#"));
assertEquals(4, ee.output().size());
assertThat(ee.output().get(0).toString(), startsWith("MIN(10){r}#"));
assertThat(ee.output().get(1).toString(), startsWith("MAX(123){r}#"));
assertThat(ee.output().get(2).toString(), startsWith("SUM(20){r}#"));
assertThat(ee.output().get(3).toString(), startsWith("AVG(30){r}#"));
}

public void testFoldingOfIsNull() {
Expand Down Expand Up @@ -489,4 +503,8 @@ public void testFoldingOfPivot() {
assertThat(a, containsString("\"terms\":{\"field\":\"keyword\""));
assertThat(a, containsString("{\"avg\":{\"field\":\"int\"}"));
}

private static String randomOrderByAndLimit(int noOfSelectArgs) {
return SqlTestUtils.randomOrderByAndLimit(noOfSelectArgs, random());
}
}

0 comments on commit d1d6605

Please sign in to comment.