From 856d899525cac1d4383f31b836cc29b0fbb7b73e Mon Sep 17 00:00:00 2001 From: Vinai Rachakonda Date: Tue, 15 Jun 2021 15:19:31 -0400 Subject: [PATCH 1/7] add small fix that allows most expressions to be run as aggregates on group by --- enginetest/script_queries.go | 20 ++++++++++++++++++++ sql/analyzer/validation_rules.go | 5 ++--- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/enginetest/script_queries.go b/enginetest/script_queries.go index c2e33114cb..5fc53e53e8 100644 --- a/enginetest/script_queries.go +++ b/enginetest/script_queries.go @@ -890,6 +890,26 @@ var ScriptTests = []ScriptTest{ Query: "SELECT SUM(DISTINCT POWER(v1, 2)) FROM mytable", Expected: []sql.Row{{float64(5)}}, }, + { + Query: "SELECT + + 97 FROM tab1 GROUP BY tab1.col1", + Expected: []sql.Row{{97}, {97}, {97}}, + }, + { + Query: "SELECT ALL - cor0.col0 * + cor0.col0 AS col2 FROM tab1 AS cor0 GROUP BY cor0.col0", + Expected: []sql.Row{{-2601}, {-7225}, {-8281}}, + }, + { + Query: "SELECT col0 FROM tab1 AS cor0 GROUP BY cor0.col0", + Expected: []sql.Row{{51}, {85}, {91}}, + }, + { + Query: "SELECT - cor0.col0 FROM tab1 AS cor0 GROUP BY cor0.col0", + Expected: []sql.Row{{-51}, {-85}, {-91}}, + }, + { + Query: "SELECT col0 BETWEEN 2 and 4 from tab1 group by col0", + Expected: []sql.Row{{false}, {false}, {false}}, + }, }, }, { diff --git a/sql/analyzer/validation_rules.go b/sql/analyzer/validation_rules.go index 1587227a61..26ebfc8b3e 100644 --- a/sql/analyzer/validation_rules.go +++ b/sql/analyzer/validation_rules.go @@ -252,12 +252,11 @@ func validateGroupBy(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope) (s func isValidAgg(validAggs []string, expr sql.Expression) bool { switch expr := expr.(type) { - case sql.Aggregation: - return true case *expression.Alias: return stringContains(validAggs, expr.String()) || isValidAgg(validAggs, expr.Child) + // Pretty much all expressions can be treated as an aggregations on GROUP BY default: - return stringContains(validAggs, expr.String()) + return true } } From 1b504c17812c519508d1e1e65444eea47c038903 Mon Sep 17 00:00:00 2001 From: Vinai Rachakonda Date: Tue, 15 Jun 2021 15:49:10 -0400 Subject: [PATCH 2/7] new test case --- enginetest/script_queries.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/enginetest/script_queries.go b/enginetest/script_queries.go index 5fc53e53e8..68f35d1b72 100644 --- a/enginetest/script_queries.go +++ b/enginetest/script_queries.go @@ -910,6 +910,10 @@ var ScriptTests = []ScriptTest{ Query: "SELECT col0 BETWEEN 2 and 4 from tab1 group by col0", Expected: []sql.Row{{false}, {false}, {false}}, }, + { + Query: "SELECT col1 * cor0.col1 * 56 AS col1 FROM tab2 AS cor0 GROUP BY cor0.col1", + Expected: []sql.Row{{332024}, {251384}, {145656}}, + }, }, }, { From 747f3d5a79df5abca45fddfae204a81e3cfdc02f Mon Sep 17 00:00:00 2001 From: Vinai Rachakonda Date: Tue, 15 Jun 2021 15:49:29 -0400 Subject: [PATCH 3/7] fmt --- enginetest/script_queries.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/enginetest/script_queries.go b/enginetest/script_queries.go index 68f35d1b72..c088c7a854 100644 --- a/enginetest/script_queries.go +++ b/enginetest/script_queries.go @@ -891,11 +891,11 @@ var ScriptTests = []ScriptTest{ Expected: []sql.Row{{float64(5)}}, }, { - Query: "SELECT + + 97 FROM tab1 GROUP BY tab1.col1", + Query: "SELECT + + 97 FROM tab1 GROUP BY tab1.col1", Expected: []sql.Row{{97}, {97}, {97}}, }, { - Query: "SELECT ALL - cor0.col0 * + cor0.col0 AS col2 FROM tab1 AS cor0 GROUP BY cor0.col0", + Query: "SELECT ALL - cor0.col0 * + cor0.col0 AS col2 FROM tab1 AS cor0 GROUP BY cor0.col0", Expected: []sql.Row{{-2601}, {-7225}, {-8281}}, }, { From bfa6f965aea712a59183e8b1f7cf0378b29259f9 Mon Sep 17 00:00:00 2001 From: Vinai Rachakonda Date: Wed, 16 Jun 2021 09:25:28 -0400 Subject: [PATCH 4/7] cleanup --- enginetest/script_queries.go | 5 +++-- sql/analyzer/validation_rules.go | 17 +++++++++++++++-- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/enginetest/script_queries.go b/enginetest/script_queries.go index c088c7a854..ca73e0fd2f 100644 --- a/enginetest/script_queries.go +++ b/enginetest/script_queries.go @@ -15,6 +15,7 @@ package enginetest import ( + "github.com/dolthub/go-mysql-server/sql/analyzer" "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" @@ -911,8 +912,8 @@ var ScriptTests = []ScriptTest{ Expected: []sql.Row{{false}, {false}, {false}}, }, { - Query: "SELECT col1 * cor0.col1 * 56 AS col1 FROM tab2 AS cor0 GROUP BY cor0.col1", - Expected: []sql.Row{{332024}, {251384}, {145656}}, + Query: "SELECT col0, col1 FROM tab1 GROUP by col0;", + ExpectedErr: analyzer.ErrValidationGroupBy, }, }, }, diff --git a/sql/analyzer/validation_rules.go b/sql/analyzer/validation_rules.go index 26ebfc8b3e..56f153b0c8 100644 --- a/sql/analyzer/validation_rules.go +++ b/sql/analyzer/validation_rules.go @@ -252,11 +252,24 @@ func validateGroupBy(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope) (s func isValidAgg(validAggs []string, expr sql.Expression) bool { switch expr := expr.(type) { + case sql.Aggregation, *expression.Literal: + return true case *expression.Alias: return stringContains(validAggs, expr.String()) || isValidAgg(validAggs, expr.Child) - // Pretty much all expressions can be treated as an aggregations on GROUP BY + // cc: https://dev.mysql.com/doc/refman/8.0/en/group-by-handling.html + // Each part of the SelectExpr must refer to the aggregated columns in some way default: - return true + if stringContains(validAggs, expr.String()) { + return true + } + + for _, child := range expr.Children() { + if isValidAgg(validAggs, child) { + return true + } + } + + return false } } From 99d16bb38885b762343a18478a9548267ed3368c Mon Sep 17 00:00:00 2001 From: Vinai Rachakonda Date: Wed, 16 Jun 2021 09:25:46 -0400 Subject: [PATCH 5/7] fmt --- enginetest/script_queries.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/enginetest/script_queries.go b/enginetest/script_queries.go index ca73e0fd2f..0a5f8c6683 100644 --- a/enginetest/script_queries.go +++ b/enginetest/script_queries.go @@ -15,9 +15,10 @@ package enginetest import ( - "github.com/dolthub/go-mysql-server/sql/analyzer" "gopkg.in/src-d/go-errors.v1" + "github.com/dolthub/go-mysql-server/sql/analyzer" + "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/plan" ) @@ -912,7 +913,7 @@ var ScriptTests = []ScriptTest{ Expected: []sql.Row{{false}, {false}, {false}}, }, { - Query: "SELECT col0, col1 FROM tab1 GROUP by col0;", + Query: "SELECT col0, col1 FROM tab1 GROUP by col0;", ExpectedErr: analyzer.ErrValidationGroupBy, }, }, From 24e3e676f72a1c672f8dc4206b3c956178c4f651 Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Wed, 16 Jun 2021 16:31:44 -0700 Subject: [PATCH 6/7] Added more complete checks of group by validation --- enginetest/script_queries.go | 20 +++++++++++++ sql/analyzer/validation_rules.go | 51 ++++++++++++++++++-------------- 2 files changed, 48 insertions(+), 23 deletions(-) diff --git a/enginetest/script_queries.go b/enginetest/script_queries.go index 0a5f8c6683..6bc34d1d93 100644 --- a/enginetest/script_queries.go +++ b/enginetest/script_queries.go @@ -896,10 +896,22 @@ var ScriptTests = []ScriptTest{ Query: "SELECT + + 97 FROM tab1 GROUP BY tab1.col1", Expected: []sql.Row{{97}, {97}, {97}}, }, + { + Query: "SELECT rand(10) FROM tab1 GROUP BY tab1.col1", + Expected: []sql.Row{{0.5660920659323543}, {0.5660920659323543}, {0.5660920659323543}}, + }, { Query: "SELECT ALL - cor0.col0 * + cor0.col0 AS col2 FROM tab1 AS cor0 GROUP BY cor0.col0", Expected: []sql.Row{{-2601}, {-7225}, {-8281}}, }, + { + Query: "SELECT cor0.col0 * cor0.col0 + cor0.col0 AS col2 FROM tab1 AS cor0 GROUP BY cor0.col0 order by 1", + Expected: []sql.Row{{2652}, {7310}, {8372}}, + }, + { + Query: "SELECT - floor(cor0.col0) * ceil(cor0.col0) AS col2 FROM tab1 AS cor0 GROUP BY cor0.col0", + Expected: []sql.Row{{-2601}, {-7225}, {-8281}}, + }, { Query: "SELECT col0 FROM tab1 AS cor0 GROUP BY cor0.col0", Expected: []sql.Row{{51}, {85}, {91}}, @@ -916,6 +928,14 @@ var ScriptTests = []ScriptTest{ Query: "SELECT col0, col1 FROM tab1 GROUP by col0;", ExpectedErr: analyzer.ErrValidationGroupBy, }, + { + Query: "SELECT col0, floor(col1) FROM tab1 GROUP by col0;", + ExpectedErr: analyzer.ErrValidationGroupBy, + }, + { + Query: "SELECT floor(cor0.col1) * ceil(cor0.col0) AS col2 FROM tab1 AS cor0 GROUP BY cor0.col0", + ExpectedErr: analyzer.ErrValidationGroupBy, + }, }, }, { diff --git a/sql/analyzer/validation_rules.go b/sql/analyzer/validation_rules.go index 56f153b0c8..2b68edc8e6 100644 --- a/sql/analyzer/validation_rules.go +++ b/sql/analyzer/validation_rules.go @@ -49,7 +49,7 @@ var ( ErrValidationOrderBy = errors.NewKind("OrderBy does not support aggregation expressions") // ErrValidationGroupBy is returned when the aggregation expression does not // appear in the grouping columns. - ErrValidationGroupBy = errors.NewKind("GroupBy aggregate expression '%v' doesn't appear in the grouping columns") + ErrValidationGroupBy = errors.NewKind("expression '%v' doesn't appear in the group by expressions") // ErrValidationSchemaSource is returned when there is any column source // that does not match the table name. ErrValidationSchemaSource = errors.NewKind("one or more schema sources are empty") @@ -228,17 +228,14 @@ func validateGroupBy(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope) (s return n, nil } - var validAggs []string + var groupBys []string for _, expr := range n.GroupByExprs { - validAggs = append(validAggs, expr.String()) + groupBys = append(groupBys, expr.String()) } - // TODO: validate columns inside aggregations - // and allow any kind of expression that make use of the grouping - // columns. for _, expr := range n.SelectedExprs { if _, ok := expr.(sql.Aggregation); !ok { - if !isValidAgg(validAggs, expr) { + if !expressionReferencesOnlyGroupBys(groupBys, expr) { return nil, ErrValidationGroupBy.New(expr.String()) } } @@ -250,27 +247,35 @@ func validateGroupBy(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope) (s return n, nil } -func isValidAgg(validAggs []string, expr sql.Expression) bool { - switch expr := expr.(type) { - case sql.Aggregation, *expression.Literal: - return true - case *expression.Alias: - return stringContains(validAggs, expr.String()) || isValidAgg(validAggs, expr.Child) - // cc: https://dev.mysql.com/doc/refman/8.0/en/group-by-handling.html - // Each part of the SelectExpr must refer to the aggregated columns in some way - default: - if stringContains(validAggs, expr.String()) { +func expressionReferencesOnlyGroupBys(groupBys []string, expr sql.Expression) bool { + valid := true + sql.Inspect(expr, func(expr sql.Expression) bool { + switch expr := expr.(type) { + case nil, sql.Aggregation, *expression.Literal: + return false + case *expression.Alias, sql.FunctionExpression: + if stringContains(groupBys, expr.String()) { + return false + } return true - } - - for _, child := range expr.Children() { - if isValidAgg(validAggs, child) { + // cc: https://dev.mysql.com/doc/refman/8.0/en/group-by-handling.html + // Each part of the SelectExpr must refer to the aggregated columns in some way + // TODO: this isn't complete, it's overly restrictive. Dependant columns are fine to reference. + default: + if stringContains(groupBys, expr.String()) { return true } + + if len(expr.Children()) == 0 { + valid = false + return false + } + + return true } + }) - return false - } + return valid } func validateSchemaSource(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope) (sql.Node, error) { From 97493b68e2d9551f3e03c2345460e306896db275 Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Wed, 16 Jun 2021 16:36:47 -0700 Subject: [PATCH 7/7] Formatting --- enginetest/script_queries.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enginetest/script_queries.go b/enginetest/script_queries.go index 6bc34d1d93..8a21bddb20 100644 --- a/enginetest/script_queries.go +++ b/enginetest/script_queries.go @@ -933,7 +933,7 @@ var ScriptTests = []ScriptTest{ ExpectedErr: analyzer.ErrValidationGroupBy, }, { - Query: "SELECT floor(cor0.col1) * ceil(cor0.col0) AS col2 FROM tab1 AS cor0 GROUP BY cor0.col0", + Query: "SELECT floor(cor0.col1) * ceil(cor0.col0) AS col2 FROM tab1 AS cor0 GROUP BY cor0.col0", ExpectedErr: analyzer.ErrValidationGroupBy, }, },