Skip to content

Commit

Permalink
sql/analyzer: refactor and fix bugs in qualify_columns rule (src-d#706)
Browse files Browse the repository at this point in the history
sql/analyzer: refactor and fix bugs in qualify_columns rule
  • Loading branch information
ajnavarro committed May 13, 2019
2 parents 756e3bf + c059a12 commit 33c1da4
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 200 deletions.
4 changes: 4 additions & 0 deletions engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,10 @@ var queries = []struct {
`SELECT t.date_col FROM (SELECT CONVERT('2019-06-06 00:00:00', DATETIME) as date_col) t GROUP BY t.date_col`,
[]sql.Row{{time.Date(2019, time.June, 6, 0, 0, 0, 0, time.UTC)}},
},
{
`SELECT i AS foo FROM mytable ORDER BY mytable.i`,
[]sql.Row{{int64(1)}, {int64(2)}, {int64(3)}},
},
}

func TestQueries(t *testing.T) {
Expand Down
310 changes: 119 additions & 191 deletions sql/analyzer/resolve_columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,228 +110,156 @@ type column interface {
}

func qualifyColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
span, _ := ctx.Span("qualify_columns")
defer span.Finish()

a.Log("qualify columns")
tables := make(map[string]sql.Node)
tableAliases := make(map[string]string)
colIndex := make(map[string][]string)

indexCols := func(table string, schema sql.Schema) {
for _, col := range schema {
name := strings.ToLower(col.Name)
colIndex[name] = append(colIndex[name], strings.ToLower(table))
}
}

var projects, seenProjects int
plan.Inspect(n, func(n sql.Node) bool {
if _, ok := n.(*plan.Project); ok {
projects++
}
return true
})

return n.TransformUp(func(n sql.Node) (sql.Node, error) {
a.Log("transforming node of type: %T", n)
switch n := n.(type) {
case *plan.TableAlias:
switch t := n.Child.(type) {
case *plan.ResolvedTable, *plan.UnresolvedTable:
name := strings.ToLower(t.(sql.Nameable).Name())
tableAliases[strings.ToLower(n.Name())] = name
default:
tables[strings.ToLower(n.Name())] = n.Child
indexCols(n.Name(), n.Schema())
}
case *plan.ResolvedTable, *plan.SubqueryAlias:
name := strings.ToLower(n.(sql.Nameable).Name())
tables[name] = n
indexCols(name, n.Schema())
}

exp, ok := n.(sql.Expressioner)
if !ok {
if !ok || n.Resolved() {
return n, nil
}

result, err := exp.TransformExpressions(func(e sql.Expression) (sql.Expression, error) {
a.Log("transforming expression of type: %T", e)
switch col := e.(type) {
case *expression.UnresolvedColumn:
// Skip this step for global and session variables
if isGlobalOrSessionColumn(col) {
return col, nil
}
columns := getNodeAvailableColumns(n)
tables := getNodeAvailableTables(n)

col = expression.NewUnresolvedQualifiedColumn(col.Table(), col.Name())
name := strings.ToLower(col.Name())
table := strings.ToLower(col.Table())
if table == "" {
// If a column has no table, it might be an alias
// defined in a child projection, so check that instead
// of incorrectly qualify it.
if isDefinedInChildProject(n, col) {
return col, nil
}

tables := dedupStrings(colIndex[name])
switch len(tables) {
case 0:
// If there are no tables that have any column with the column
// name let's just return it as it is. This may be an alias, so
// we'll wait for the reorder of the projection.
return col, nil
case 1:
col = expression.NewUnresolvedQualifiedColumn(
tables[0],
col.Name(),
)
default:
if _, ok := n.(*plan.GroupBy); ok {
return expression.NewUnresolvedColumn(col.Name()), nil
}
return nil, ErrAmbiguousColumnName.New(col.Name(), strings.Join(tables, ", "))
}
} else {
if real, ok := tableAliases[table]; ok {
col = expression.NewUnresolvedQualifiedColumn(
real,
col.Name(),
)
}
return exp.TransformExpressions(func(e sql.Expression) (sql.Expression, error) {
return qualifyExpression(e, columns, tables)
})
})
}

if _, ok := tables[col.Table()]; !ok {
if len(tables) == 0 {
return nil, sql.ErrTableNotFound.New(col.Table())
}
func qualifyExpression(
e sql.Expression,
columns map[string][]string,
tables map[string]string,
) (sql.Expression, error) {
switch col := e.(type) {
case column:
// Skip this step for global and session variables
if isGlobalOrSessionColumn(col) {
return col, nil
}

similar := similartext.FindFromMap(tables, col.Table())
return nil, sql.ErrTableNotFound.New(col.Table() + similar)
}
name, table := strings.ToLower(col.Name()), strings.ToLower(col.Table())
availableTables := dedupStrings(columns[name])
if table != "" {
table, ok := tables[table]
if !ok {
if len(tables) == 0 {
return nil, sql.ErrTableNotFound.New(col.Table())
}

a.Log("column %q was qualified with table %q", col.Name(), col.Table())
similar := similartext.FindFromMap(tables, col.Table())
return nil, sql.ErrTableNotFound.New(col.Table() + similar)
}

// If the table exists but it's not available for this node it
// means some work is still needed, so just return the column
// and let it be resolved in the next pass.
if !stringContains(availableTables, table) {
return col, nil
case *expression.Star:
if col.Table != "" {
if real, ok := tableAliases[strings.ToLower(col.Table)]; ok {
col = expression.NewQualifiedStar(real)
}
}

if _, ok := tables[strings.ToLower(col.Table)]; !ok {
return nil, sql.ErrTableNotFound.New(col.Table)
}
return expression.NewUnresolvedQualifiedColumn(table, col.Name()), nil
}

return col, nil
}
default:
// If any other kind of expression has a star, just replace it
// with an unqualified star because it cannot be expanded.
return e.TransformUp(func(e sql.Expression) (sql.Expression, error) {
if _, ok := e.(*expression.Star); ok {
return expression.NewStar(), nil
}
return e, nil
})
switch len(availableTables) {
case 0:
// If there are no tables that have any column with the column
// name let's just return it as it is. This may be an alias, so
// we'll wait for the reorder of the projection.
return col, nil
case 1:
return expression.NewUnresolvedQualifiedColumn(
availableTables[0],
col.Name(),
), nil
default:
return nil, ErrAmbiguousColumnName.New(col.Name(), strings.Join(availableTables, ", "))
}
case *expression.Star:
if col.Table != "" {
if real, ok := tables[strings.ToLower(col.Table)]; ok {
col = expression.NewQualifiedStar(real)
}

if _, ok := tables[strings.ToLower(col.Table)]; !ok {
return nil, sql.ErrTableNotFound.New(col.Table)
}
}
return col, nil
default:
// If any other kind of expression has a star, just replace it
// with an unqualified star because it cannot be expanded.
return e.TransformUp(func(e sql.Expression) (sql.Expression, error) {
if _, ok := e.(*expression.Star); ok {
return expression.NewStar(), nil
}
return e, nil
})
}
}

if err != nil {
return nil, err
}
func getNodeAvailableColumns(n sql.Node) map[string][]string {
var columns = make(map[string][]string)
getColumnsInNodes(n.Children(), columns)
return columns
}

// We should ignore the topmost project, because some nodes are
// reordered, such as Sort, and they would not be resolved well.
if n, ok := result.(*plan.Project); ok && projects-seenProjects > 1 {
seenProjects++

// We need to modify the indexed columns to only contain what is
// projected in this project. If the column is not qualified by any
// table, just keep the ones that are currently in the index.
// If it is, then just make those tables available for the column.
// If we don't do this, columns that are not projected will be
// available in this step and may cause false errors or unintended
// results.
var projected = make(map[string][]string)
for _, p := range n.Projections {
var table, col string
switch p := p.(type) {
case column:
table = p.Table()
col = p.Name()
default:
continue
}
func getColumnsInNodes(nodes []sql.Node, columns map[string][]string) {
indexCol := func(table, col string) {
col = strings.ToLower(col)
columns[col] = append(columns[col], strings.ToLower(table))
}

col = strings.ToLower(col)
table = strings.ToLower(table)
if table != "" {
projected[col] = append(projected[col], table)
} else {
projected[col] = append(projected[col], colIndex[col]...)
}
indexExpressions := func(exprs []sql.Expression) {
for _, e := range exprs {
switch e := e.(type) {
case *expression.Alias:
indexCol("", e.Name())
case *expression.GetField:
indexCol(e.Table(), e.Name())
case *expression.UnresolvedColumn:
indexCol(e.Table(), e.Name())
}
}
}

colIndex = make(map[string][]string)
for col, tables := range projected {
colIndex[col] = dedupStrings(tables)
for _, node := range nodes {
switch n := node.(type) {
case *plan.ResolvedTable, *plan.SubqueryAlias:
for _, col := range n.Schema() {
indexCol(col.Source, col.Name)
}
case *plan.Project:
indexExpressions(n.Projections)
case *plan.GroupBy:
indexExpressions(n.Aggregate)
default:
getColumnsInNodes(n.Children(), columns)
}

return result, nil
})
}
}

func isDefinedInChildProject(n sql.Node, col *expression.UnresolvedColumn) bool {
var x sql.Node
for _, child := range n.Children() {
plan.Inspect(child, func(n sql.Node) bool {
func getNodeAvailableTables(n sql.Node) map[string]string {
var tables = make(map[string]string)
for _, c := range n.Children() {
plan.Inspect(c, func(n sql.Node) bool {
switch n := n.(type) {
case *plan.SubqueryAlias:
case *plan.SubqueryAlias, *plan.ResolvedTable:
name := strings.ToLower(n.(sql.Nameable).Name())
tables[name] = name
return false
case *plan.Project, *plan.GroupBy:
if x == nil {
x = n
case *plan.TableAlias:
switch t := n.Child.(type) {
case *plan.ResolvedTable, *plan.UnresolvedTable:
name := strings.ToLower(t.(sql.Nameable).Name())
alias := strings.ToLower(n.Name())
tables[alias] = name
}
return false
default:
return true
}
})

if x != nil {
break
}
}

if x == nil {
return false
}

var found bool
for _, expr := range x.(sql.Expressioner).Expressions() {
switch expr := expr.(type) {
case *expression.Alias:
if strings.ToLower(expr.Name()) == strings.ToLower(col.Name()) {
found = true
}
case column:
if strings.ToLower(expr.Name()) == strings.ToLower(col.Name()) &&
strings.ToLower(expr.Table()) == strings.ToLower(col.Table()) {
found = true
}
}

if found {
break
}
return true
})
}

return found
return tables
}

var errGlobalVariablesNotSupported = errors.NewKind("can't resolve global variable, %s was requested")
Expand Down Expand Up @@ -659,6 +587,6 @@ func dedupStrings(in []string) []string {
return result
}

func isGlobalOrSessionColumn(col *expression.UnresolvedColumn) bool {
func isGlobalOrSessionColumn(col column) bool {
return strings.HasPrefix(col.Name(), "@@") || strings.HasPrefix(col.Table(), "@@")
}
8 changes: 4 additions & 4 deletions sql/analyzer/resolve_columns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ func TestQualifyColumns(t *testing.T) {
require := require.New(t)
f := getRule("qualify_columns")

table := mem.NewTable("mytable", sql.Schema{{Name: "i", Type: sql.Int32}})
table2 := mem.NewTable("mytable2", sql.Schema{{Name: "i", Type: sql.Int32}})
sessionTable := mem.NewTable("@@session", sql.Schema{{Name: "autocommit", Type: sql.Int64}})
globalTable := mem.NewTable("@@global", sql.Schema{{Name: "max_allowed_packet", Type: sql.Int64}})
table := mem.NewTable("mytable", sql.Schema{{Name: "i", Type: sql.Int32, Source: "mytable"}})
table2 := mem.NewTable("mytable2", sql.Schema{{Name: "i", Type: sql.Int32, Source: "mytable2"}})
sessionTable := mem.NewTable("@@session", sql.Schema{{Name: "autocommit", Type: sql.Int64, Source: "@@session"}})
globalTable := mem.NewTable("@@global", sql.Schema{{Name: "max_allowed_packet", Type: sql.Int64, Source: "@@global"}})

node := plan.NewProject(
[]sql.Expression{
Expand Down
Loading

0 comments on commit 33c1da4

Please sign in to comment.