Skip to content

Commit

Permalink
sql: implement struct types
Browse files Browse the repository at this point in the history
Depends on src-d#720

This PR implements struct types, which will be serialized to JSON
before being sent to the mysql client using the mysql wire proto.

Internally, structs are just `map[string]interface{}`, but they can
actually be anything that's convertible to that, because of the way
the Convert method of the struct type is implemented. That means,
the result of an UDF or a table that returns a struct may be an
actual Go struct and it will then be transformed into the internal
map. It does have a penalty, though, because structs require
encoding to JSON and then decoding into `map[string]interface{}`.
Structs have a schema, identical to a table schema, except their
`Source` will always be empty.

Resolution of columns has also been slightly change in order to
resolve getting fields from structs using the `.` operator, which
required some trade-offs in some rules, such as not erroring
anymore in `qualify_columns` when the table is not found. That
error was delegated to `resolve_columns` in order to make resolution
possible, as the syntax is a bit ambiguous.
The advantage of using dot is the fact that no changes have to be made
to the parser in order for it to work.

Signed-off-by: Miguel Molina <miguel@erizocosmi.co>
  • Loading branch information
erizocosmico committed May 30, 2019
1 parent 814b219 commit 385c7a3
Show file tree
Hide file tree
Showing 10 changed files with 508 additions and 14 deletions.
58 changes: 58 additions & 0 deletions engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2406,6 +2406,64 @@ func TestGenerators(t *testing.T) {
}
}

var structQueries = []struct {
query string
expected []sql.Row
}{
{
`SELECT s.i, t.s.t FROM t ORDER BY s.i`,
[]sql.Row{
{int64(1), "first"},
{int64(2), "second"},
{int64(3), "third"},
},
},
{
`SELECT s.i, s.t FROM t ORDER BY s.i`,
[]sql.Row{
{int64(1), "first"},
{int64(2), "second"},
{int64(3), "third"},
},
},
{
`SELECT s.i, COUNT(*) FROM t GROUP BY s.i`,
[]sql.Row{
{int64(1), int64(1)},
{int64(2), int64(1)},
{int64(3), int64(1)},
},
},
}

func TestStructs(t *testing.T) {
schema := sql.Schema{
{Name: "i", Type: sql.Int64},
{Name: "t", Type: sql.Text},
}
table := mem.NewPartitionedTable("t", sql.Schema{
{Name: "s", Type: sql.Struct(schema), Source: "t"},
}, testNumPartitions)

insertRows(
t, table,
sql.NewRow(map[string]interface{}{"i": int64(1), "t": "first"}),
sql.NewRow(map[string]interface{}{"i": int64(2), "t": "second"}),
sql.NewRow(map[string]interface{}{"i": int64(3), "t": "third"}),
)

db := mem.NewDatabase("db")
db.AddTable("t", table)

catalog := sql.NewCatalog()
catalog.AddDatabase(db)
e := sqle.New(catalog, analyzer.NewDefault(catalog), new(sqle.Config))

for _, q := range structQueries {
testQuery(t, e, q.query, q.expected)
}
}

func insertRows(t *testing.T, table sql.Inserter, rows ...sql.Row) {
t.Helper()

Expand Down
95 changes: 86 additions & 9 deletions sql/analyzer/resolve_columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"strings"

"gopkg.in/src-d/go-errors.v1"
"gopkg.in/src-d/go-mysql-server.v0/internal/similartext"
"gopkg.in/src-d/go-mysql-server.v0/sql"
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
"gopkg.in/src-d/go-mysql-server.v0/sql/plan"
Expand Down Expand Up @@ -147,17 +146,24 @@ func qualifyExpression(
return col, nil
}

name, table := strings.ToLower(col.Name()), strings.ToLower(col.Table())
name, tableName := strings.ToLower(col.Name()), strings.ToLower(col.Table())
availableTables := dedupStrings(columns[name])
if table != "" {
table, ok := tables[table]
if tableName != "" {
table, ok := tables[tableName]
if !ok {
if len(tables) == 0 {
return nil, sql.ErrTableNotFound.New(col.Table())
// If the table does not exist but the column does, it may be a
// struct field access.
if columnExists(columns, tableName) {
return expression.NewUnresolvedField(
expression.NewUnresolvedColumn(col.Table()),
col.Name(),
), nil
}

similar := similartext.FindFromMap(tables, col.Table())
return nil, sql.ErrTableNotFound.New(col.Table() + similar)
// If it cannot be resolved, then pass along and let it fail
// somewhere else. Maybe we're missing some steps before this
// can be resolved.
return col, nil
}

// If the table exists but it's not available for this node it
Expand Down Expand Up @@ -207,6 +213,15 @@ func qualifyExpression(
}
}

func columnExists(columns map[string][]string, col string) bool {
for c := range columns {
if strings.ToLower(c) == strings.ToLower(col) {
return true
}
}
return false
}

func getNodeAvailableColumns(n sql.Node) map[string][]string {
var columns = make(map[string][]string)
getColumnsInNodes(n.Children(), columns)
Expand Down Expand Up @@ -369,7 +384,23 @@ func resolveColumnExpression(
return &deferredColumn{uc}, nil
default:
if table != "" {
return nil, ErrColumnTableNotFound.New(e.Table(), e.Name())
if isStructField(uc, columns) {
return expression.NewUnresolvedField(
expression.NewUnresolvedColumn(uc.Table()),
uc.Name(),
), nil
}

// If we manage to find any column with the given table, it's because
// the column does not exist.
for col := range columns {
if col.table == table {
return nil, ErrColumnTableNotFound.New(e.Table(), e.Name())
}
}

// In any other case, it's the table the one that does not exist.
return nil, sql.ErrTableNotFound.New(e.Table())
}

return nil, ErrColumnNotFound.New(e.Name())
Expand All @@ -385,6 +416,52 @@ func resolveColumnExpression(
), nil
}

func isStructField(c column, columns map[tableCol]indexedCol) bool {
for _, col := range columns {
if strings.ToLower(c.Table()) == strings.ToLower(col.Name) &&
sql.Field(col.Type, c.Name()) != nil {
return true
}
}
return false
}

var errFieldNotFound = errors.NewKind("field %s not found on struct %s")

func resolveStructFields(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
span, ctx := ctx.Span("resolve_struct_fields")
defer span.Finish()

return n.TransformUp(func(n sql.Node) (sql.Node, error) {
if n.Resolved() {
return n, nil
}

expressioner, ok := n.(sql.Expressioner)
if !ok {
return n, nil
}

return expressioner.TransformExpressions(func(e sql.Expression) (sql.Expression, error) {
f, ok := e.(*expression.UnresolvedField)
if !ok {
return e, nil
}

if !f.Struct.Resolved() {
return e, nil
}

field := sql.Field(f.Struct.Type(), f.Name)
if field == nil {
return nil, errFieldNotFound.New(f.Name, f.Struct)
}

return expression.NewGetStructField(f.Struct, f.Name), nil
})
})
}

// resolveGroupingColumns reorders the aggregation in a groupby so aliases
// defined in it can be resolved in the grouping of the groupby. To do so,
// all aliases are pushed down to a projection node under the group by.
Expand Down
20 changes: 15 additions & 5 deletions sql/analyzer/resolve_columns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,14 +186,24 @@ func TestQualifyColumns(t *testing.T) {

node = plan.NewProject(
[]sql.Expression{
expression.NewUnresolvedQualifiedColumn("foo", "i"),
expression.NewUnresolvedQualifiedColumn("i", "some_field"),
},
plan.NewTableAlias("a", plan.NewResolvedTable(table)),
plan.NewResolvedTable(table),
)

_, err = f.Apply(sql.NewEmptyContext(), nil, node)
require.Error(err)
require.True(sql.ErrTableNotFound.Is(err))
expected = plan.NewProject(
[]sql.Expression{
expression.NewUnresolvedField(
expression.NewUnresolvedColumn("i"),
"some_field",
),
},
plan.NewResolvedTable(table),
)

result, err = f.Apply(sql.NewEmptyContext(), nil, node)
require.NoError(err)
require.Equal(expected, result)

node = plan.NewProject(
[]sql.Expression{
Expand Down
1 change: 1 addition & 0 deletions sql/analyzer/rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ var DefaultRules = []Rule{
{"resolve_grouping_columns", resolveGroupingColumns},
{"qualify_columns", qualifyColumns},
{"resolve_columns", resolveColumns},
{"resolve_struct_fields", resolveStructFields},
{"resolve_database", resolveDatabase},
{"resolve_star", resolveStar},
{"resolve_functions", resolveFunctions},
Expand Down
71 changes: 71 additions & 0 deletions sql/expression/get_field.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,74 @@ func (f *GetSessionField) String() string { return "@@" + f.name }
func (f *GetSessionField) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) {
return fn(f)
}

// GetStructField is an expression to get a field from a struct column.
type GetStructField struct {
Struct sql.Expression
Name string
}

// NewGetStructField creates a new GetStructField expression.
func NewGetStructField(s sql.Expression, fieldName string) *GetStructField {
return &GetStructField{s, fieldName}
}

// Children implements the Expression interface.
func (p *GetStructField) Children() []sql.Expression {
return []sql.Expression{p.Struct}
}

// Resolved implements the Expression interface.
func (p *GetStructField) Resolved() bool {
return p.Struct.Resolved()
}

func (p *GetStructField) column() *sql.Column {
return sql.Field(p.Struct.Type(), p.Name)
}

// IsNullable returns whether the field is nullable or not.
func (p *GetStructField) IsNullable() bool {
return p.Struct.IsNullable() || p.column().Nullable
}

// Type returns the type of the field.
func (p *GetStructField) Type() sql.Type {
return p.column().Type
}

// Eval implements the Expression interface.
func (p *GetStructField) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
s, err := p.Struct.Eval(ctx, row)
if err != nil {
return nil, err
}

if s == nil {
return nil, nil
}

s, err = p.Struct.Type().Convert(s)
if err != nil {
return nil, err
}

if val, ok := s.(map[string]interface{})[p.Name]; ok {
return p.Type().Convert(val)
}

return nil, nil
}

// TransformUp implements the Expression interface.
func (p *GetStructField) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) {
s, err := p.Struct.TransformUp(f)
if err != nil {
return nil, err
}
return f(NewGetStructField(s, p.Name))
}

func (p *GetStructField) String() string {
return fmt.Sprintf("%s.%s", p.Struct, p.Name)
}
49 changes: 49 additions & 0 deletions sql/expression/unresolved.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,52 @@ func (uf *UnresolvedFunction) TransformUp(f sql.TransformExprFunc) (sql.Expressi

return f(NewUnresolvedFunction(uf.name, uf.IsAggregate, rc...))
}

// UnresolvedField is an unresolved expression to get a field from a struct column.
type UnresolvedField struct {
Struct sql.Expression
Name string
}

// NewUnresolvedField creates a new UnresolvedField expression.
func NewUnresolvedField(s sql.Expression, fieldName string) *UnresolvedField {
return &UnresolvedField{s, fieldName}
}

// Children implements the Expression interface.
func (p *UnresolvedField) Children() []sql.Expression {
return []sql.Expression{p.Struct}
}

// Resolved implements the Expression interface.
func (p *UnresolvedField) Resolved() bool {
return false
}

// IsNullable returns whether the field is nullable or not.
func (p *UnresolvedField) IsNullable() bool {
panic("unresolved field is a placeholder node, but IsNullable was called")
}

// Type returns the type of the field.
func (p *UnresolvedField) Type() sql.Type {
panic("unresolved field is a placeholder node, but Type was called")
}

// Eval implements the Expression interface.
func (p *UnresolvedField) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
panic("unresolved field is a placeholder node, but Eval was called")
}

// TransformUp implements the Expression interface.
func (p *UnresolvedField) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) {
s, err := p.Struct.TransformUp(f)
if err != nil {
return nil, err
}
return f(NewUnresolvedField(s, p.Name))
}

func (p *UnresolvedField) String() string {
return fmt.Sprintf("%s.%s", p.Struct, p.Name)
}
14 changes: 14 additions & 0 deletions sql/parse/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,20 @@ func exprToExpression(e sqlparser.Expr) (sql.Expression, error) {
return expression.NewLiteral(nil, sql.Null), nil
case *sqlparser.ColName:
if !v.Qualifier.IsEmpty() {
// If we find something of the form A.B.C we're going to treat it
// as a struct field access.
// TODO: this should be handled better when GetFields support being
// qualified with the database.
if !v.Qualifier.Qualifier.IsEmpty() {
return expression.NewUnresolvedField(
expression.NewUnresolvedQualifiedColumn(
v.Qualifier.Qualifier.String(),
v.Qualifier.Name.String(),
),
v.Name.String(),
), nil
}

return expression.NewUnresolvedQualifiedColumn(
v.Qualifier.Name.String(),
v.Name.String(),
Expand Down
Loading

0 comments on commit 385c7a3

Please sign in to comment.