Skip to content

Commit

Permalink
sql: support corr()
Browse files Browse the repository at this point in the history
Support aggregate function `corr()`. It referred to the PostgreSQL implementation, but the implementation details follow SQL 2003. See: https://www.postgresql.org/docs/10/functions-aggregate.html

See #41274

Release note (sql change): Support aggregate function `corr()`
  • Loading branch information
hueypark committed Feb 21, 2020
1 parent a33c58c commit 865e011
Show file tree
Hide file tree
Showing 15 changed files with 607 additions and 226 deletions.
8 changes: 8 additions & 0 deletions docs/generated/sql/aggregates.md
Expand Up @@ -53,6 +53,14 @@
</span></td></tr>
<tr><td><a name="concat_agg"></a><code>concat_agg(arg1: <a href="string.html">string</a>) &rarr; <a href="string.html">string</a></code></td><td><span class="funcdesc"><p>Concatenates all selected values.</p>
</span></td></tr>
<tr><td><a name="corr"></a><code>corr(arg1: <a href="float.html">float</a>, arg2: <a href="float.html">float</a>) &rarr; <a href="float.html">float</a></code></td><td><span class="funcdesc"><p>Calculates the correlation coefficient of the selected values.</p>
</span></td></tr>
<tr><td><a name="corr"></a><code>corr(arg1: <a href="float.html">float</a>, arg2: <a href="int.html">int</a>) &rarr; <a href="float.html">float</a></code></td><td><span class="funcdesc"><p>Calculates the correlation coefficient of the selected values.</p>
</span></td></tr>
<tr><td><a name="corr"></a><code>corr(arg1: <a href="int.html">int</a>, arg2: <a href="float.html">float</a>) &rarr; <a href="float.html">float</a></code></td><td><span class="funcdesc"><p>Calculates the correlation coefficient of the selected values.</p>
</span></td></tr>
<tr><td><a name="corr"></a><code>corr(arg1: <a href="int.html">int</a>, arg2: <a href="int.html">int</a>) &rarr; <a href="float.html">float</a></code></td><td><span class="funcdesc"><p>Calculates the correlation coefficient of the selected values.</p>
</span></td></tr>
<tr><td><a name="count"></a><code>count(arg1: anyelement) &rarr; <a href="int.html">int</a></code></td><td><span class="funcdesc"><p>Calculates the number of selected elements.</p>
</span></td></tr>
<tr><td><a name="count_rows"></a><code>count_rows() &rarr; <a href="int.html">int</a></code></td><td><span class="funcdesc"><p>Calculates the number of rows.</p>
Expand Down
361 changes: 182 additions & 179 deletions pkg/sql/execinfrapb/processors_sql.pb.go

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pkg/sql/execinfrapb/processors_sql.proto
Expand Up @@ -468,6 +468,7 @@ message AggregatorSpec {
STRING_AGG = 21;
BIT_AND = 22;
BIT_OR = 23;
CORR = 24;
}

enum Type {
Expand Down
88 changes: 88 additions & 0 deletions pkg/sql/logictest/testdata/logic_test/aggregate
Expand Up @@ -1223,6 +1223,94 @@ SELECT 123 FROM kv ORDER BY max(v)
----
123

subtest statistics

statement OK
CREATE TABLE statistics_agg_test (
y float,
x float,
int_y int,
int_x int
)

statement OK
INSERT INTO statistics_agg_test (y, x, int_y, int_x) VALUES
(1.0, 10.0, 1, 10),
(2.0, 25.0, 2, 25),
(2.0, 25.0, 2, 25),
(3.0, 40.0, 3, 40),
(3.0, 40.0, 3, 40),
(3.0, 40.0, 3, 40),
(4.0, 100.0, 4, 100),
(4.0, 100.0, 4, 100),
(4.0, 100.0, 4, 100),
(4.0, 100.0, 4, 100),
(NULL, NULL, NULL, NULL)

query RRRR
SELECT corr(y, x)::decimal, corr(int_y, int_x)::decimal, corr(y, int_x)::decimal, corr(int_y, x)::decimal FROM statistics_agg_test
----
0.933007822647968 0.933007822647968 0.933007822647968 0.933007822647968

query R
SELECT corr(DISTINCT y, x)::decimal FROM statistics_agg_test
----
0.9326733179802503

query R
SELECT CAST(corr(DISTINCT y, x) FILTER (WHERE x > 3 AND y < 30) AS decimal) FROM statistics_agg_test
----
0.9326733179802503

query error pq: unknown signature: corr\(string, string\)
SELECT corr(y::string, x::string) FROM statistics_agg_test

statement OK
INSERT INTO statistics_agg_test (y, x, int_y, int_x) VALUES
(1.797693134862315708145274237317043567981e+308, 0, 0, 0)

query error float out of range
SELECT corr(y, x)::decimal, corr(int_y, int_x)::decimal FROM statistics_agg_test

statement OK
TRUNCATE statistics_agg_test

statement OK
INSERT INTO statistics_agg_test (y, x, int_y, int_x) VALUES
(1.0, 10.0, 1, 10),
(2.0, 20.0, 2, 20)

query RR
SELECT corr(y, x)::decimal, corr(int_y, int_x)::decimal FROM statistics_agg_test
----
1 1

statement OK
TRUNCATE statistics_agg_test

statement OK
INSERT INTO statistics_agg_test (y, x, int_y, int_x) VALUES
(1.0, 10.0, 1, 10),
(2.0, -20.0, 2, -20)

query RR
SELECT corr(y, x)::decimal, corr(int_y, int_x)::decimal FROM statistics_agg_test
----
-1 -1

statement OK
TRUNCATE statistics_agg_test

statement OK
INSERT INTO statistics_agg_test (y, x, int_y, int_x) VALUES
(1.0, -1.0, 1, -1),
(1.0, 1.0, 1, 1)

query RR
SELECT corr(y, x)::decimal, corr(int_y, int_x)::decimal FROM statistics_agg_test
----
NULL NULL

subtest string_agg

statement OK
Expand Down
12 changes: 12 additions & 0 deletions pkg/sql/logictest/testdata/logic_test/distsql_agg
Expand Up @@ -565,6 +565,18 @@ SELECT sum(a) FROM data WHERE FALSE
----
NULL

# Test that statistics aggregate functions.
statement ok
CREATE TABLE statistics_agg_test (y INT, x INT)

statement ok
INSERT INTO statistics_agg_test SELECT y, y%10 FROM generate_series(1, 100) AS y

query R
SELECT corr(y, x)::decimal FROM statistics_agg_test
----
0.045228963191363145

# Regression test for #37211 (incorrect ordering between aggregator stages).
statement ok
CREATE TABLE uv (u INT PRIMARY KEY, v INT);
Expand Down
27 changes: 13 additions & 14 deletions pkg/sql/opt/exec/execbuilder/testdata/distsql_agg
Expand Up @@ -351,20 +351,19 @@ group-by
└── count-rows

query TTTTT
EXPLAIN (verbose) SELECT b, count(*) FROM data2 WHERE a=1 GROUP BY b
----
· distributed true · ·
· vectorized false · ·
group · · (b, count) ·
│ aggregate 0 b · ·
│ aggregate 1 count_rows() · ·
│ group by b · ·
│ ordered +b · ·
└── render · · (b) +b
│ render 0 b · ·
└── scan · · (a, b) +b
· table data2@primary · ·
· spans /1-/2 · ·
EXPLAIN (verbose) SELECT b, count(*), corr(a, b) FROM data2 WHERE a=1 GROUP BY b
----
· distributed true · ·
· vectorized false · ·
group · · (b, count, corr) ·
│ aggregate 0 b · ·
│ aggregate 1 count_rows() · ·
│ aggregate 2 corr(a, b) · ·
│ group by b · ·
│ ordered +b · ·
└── scan · · (a, b) +b
· table data2@primary · ·
· spans /1-/2 · ·

query T
SELECT url FROM [EXPLAIN (DISTSQL) SELECT b, count(*) FROM data2 WHERE a=1 GROUP BY b];
Expand Down
12 changes: 12 additions & 0 deletions pkg/sql/opt/exec/execbuilder/testdata/explain
Expand Up @@ -997,3 +997,15 @@ group · ·
│ scalar ·
└── values · ·
· size 2 columns, 2 rows

query TTT
EXPLAIN SELECT corr(a, b) FROM tc;
----
· distributed false
· vectorized true
group · ·
│ aggregate 0 corr(a, b)
│ scalar ·
└── scan · ·
· table tc@primary
· spans ALL
53 changes: 35 additions & 18 deletions pkg/sql/opt/norm/testdata/rules/agg
Expand Up @@ -7,17 +7,21 @@ CREATE TABLE a (k INT PRIMARY KEY, i INT, f FLOAT, s STRING, j JSON, arr int[])
# --------------------------------------------------

norm expect=EliminateAggDistinct
SELECT min(DISTINCT i), max(DISTINCT i), bool_and(DISTINCT i>f), bool_or(DISTINCT i>f) FROM a
SELECT min(DISTINCT i), max(DISTINCT i), bool_and(DISTINCT i>f), bool_or(DISTINCT i>f), corr(DISTINCT k, i) FROM a
----
scalar-group-by
├── columns: min:7(int) max:8(int) bool_and:10(bool) bool_or:11(bool)
├── columns: min:7(int) max:8(int) bool_and:10(bool) bool_or:11(bool) corr:12(float)
├── cardinality: [1 - 1]
├── key: ()
├── fd: ()-->(7,8,10,11)
├── fd: ()-->(7,8,10-12)
├── project
│ ├── columns: column9:9(bool) i:2(int)
│ ├── columns: column9:9(bool) k:1(int!null) i:2(int)
│ ├── key: (1)
│ ├── fd: (1)-->(2,9)
│ ├── scan a
│ │ └── columns: i:2(int) f:3(float)
│ │ ├── columns: k:1(int!null) i:2(int) f:3(float)
│ │ ├── key: (1)
│ │ └── fd: (1)-->(2,3)
│ └── projections
│ └── i > f [type=bool, outer=(2,3)]
└── aggregations
Expand All @@ -27,33 +31,41 @@ scalar-group-by
│ └── variable: i [type=int]
├── bool-and [type=bool, outer=(9)]
│ └── variable: column9 [type=bool]
└── bool-or [type=bool, outer=(9)]
└── variable: column9 [type=bool]
├── bool-or [type=bool, outer=(9)]
│ └── variable: column9 [type=bool]
└── corr [type=float, outer=(1,2)]
├── variable: k [type=int]
└── variable: i [type=int]

# The rule should still work when FILTER is present.
norm expect=EliminateAggDistinct
SELECT
min(DISTINCT i) FILTER (WHERE i > 5),
max(DISTINCT i) FILTER (WHERE i > 5),
bool_and(DISTINCT i>f) FILTER (WHERE f > 0.0),
bool_or(DISTINCT i>f) FILTER (WHERE f > 1.0)
bool_or(DISTINCT i>f) FILTER (WHERE f > 1.0),
corr(DISTINCT k, i) FILTER(WHERE k > 5 AND i > 5)
FROM a
----
scalar-group-by
├── columns: min:8(int) max:9(int) bool_and:12(bool) bool_or:14(bool)
├── columns: min:8(int) max:9(int) bool_and:12(bool) bool_or:14(bool) corr:16(float)
├── cardinality: [1 - 1]
├── key: ()
├── fd: ()-->(8,9,12,14)
├── fd: ()-->(8,9,12,14,16)
├── project
│ ├── columns: column7:7(bool) column10:10(bool) column11:11(bool) column13:13(bool) i:2(int)
│ ├── fd: (2)-->(7)
│ ├── columns: column7:7(bool) column10:10(bool) column11:11(bool) column13:13(bool) column15:15(bool) k:1(int!null) i:2(int)
│ ├── key: (1)
│ ├── fd: (1)-->(2,7,10,11,13), (2)-->(7), (1,2)-->(15)
│ ├── scan a
│ │ └── columns: i:2(int) f:3(float)
│ │ ├── columns: k:1(int!null) i:2(int) f:3(float)
│ │ ├── key: (1)
│ │ └── fd: (1)-->(2,3)
│ └── projections
│ ├── i > 5 [type=bool, outer=(2)]
│ ├── i > f [type=bool, outer=(2,3)]
│ ├── f > 0.0 [type=bool, outer=(3)]
│ └── f > 1.0 [type=bool, outer=(3)]
│ ├── f > 1.0 [type=bool, outer=(3)]
│ └── (k > 5) AND (i > 5) [type=bool, outer=(1,2)]
└── aggregations
├── agg-filter [type=int, outer=(2,7)]
│ ├── min [type=int]
Expand All @@ -67,10 +79,15 @@ scalar-group-by
│ ├── bool-and [type=bool]
│ │ └── variable: column10 [type=bool]
│ └── variable: column11 [type=bool]
└── agg-filter [type=bool, outer=(10,13)]
├── bool-or [type=bool]
│ └── variable: column10 [type=bool]
└── variable: column13 [type=bool]
├── agg-filter [type=bool, outer=(10,13)]
│ ├── bool-or [type=bool]
│ │ └── variable: column10 [type=bool]
│ └── variable: column13 [type=bool]
└── agg-filter [type=float, outer=(1,2,15)]
├── corr [type=float]
│ ├── variable: k [type=int]
│ └── variable: i [type=int]
└── variable: column15 [type=bool]

# The rule should not apply to these aggregations.
norm expect-not=EliminateAggDistinct
Expand Down
33 changes: 23 additions & 10 deletions pkg/sql/opt/operator.go
Expand Up @@ -179,6 +179,7 @@ var AggregateOpReverseMap = map[Operator]string{
BoolOrOp: "bool_or",
ConcatAggOp: "concat_agg",
CountOp: "count",
CorrOp: "corr",
CountRowsOp: "count_rows",
MaxOp: "max",
MinOp: "min",
Expand Down Expand Up @@ -251,27 +252,39 @@ func BoolOperatorRequiresNotNullArgs(op Operator) bool {
return false
}

// AggregateIgnoresNulls returns true if the given aggregate operator has a
// single input, and if it always evaluates to the same result regardless of
// how many NULL values are included in that input, in any order.
// AggregateIgnoresNulls returns true if the given aggregate operator ignores
// rows where its first argument evaluates to NULL. In other words, it always
// evaluates to the same result even if those rows are filtered. For example:
//
// SELECT string_agg(x, y)
// FROM (VALUES ('foo', ','), ('bar', ','), (NULL, ',')) t(x, y)
//
// In this example, the NULL row can be removed from the input, and the
// string_agg function still returns the same result. Contrast this to the
// array_agg function:
//
// SELECT array_agg(x)
// FROM (VALUES ('foo'), (NULL), ('bar')) t(x)
//
// If the NULL row is removed here, array_agg returns {foo,bar} instead of
// {foo,NULL,bar}.
func AggregateIgnoresNulls(op Operator) bool {
switch op {
case AvgOp, BitAndAggOp, BitOrAggOp, BoolAndOp, BoolOrOp, CountOp, MaxOp, MinOp,
case AvgOp, BitAndAggOp, BitOrAggOp, BoolAndOp, BoolOrOp, CorrOp, CountOp, MaxOp, MinOp,
SumIntOp, SumOp, SqrDiffOp, VarianceOp, StdDevOp, XorAggOp, ConstNotNullAggOp,
AnyNotNullAggOp, StringAggOp:
return true
}
return false
}

// AggregateIsNullOnEmpty returns true if the given aggregate operator has a
// single input, and if it returns NULL when the input set contains no values.
// This group of aggregates overlaps considerably with the AggregateIgnoresNulls
// group, with the notable exception of COUNT, which returns zero instead of
// NULL when its input is empty.
// AggregateIsNullOnEmpty returns true if the given aggregate operator returns
// NULL when the input set contains no values. This group of aggregates overlaps
// considerably with the AggregateIgnoresNulls group, with the notable exception
// of COUNT, which returns zero instead of NULL when its input is empty.
func AggregateIsNullOnEmpty(op Operator) bool {
switch op {
case AvgOp, BitAndAggOp, BitOrAggOp, BoolAndOp, BoolOrOp, MaxOp, MinOp, SumIntOp,
case AvgOp, BitAndAggOp, BitOrAggOp, BoolAndOp, BoolOrOp, CorrOp, MaxOp, MinOp, SumIntOp,
SumOp, SqrDiffOp, VarianceOp, StdDevOp, XorAggOp, ConstAggOp, ConstNotNullAggOp, ArrayAggOp,
ConcatAggOp, JsonAggOp, JsonbAggOp, AnyNotNullAggOp, StringAggOp:
return true
Expand Down
6 changes: 6 additions & 0 deletions pkg/sql/opt/ops/scalar.opt
Expand Up @@ -761,6 +761,12 @@ define ConcatAgg {
Input ScalarExpr
}

[Scalar, Aggregate]
define Corr {
Y ScalarExpr
X ScalarExpr
}

[Scalar, Aggregate]
define Count {
Input ScalarExpr
Expand Down
2 changes: 2 additions & 0 deletions pkg/sql/opt/optbuilder/groupby.go
Expand Up @@ -756,6 +756,8 @@ func (b *Builder) constructAggregate(name string, args []opt.ScalarExpr) opt.Sca
return b.factory.ConstructBoolOr(args[0])
case "concat_agg":
return b.factory.ConstructConcatAgg(args[0])
case "corr":
return b.factory.ConstructCorr(args[0], args[1])
case "count":
return b.factory.ConstructCount(args[0])
case "count_rows":
Expand Down

0 comments on commit 865e011

Please sign in to comment.