diff --git a/enginetest/script_queries.go b/enginetest/script_queries.go index 503c5fa1a1..a3c33ba185 100644 --- a/enginetest/script_queries.go +++ b/enginetest/script_queries.go @@ -17,6 +17,8 @@ package enginetest import ( "gopkg.in/src-d/go-errors.v1" + "github.com/dolthub/go-mysql-server/sql/analyzer" + "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/plan" ) @@ -964,6 +966,50 @@ var ScriptTests = []ScriptTest{ Query: "SELECT SUM(DISTINCT POWER(v1, 2)) FROM mytable", Expected: []sql.Row{{float64(5)}}, }, + { + Query: "SELECT + + 97 FROM tab1 GROUP BY tab1.col1", + Expected: []sql.Row{{97}, {97}, {97}}, + }, + { + Query: "SELECT rand(10) FROM tab1 GROUP BY tab1.col1", + Expected: []sql.Row{{0.5660920659323543}, {0.5660920659323543}, {0.5660920659323543}}, + }, + { + Query: "SELECT ALL - cor0.col0 * + cor0.col0 AS col2 FROM tab1 AS cor0 GROUP BY cor0.col0", + Expected: []sql.Row{{-2601}, {-7225}, {-8281}}, + }, + { + Query: "SELECT cor0.col0 * cor0.col0 + cor0.col0 AS col2 FROM tab1 AS cor0 GROUP BY cor0.col0 order by 1", + Expected: []sql.Row{{2652}, {7310}, {8372}}, + }, + { + Query: "SELECT - floor(cor0.col0) * ceil(cor0.col0) AS col2 FROM tab1 AS cor0 GROUP BY cor0.col0", + Expected: []sql.Row{{-2601}, {-7225}, {-8281}}, + }, + { + Query: "SELECT col0 FROM tab1 AS cor0 GROUP BY cor0.col0", + Expected: []sql.Row{{51}, {85}, {91}}, + }, + { + Query: "SELECT - cor0.col0 FROM tab1 AS cor0 GROUP BY cor0.col0", + Expected: []sql.Row{{-51}, {-85}, {-91}}, + }, + { + Query: "SELECT col0 BETWEEN 2 and 4 from tab1 group by col0", + Expected: []sql.Row{{false}, {false}, {false}}, + }, + { + Query: "SELECT col0, col1 FROM tab1 GROUP by col0;", + ExpectedErr: analyzer.ErrValidationGroupBy, + }, + { + Query: "SELECT col0, floor(col1) FROM tab1 GROUP by col0;", + ExpectedErr: analyzer.ErrValidationGroupBy, + }, + { + Query: "SELECT floor(cor0.col1) * ceil(cor0.col0) AS col2 FROM tab1 AS cor0 GROUP BY cor0.col0", + ExpectedErr: analyzer.ErrValidationGroupBy, + }, }, }, { diff --git a/sql/analyzer/validation_rules.go b/sql/analyzer/validation_rules.go index 1587227a61..2b68edc8e6 100644 --- a/sql/analyzer/validation_rules.go +++ b/sql/analyzer/validation_rules.go @@ -49,7 +49,7 @@ var ( ErrValidationOrderBy = errors.NewKind("OrderBy does not support aggregation expressions") // ErrValidationGroupBy is returned when the aggregation expression does not // appear in the grouping columns. - ErrValidationGroupBy = errors.NewKind("GroupBy aggregate expression '%v' doesn't appear in the grouping columns") + ErrValidationGroupBy = errors.NewKind("expression '%v' doesn't appear in the group by expressions") // ErrValidationSchemaSource is returned when there is any column source // that does not match the table name. ErrValidationSchemaSource = errors.NewKind("one or more schema sources are empty") @@ -228,17 +228,14 @@ func validateGroupBy(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope) (s return n, nil } - var validAggs []string + var groupBys []string for _, expr := range n.GroupByExprs { - validAggs = append(validAggs, expr.String()) + groupBys = append(groupBys, expr.String()) } - // TODO: validate columns inside aggregations - // and allow any kind of expression that make use of the grouping - // columns. for _, expr := range n.SelectedExprs { if _, ok := expr.(sql.Aggregation); !ok { - if !isValidAgg(validAggs, expr) { + if !expressionReferencesOnlyGroupBys(groupBys, expr) { return nil, ErrValidationGroupBy.New(expr.String()) } } @@ -250,15 +247,35 @@ func validateGroupBy(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope) (s return n, nil } -func isValidAgg(validAggs []string, expr sql.Expression) bool { - switch expr := expr.(type) { - case sql.Aggregation: - return true - case *expression.Alias: - return stringContains(validAggs, expr.String()) || isValidAgg(validAggs, expr.Child) - default: - return stringContains(validAggs, expr.String()) - } +func expressionReferencesOnlyGroupBys(groupBys []string, expr sql.Expression) bool { + valid := true + sql.Inspect(expr, func(expr sql.Expression) bool { + switch expr := expr.(type) { + case nil, sql.Aggregation, *expression.Literal: + return false + case *expression.Alias, sql.FunctionExpression: + if stringContains(groupBys, expr.String()) { + return false + } + return true + // cc: https://dev.mysql.com/doc/refman/8.0/en/group-by-handling.html + // Each part of the SelectExpr must refer to the aggregated columns in some way + // TODO: this isn't complete, it's overly restrictive. Dependant columns are fine to reference. + default: + if stringContains(groupBys, expr.String()) { + return true + } + + if len(expr.Children()) == 0 { + valid = false + return false + } + + return true + } + }) + + return valid } func validateSchemaSource(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope) (sql.Node, error) {