diff --git a/engine_test.go b/engine_test.go index 851b25f52..5b10ff14f 100644 --- a/engine_test.go +++ b/engine_test.go @@ -2406,6 +2406,64 @@ func TestGenerators(t *testing.T) { } } +var structQueries = []struct { + query string + expected []sql.Row +}{ + { + `SELECT s.i, t.s.t FROM t ORDER BY s.i`, + []sql.Row{ + {int64(1), "first"}, + {int64(2), "second"}, + {int64(3), "third"}, + }, + }, + { + `SELECT s.i, s.t FROM t ORDER BY s.i`, + []sql.Row{ + {int64(1), "first"}, + {int64(2), "second"}, + {int64(3), "third"}, + }, + }, + { + `SELECT s.i, COUNT(*) FROM t GROUP BY s.i`, + []sql.Row{ + {int64(1), int64(1)}, + {int64(2), int64(1)}, + {int64(3), int64(1)}, + }, + }, +} + +func TestStructs(t *testing.T) { + schema := sql.Schema{ + {Name: "i", Type: sql.Int64}, + {Name: "t", Type: sql.Text}, + } + table := mem.NewPartitionedTable("t", sql.Schema{ + {Name: "s", Type: sql.Struct(schema), Source: "t"}, + }, testNumPartitions) + + insertRows( + t, table, + sql.NewRow(map[string]interface{}{"i": int64(1), "t": "first"}), + sql.NewRow(map[string]interface{}{"i": int64(2), "t": "second"}), + sql.NewRow(map[string]interface{}{"i": int64(3), "t": "third"}), + ) + + db := mem.NewDatabase("db") + db.AddTable("t", table) + + catalog := sql.NewCatalog() + catalog.AddDatabase(db) + e := sqle.New(catalog, analyzer.NewDefault(catalog), new(sqle.Config)) + + for _, q := range structQueries { + testQuery(t, e, q.query, q.expected) + } +} + func insertRows(t *testing.T, table sql.Inserter, rows ...sql.Row) { t.Helper() diff --git a/sql/analyzer/resolve_columns.go b/sql/analyzer/resolve_columns.go index 1fa006e07..c5471293d 100644 --- a/sql/analyzer/resolve_columns.go +++ b/sql/analyzer/resolve_columns.go @@ -6,7 +6,6 @@ import ( "strings" "gopkg.in/src-d/go-errors.v1" - "gopkg.in/src-d/go-mysql-server.v0/internal/similartext" "gopkg.in/src-d/go-mysql-server.v0/sql" "gopkg.in/src-d/go-mysql-server.v0/sql/expression" "gopkg.in/src-d/go-mysql-server.v0/sql/plan" @@ -147,17 +146,24 @@ func qualifyExpression( return col, nil } - name, table := strings.ToLower(col.Name()), strings.ToLower(col.Table()) + name, tableName := strings.ToLower(col.Name()), strings.ToLower(col.Table()) availableTables := dedupStrings(columns[name]) - if table != "" { - table, ok := tables[table] + if tableName != "" { + table, ok := tables[tableName] if !ok { - if len(tables) == 0 { - return nil, sql.ErrTableNotFound.New(col.Table()) + // If the table does not exist but the column does, it may be a + // struct field access. + if columnExists(columns, tableName) { + return expression.NewUnresolvedField( + expression.NewUnresolvedColumn(col.Table()), + col.Name(), + ), nil } - similar := similartext.FindFromMap(tables, col.Table()) - return nil, sql.ErrTableNotFound.New(col.Table() + similar) + // If it cannot be resolved, then pass along and let it fail + // somewhere else. Maybe we're missing some steps before this + // can be resolved. + return col, nil } // If the table exists but it's not available for this node it @@ -207,6 +213,15 @@ func qualifyExpression( } } +func columnExists(columns map[string][]string, col string) bool { + for c := range columns { + if strings.ToLower(c) == strings.ToLower(col) { + return true + } + } + return false +} + func getNodeAvailableColumns(n sql.Node) map[string][]string { var columns = make(map[string][]string) getColumnsInNodes(n.Children(), columns) @@ -369,7 +384,23 @@ func resolveColumnExpression( return &deferredColumn{uc}, nil default: if table != "" { - return nil, ErrColumnTableNotFound.New(e.Table(), e.Name()) + if isStructField(uc, columns) { + return expression.NewUnresolvedField( + expression.NewUnresolvedColumn(uc.Table()), + uc.Name(), + ), nil + } + + // If we manage to find any column with the given table, it's because + // the column does not exist. + for col := range columns { + if col.table == table { + return nil, ErrColumnTableNotFound.New(e.Table(), e.Name()) + } + } + + // In any other case, it's the table the one that does not exist. + return nil, sql.ErrTableNotFound.New(e.Table()) } return nil, ErrColumnNotFound.New(e.Name()) @@ -385,6 +416,52 @@ func resolveColumnExpression( ), nil } +func isStructField(c column, columns map[tableCol]indexedCol) bool { + for _, col := range columns { + if strings.ToLower(c.Table()) == strings.ToLower(col.Name) && + sql.Field(col.Type, c.Name()) != nil { + return true + } + } + return false +} + +var errFieldNotFound = errors.NewKind("field %s not found on struct %s") + +func resolveStructFields(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + span, ctx := ctx.Span("resolve_struct_fields") + defer span.Finish() + + return n.TransformUp(func(n sql.Node) (sql.Node, error) { + if n.Resolved() { + return n, nil + } + + expressioner, ok := n.(sql.Expressioner) + if !ok { + return n, nil + } + + return expressioner.TransformExpressions(func(e sql.Expression) (sql.Expression, error) { + f, ok := e.(*expression.UnresolvedField) + if !ok { + return e, nil + } + + if !f.Struct.Resolved() { + return e, nil + } + + field := sql.Field(f.Struct.Type(), f.Name) + if field == nil { + return nil, errFieldNotFound.New(f.Name, f.Struct) + } + + return expression.NewGetStructField(f.Struct, f.Name), nil + }) + }) +} + // resolveGroupingColumns reorders the aggregation in a groupby so aliases // defined in it can be resolved in the grouping of the groupby. To do so, // all aliases are pushed down to a projection node under the group by. diff --git a/sql/analyzer/resolve_columns_test.go b/sql/analyzer/resolve_columns_test.go index a46544367..ada2f8649 100644 --- a/sql/analyzer/resolve_columns_test.go +++ b/sql/analyzer/resolve_columns_test.go @@ -186,14 +186,24 @@ func TestQualifyColumns(t *testing.T) { node = plan.NewProject( []sql.Expression{ - expression.NewUnresolvedQualifiedColumn("foo", "i"), + expression.NewUnresolvedQualifiedColumn("i", "some_field"), }, - plan.NewTableAlias("a", plan.NewResolvedTable(table)), + plan.NewResolvedTable(table), ) - _, err = f.Apply(sql.NewEmptyContext(), nil, node) - require.Error(err) - require.True(sql.ErrTableNotFound.Is(err)) + expected = plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedField( + expression.NewUnresolvedColumn("i"), + "some_field", + ), + }, + plan.NewResolvedTable(table), + ) + + result, err = f.Apply(sql.NewEmptyContext(), nil, node) + require.NoError(err) + require.Equal(expected, result) node = plan.NewProject( []sql.Expression{ diff --git a/sql/analyzer/rules.go b/sql/analyzer/rules.go index c2b8daf00..ce4423cc1 100644 --- a/sql/analyzer/rules.go +++ b/sql/analyzer/rules.go @@ -12,6 +12,7 @@ var DefaultRules = []Rule{ {"resolve_grouping_columns", resolveGroupingColumns}, {"qualify_columns", qualifyColumns}, {"resolve_columns", resolveColumns}, + {"resolve_struct_fields", resolveStructFields}, {"resolve_database", resolveDatabase}, {"resolve_star", resolveStar}, {"resolve_functions", resolveFunctions}, diff --git a/sql/expression/get_field.go b/sql/expression/get_field.go index 75690a5d2..b2fe5657e 100644 --- a/sql/expression/get_field.go +++ b/sql/expression/get_field.go @@ -128,3 +128,74 @@ func (f *GetSessionField) String() string { return "@@" + f.name } func (f *GetSessionField) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { return fn(f) } + +// GetStructField is an expression to get a field from a struct column. +type GetStructField struct { + Struct sql.Expression + Name string +} + +// NewGetStructField creates a new GetStructField expression. +func NewGetStructField(s sql.Expression, fieldName string) *GetStructField { + return &GetStructField{s, fieldName} +} + +// Children implements the Expression interface. +func (p *GetStructField) Children() []sql.Expression { + return []sql.Expression{p.Struct} +} + +// Resolved implements the Expression interface. +func (p *GetStructField) Resolved() bool { + return p.Struct.Resolved() +} + +func (p *GetStructField) column() *sql.Column { + return sql.Field(p.Struct.Type(), p.Name) +} + +// IsNullable returns whether the field is nullable or not. +func (p *GetStructField) IsNullable() bool { + return p.Struct.IsNullable() || p.column().Nullable +} + +// Type returns the type of the field. +func (p *GetStructField) Type() sql.Type { + return p.column().Type +} + +// Eval implements the Expression interface. +func (p *GetStructField) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + s, err := p.Struct.Eval(ctx, row) + if err != nil { + return nil, err + } + + if s == nil { + return nil, nil + } + + s, err = p.Struct.Type().Convert(s) + if err != nil { + return nil, err + } + + if val, ok := s.(map[string]interface{})[p.Name]; ok { + return p.Type().Convert(val) + } + + return nil, nil +} + +// TransformUp implements the Expression interface. +func (p *GetStructField) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { + s, err := p.Struct.TransformUp(f) + if err != nil { + return nil, err + } + return f(NewGetStructField(s, p.Name)) +} + +func (p *GetStructField) String() string { + return fmt.Sprintf("%s.%s", p.Struct, p.Name) +} diff --git a/sql/expression/unresolved.go b/sql/expression/unresolved.go index 537ed0973..32dcea9af 100644 --- a/sql/expression/unresolved.go +++ b/sql/expression/unresolved.go @@ -139,3 +139,52 @@ func (uf *UnresolvedFunction) TransformUp(f sql.TransformExprFunc) (sql.Expressi return f(NewUnresolvedFunction(uf.name, uf.IsAggregate, rc...)) } + +// UnresolvedField is an unresolved expression to get a field from a struct column. +type UnresolvedField struct { + Struct sql.Expression + Name string +} + +// NewUnresolvedField creates a new UnresolvedField expression. +func NewUnresolvedField(s sql.Expression, fieldName string) *UnresolvedField { + return &UnresolvedField{s, fieldName} +} + +// Children implements the Expression interface. +func (p *UnresolvedField) Children() []sql.Expression { + return []sql.Expression{p.Struct} +} + +// Resolved implements the Expression interface. +func (p *UnresolvedField) Resolved() bool { + return false +} + +// IsNullable returns whether the field is nullable or not. +func (p *UnresolvedField) IsNullable() bool { + panic("unresolved field is a placeholder node, but IsNullable was called") +} + +// Type returns the type of the field. +func (p *UnresolvedField) Type() sql.Type { + panic("unresolved field is a placeholder node, but Type was called") +} + +// Eval implements the Expression interface. +func (p *UnresolvedField) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + panic("unresolved field is a placeholder node, but Eval was called") +} + +// TransformUp implements the Expression interface. +func (p *UnresolvedField) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { + s, err := p.Struct.TransformUp(f) + if err != nil { + return nil, err + } + return f(NewUnresolvedField(s, p.Name)) +} + +func (p *UnresolvedField) String() string { + return fmt.Sprintf("%s.%s", p.Struct, p.Name) +} diff --git a/sql/parse/parse.go b/sql/parse/parse.go index 33dc6bde5..b4fc24f1b 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -748,6 +748,20 @@ func exprToExpression(e sqlparser.Expr) (sql.Expression, error) { return expression.NewLiteral(nil, sql.Null), nil case *sqlparser.ColName: if !v.Qualifier.IsEmpty() { + // If we find something of the form A.B.C we're going to treat it + // as a struct field access. + // TODO: this should be handled better when GetFields support being + // qualified with the database. + if !v.Qualifier.Qualifier.IsEmpty() { + return expression.NewUnresolvedField( + expression.NewUnresolvedQualifiedColumn( + v.Qualifier.Qualifier.String(), + v.Qualifier.Name.String(), + ), + v.Name.String(), + ), nil + } + return expression.NewUnresolvedQualifiedColumn( v.Qualifier.Name.String(), v.Name.String(), diff --git a/sql/parse/parse_test.go b/sql/parse/parse_test.go index 0fb3daa5e..c5466bf15 100644 --- a/sql/parse/parse_test.go +++ b/sql/parse/parse_test.go @@ -1085,6 +1085,15 @@ var fixtures = map[string]sql.Node{ ), ), ), + `SELECT a.b.c FROM a`: plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedField( + expression.NewUnresolvedQualifiedColumn("a", "b"), + "c", + ), + }, + plan.NewUnresolvedTable("a", ""), + ), } func TestParse(t *testing.T) { diff --git a/sql/type.go b/sql/type.go index f0d7bf0eb..25d7d0e7f 100644 --- a/sql/type.go +++ b/sql/type.go @@ -880,6 +880,120 @@ func (t arrayT) Compare(a, b interface{}) (int, error) { return 0, nil } +type structT struct { + schema Schema +} + +// Struct creates a new struct type with the given schema. +func Struct(schema Schema) Type { + return structT{schema} +} + +func (t structT) String() string { + var fields = make([]string, len(t.schema)) + for i, c := range t.schema { + fields[i] = fmt.Sprintf("%s %s", c.Name, c.Type) + } + return fmt.Sprintf("STRUCT(%s)", strings.Join(fields, ", ")) +} + +func (t structT) Type() query.Type { + return sqltypes.TypeJSON +} + +func (t structT) SQL(v interface{}) (sqltypes.Value, error) { + if _, ok := v.(nullT); ok { + return sqltypes.NULL, nil + } + + v, err := t.Convert(v) + if err != nil { + return sqltypes.Value{}, err + } + + return JSON.SQL(v) +} + +var ( + errNotValidStruct = errors.NewKind("can't convert struct because the schema does not match") + errCantConvertStruct = errors.NewKind("can't convert struct: %s") +) + +func (t structT) Convert(v interface{}) (interface{}, error) { + switch v := v.(type) { + case map[string]interface{}: + var result = make(map[string]interface{}) + for _, col := range t.schema { + val, ok := v[col.Name] + if !ok { + if !col.Nullable { + return nil, errNotValidStruct.New() + } + result[col.Name] = col.Default + continue + } + + val, err := col.Type.Convert(val) + if err != nil { + return nil, errCantConvertStruct.New(err) + } + + result[col.Name] = val + } + + return result, nil + case string: + return t.Convert([]byte(v)) + case []byte: + var m = make(map[string]interface{}) + if err := json.Unmarshal(v, &m); err != nil { + return nil, errCantConvertStruct.New(err) + } + return t.Convert(m) + default: + bs, err := json.Marshal(v) + if err != nil { + return nil, errCantConvertStruct.New(err) + } + + var m = make(map[string]interface{}) + if err := json.Unmarshal(bs, &m); err != nil { + return nil, errCantConvertStruct.New(err) + } + + return t.Convert(m) + } +} + +func (t structT) Compare(a, b interface{}) (int, error) { + a, err := t.Convert(a) + if err != nil { + return 0, err + } + + b, err = t.Convert(b) + if err != nil { + return 0, err + } + + left := a.(map[string]interface{}) + right := b.(map[string]interface{}) + + for k := range left { + n := t.schema.IndexOf(k, "") + cmp, err := t.schema[n].Type.Compare(left[k], right[k]) + if err != nil { + return 0, err + } + + if cmp != 0 { + return cmp, nil + } + } + + return 0, nil +} + // IsNumber checks if t is a number type func IsNumber(t Type) bool { return IsInteger(t) || IsDecimal(t) @@ -924,6 +1038,12 @@ func IsArray(t Type) bool { return ok } +// IsStruct returns whether the given type is a struct. +func IsStruct(t Type) bool { + _, ok := t.(structT) + return ok +} + // NumColumns returns the number of columns in a type. This is one for all // types, except tuples. func NumColumns(t Type) int { @@ -934,6 +1054,20 @@ func NumColumns(t Type) int { return len(v) } +// Field returns the field with the given name of the given type if it's a +// struct. If it's not found or it's not a struct, it will return nil. +func Field(t Type, name string) *Column { + if t, ok := t.(structT); ok { + for _, col := range t.schema { + if strings.ToLower(name) == strings.ToLower(col.Name) { + return col + } + } + } + + return nil +} + // MySQLTypeName returns the MySQL display name for the given type. func MySQLTypeName(t Type) string { switch t.Type() { diff --git a/sql/type_test.go b/sql/type_test.go index fb8de7f3b..b1de29412 100644 --- a/sql/type_test.go +++ b/sql/type_test.go @@ -5,6 +5,7 @@ import ( "time" "github.com/stretchr/testify/require" + "gopkg.in/src-d/go-errors.v1" "gopkg.in/src-d/go-vitess.v1/sqltypes" "gopkg.in/src-d/go-vitess.v1/vt/proto/query" ) @@ -281,6 +282,76 @@ func TestUnderlyingType(t *testing.T) { require.Equal(t, Text, UnderlyingType(Text)) } +type testStruct struct { + A int + B string +} + +func TestStruct(t *testing.T) { + require := require.New(t) + s := Struct(Schema{ + {Name: "A", Type: Int64, Nullable: false}, + {Name: "B", Type: Text, Nullable: true, Default: "default"}, + }) + + testCases := []struct { + name string + input interface{} + err *errors.Kind + expected interface{} + }{ + { + "struct input", + testStruct{A: 1, B: "bar"}, + nil, + map[string]interface{}{"A": int64(1), "B": "bar"}, + }, + { + "map input", + map[string]interface{}{"A": 1, "B": "bar"}, + nil, + map[string]interface{}{"A": int64(1), "B": "bar"}, + }, + { + "bytes input", + []byte(`{"A": 1, "B": "bar"}`), + nil, + map[string]interface{}{"A": int64(1), "B": "bar"}, + }, + { + "string input", + `{"A": 1, "B": "bar"}`, + nil, + map[string]interface{}{"A": int64(1), "B": "bar"}, + }, + { + "not nullable field not present", + `{"B": "bar"}`, + errNotValidStruct, + nil, + }, + { + "nullable field not present with default", + `{"A": 1}`, + nil, + map[string]interface{}{"A": int64(1), "B": "default"}, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + result, err := s.Convert(tt.input) + if tt.err != nil { + require.Error(err) + require.True(tt.err.Is(err)) + } else { + require.NoError(err) + require.Equal(tt.expected, result) + } + }) + } +} + func eq(t *testing.T, typ Type, a, b interface{}) { t.Helper() cmp, err := typ.Compare(a, b)