Skip to content

Commit

Permalink
Merge pull request #470 from dolthub/vinai/group-by-aggs
Browse files Browse the repository at this point in the history
More accurate validation of GROUP BY expressions
  • Loading branch information
zachmu committed Jun 16, 2021
2 parents a25dd40 + 97493b6 commit 8571009
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 16 deletions.
46 changes: 46 additions & 0 deletions enginetest/script_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ package enginetest
import (
"gopkg.in/src-d/go-errors.v1"

"github.com/dolthub/go-mysql-server/sql/analyzer"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/plan"
)
Expand Down Expand Up @@ -964,6 +966,50 @@ var ScriptTests = []ScriptTest{
Query: "SELECT SUM(DISTINCT POWER(v1, 2)) FROM mytable",
Expected: []sql.Row{{float64(5)}},
},
{
Query: "SELECT + + 97 FROM tab1 GROUP BY tab1.col1",
Expected: []sql.Row{{97}, {97}, {97}},
},
{
Query: "SELECT rand(10) FROM tab1 GROUP BY tab1.col1",
Expected: []sql.Row{{0.5660920659323543}, {0.5660920659323543}, {0.5660920659323543}},
},
{
Query: "SELECT ALL - cor0.col0 * + cor0.col0 AS col2 FROM tab1 AS cor0 GROUP BY cor0.col0",
Expected: []sql.Row{{-2601}, {-7225}, {-8281}},
},
{
Query: "SELECT cor0.col0 * cor0.col0 + cor0.col0 AS col2 FROM tab1 AS cor0 GROUP BY cor0.col0 order by 1",
Expected: []sql.Row{{2652}, {7310}, {8372}},
},
{
Query: "SELECT - floor(cor0.col0) * ceil(cor0.col0) AS col2 FROM tab1 AS cor0 GROUP BY cor0.col0",
Expected: []sql.Row{{-2601}, {-7225}, {-8281}},
},
{
Query: "SELECT col0 FROM tab1 AS cor0 GROUP BY cor0.col0",
Expected: []sql.Row{{51}, {85}, {91}},
},
{
Query: "SELECT - cor0.col0 FROM tab1 AS cor0 GROUP BY cor0.col0",
Expected: []sql.Row{{-51}, {-85}, {-91}},
},
{
Query: "SELECT col0 BETWEEN 2 and 4 from tab1 group by col0",
Expected: []sql.Row{{false}, {false}, {false}},
},
{
Query: "SELECT col0, col1 FROM tab1 GROUP by col0;",
ExpectedErr: analyzer.ErrValidationGroupBy,
},
{
Query: "SELECT col0, floor(col1) FROM tab1 GROUP by col0;",
ExpectedErr: analyzer.ErrValidationGroupBy,
},
{
Query: "SELECT floor(cor0.col1) * ceil(cor0.col0) AS col2 FROM tab1 AS cor0 GROUP BY cor0.col0",
ExpectedErr: analyzer.ErrValidationGroupBy,
},
},
},
{
Expand Down
49 changes: 33 additions & 16 deletions sql/analyzer/validation_rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ var (
ErrValidationOrderBy = errors.NewKind("OrderBy does not support aggregation expressions")
// ErrValidationGroupBy is returned when the aggregation expression does not
// appear in the grouping columns.
ErrValidationGroupBy = errors.NewKind("GroupBy aggregate expression '%v' doesn't appear in the grouping columns")
ErrValidationGroupBy = errors.NewKind("expression '%v' doesn't appear in the group by expressions")
// ErrValidationSchemaSource is returned when there is any column source
// that does not match the table name.
ErrValidationSchemaSource = errors.NewKind("one or more schema sources are empty")
Expand Down Expand Up @@ -228,17 +228,14 @@ func validateGroupBy(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope) (s
return n, nil
}

var validAggs []string
var groupBys []string
for _, expr := range n.GroupByExprs {
validAggs = append(validAggs, expr.String())
groupBys = append(groupBys, expr.String())
}

// TODO: validate columns inside aggregations
// and allow any kind of expression that make use of the grouping
// columns.
for _, expr := range n.SelectedExprs {
if _, ok := expr.(sql.Aggregation); !ok {
if !isValidAgg(validAggs, expr) {
if !expressionReferencesOnlyGroupBys(groupBys, expr) {
return nil, ErrValidationGroupBy.New(expr.String())
}
}
Expand All @@ -250,15 +247,35 @@ func validateGroupBy(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope) (s
return n, nil
}

func isValidAgg(validAggs []string, expr sql.Expression) bool {
switch expr := expr.(type) {
case sql.Aggregation:
return true
case *expression.Alias:
return stringContains(validAggs, expr.String()) || isValidAgg(validAggs, expr.Child)
default:
return stringContains(validAggs, expr.String())
}
func expressionReferencesOnlyGroupBys(groupBys []string, expr sql.Expression) bool {
valid := true
sql.Inspect(expr, func(expr sql.Expression) bool {
switch expr := expr.(type) {
case nil, sql.Aggregation, *expression.Literal:
return false
case *expression.Alias, sql.FunctionExpression:
if stringContains(groupBys, expr.String()) {
return false
}
return true
// cc: https://dev.mysql.com/doc/refman/8.0/en/group-by-handling.html
// Each part of the SelectExpr must refer to the aggregated columns in some way
// TODO: this isn't complete, it's overly restrictive. Dependant columns are fine to reference.
default:
if stringContains(groupBys, expr.String()) {
return true
}

if len(expr.Children()) == 0 {
valid = false
return false
}

return true
}
})

return valid
}

func validateSchemaSource(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope) (sql.Node, error) {
Expand Down

0 comments on commit 8571009

Please sign in to comment.