Skip to content

Commit

Permalink
dialect/sql/schema: initial work for incremental migration
Browse files Browse the repository at this point in the history
This is a WIP PR and should be ignored this moment.
It's based on PR #221 created by Erik Hollensbe (He should
get his credit for his work before we land this).
  • Loading branch information
a8m committed Apr 11, 2020
1 parent 8effe6d commit 1bf2416
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 13 deletions.
6 changes: 6 additions & 0 deletions dialect/sql/builder_test.go
Expand Up @@ -1164,6 +1164,12 @@ func TestBuilder(t *testing.T) {
input: DropIndex("name_index").Table("users"),
wantQuery: "DROP INDEX `name_index` ON `users`",
},
{
input: Select().
From(Table("pragma_table_info('t1')").Unquote()).
OrderBy("pk"),
wantQuery: "SELECT * FROM pragma_table_info('t1') ORDER BY `pk`",
},
}
for i, tt := range tests {
t.Run(strconv.Itoa(i), func(t *testing.T) {
Expand Down
3 changes: 1 addition & 2 deletions dialect/sql/schema/schema.go
Expand Up @@ -98,7 +98,6 @@ func (t *Table) column(name string) (*Column, bool) {
}

// index returns a table index by its name.
// faster than map lookup for most cases.
func (t *Table) index(name string) (*Index, bool) {
for _, idx := range t.Indexes {
if idx.Name == name {
Expand Down Expand Up @@ -186,7 +185,7 @@ func (c Column) FloatType() bool { return c.Type == field.TypeFloat32 || c.Type
// ScanDefault scans the default value string to its interface type.
func (c *Column) ScanDefault(value string) (err error) {
switch {
case value == Null: // ignore.
case strings.ToUpper(value) == Null: // ignore.
case c.IntType():
v := &sql.NullInt64{}
if err := v.Scan(value); err != nil {
Expand Down
145 changes: 134 additions & 11 deletions dialect/sql/schema/sqlite.go
Expand Up @@ -7,6 +7,7 @@ package schema
import (
"context"
"fmt"
"math"

"github.com/facebookincubator/ent/dialect"
"github.com/facebookincubator/ent/dialect/sql"
Expand Down Expand Up @@ -92,19 +93,15 @@ func (*SQLite) cType(c *Column) (t string) {
switch c.Type {
case field.TypeBool:
t = "bool"
case field.TypeInt8, field.TypeUint8, field.TypeInt, field.TypeInt16, field.TypeInt32, field.TypeUint, field.TypeUint16, field.TypeUint32:
case field.TypeInt8, field.TypeUint8, field.TypeInt16, field.TypeUint16, field.TypeInt32,
field.TypeUint32, field.TypeUint, field.TypeInt, field.TypeInt64, field.TypeUint64:
t = "integer"
case field.TypeInt64, field.TypeUint64:
t = "bigint"
case field.TypeBytes:
t = "blob"
case field.TypeString, field.TypeEnum:
size := c.Size
if size == 0 {
size = DefaultStringLen
}
// sqlite has no size limit on varchar.
t = fmt.Sprintf("varchar(%d)", size)
// SQLite does not impose any length restrictions on
// the length of strings, BLOBs or numeric values.
t = fmt.Sprintf("varchar(%d)", DefaultStringLen)
case field.TypeFloat32, field.TypeFloat64:
t = "real"
case field.TypeTime:
Expand Down Expand Up @@ -151,6 +148,132 @@ func (d *SQLite) dropIndex(ctx context.Context, tx dialect.Tx, idx *Index, table
func (d *SQLite) fkExist(context.Context, dialect.Tx, string) (bool, error) { return true, nil }

// table returns always error to indicate that SQLite dialect doesn't support incremental migration.
func (d *SQLite) table(context.Context, dialect.Tx, string) (*Table, error) {
return nil, fmt.Errorf("sqlite dialect does not support incremental migration")
func (d *SQLite) table(ctx context.Context, tx dialect.Tx, name string) (*Table, error) {
rows := &sql.Rows{}
query, args := sql.Select("name", "type", "notnull", "dflt_value", "pk").
From(sql.Table(fmt.Sprintf("pragma_table_info('%s')", name)).Unquote()).
OrderBy("pk").
Query()
if err := tx.Query(ctx, query, args, rows); err != nil {
return nil, fmt.Errorf("sqlite: reading table description %v", err)
}
// Call Close in cases of failures (Close is idempotent).
defer rows.Close()
t := NewTable(name)
for rows.Next() {
c := &Column{}
if err := d.scanColumn(c, rows); err != nil {
return nil, fmt.Errorf("sqlite: %v", err)
}
if c.PrimaryKey() {
t.PrimaryKey = append(t.PrimaryKey, c)
}
t.AddColumn(c)
}
if err := rows.Close(); err != nil {
return nil, fmt.Errorf("sqlite: closing rows %v", err)
}
indexes, err := d.indexes(ctx, tx, name)
if err != nil {
return nil, err
}
// Add and link indexes to table columns.
for _, idx := range indexes {
t.AddIndex(idx.Name, idx.Unique, idx.columns)
}
return t, nil
}

// table loads the table indexes from the database.
func (d *SQLite) indexes(ctx context.Context, tx dialect.Tx, name string) (Indexes, error) {
rows := &sql.Rows{}
query, args := sql.Select("name", "unique").
From(sql.Table(fmt.Sprintf("pragma_index_list('%s')", name)).Unquote()).
Query()
if err := tx.Query(ctx, query, args, rows); err != nil {
return nil, fmt.Errorf("reading table indexes %v", err)
}
defer rows.Close()
var idx Indexes
for rows.Next() {
i := &Index{}
if err := rows.Scan(&name, i.Name, i.Unique); err != nil {
return nil, fmt.Errorf("scanning index description %v", err)
}
idx = append(idx, i)
}
if err := rows.Close(); err != nil {
return nil, fmt.Errorf("closing rows %v", err)
}
for i := range idx {
columns, err := d.indexColumns(ctx, tx, idx[i].Name)
if err != nil {
return nil, err
}
idx[i].columns = columns
}
return idx, nil
}

// indexColumns loads index columns from index info.
func (d *SQLite) indexColumns(ctx context.Context, tx dialect.Tx, name string) ([]string, error) {
rows := &sql.Rows{}
query, args := sql.Select("name").
From(sql.Table(fmt.Sprintf("pragma_index_info('%s')", name)).Unquote()).
OrderBy("seqno").
Query()
if err := tx.Query(ctx, query, args, rows); err != nil {
return nil, fmt.Errorf("reading table indexes %v", err)
}
defer rows.Close()
var names []string
if err := sql.ScanSlice(rows, &names); err != nil {
return nil, err
}
return names, nil
}

// scanColumn scans the column information from SQLite column description.
func (d *SQLite) scanColumn(c *Column, rows *sql.Rows) error {
var (
pk sql.NullInt64
notnull sql.NullInt64
defaults sql.NullString
)
if err := rows.Scan(&c.Name, &c.typ, &notnull, &defaults, &pk); err != nil {
return fmt.Errorf("scanning column description: %v", err)
}
c.Nullable = notnull.Int64 == 0
if pk.Int64 > 0 {
c.Key = PrimaryKey
}
parts, _, _, err := parseColumn(c.typ)
if err != nil {
return err
}
switch parts[0] {
case "bool", "boolean":
c.Type = field.TypeBool
case "blob":
c.Size = math.MaxUint32
c.Type = field.TypeBytes
case "integer":
// All integer types have the same "type affinity".
c.Type = field.TypeInt
case "real", "float", "double":
c.Type = field.TypeFloat64
case "datetime":
c.Type = field.TypeTime
case "json":
c.Type = field.TypeJSON
case "uuid":
c.Type = field.TypeUUID
case "varchar", "text":
c.Size = DefaultStringLen
c.Type = field.TypeString
}
if defaults.Valid {
return c.ScanDefault(defaults.String)
}
return nil
}
36 changes: 36 additions & 0 deletions dialect/sql/schema/sqlite_test.go
Expand Up @@ -6,6 +6,7 @@ package schema

import (
"context"
"math"
"testing"

"github.com/facebookincubator/ent/dialect/sql"
Expand Down Expand Up @@ -119,6 +120,41 @@ func TestSQLite_Create(t *testing.T) {
mock.ExpectCommit()
},
},
{
name: "add column to table",
tables: []*Table{
{
Name: "users",
Columns: []*Column{
{Name: "id", Type: field.TypeInt, Increment: true},
{Name: "name", Type: field.TypeString, Nullable: true},
{Name: "text", Type: field.TypeString, Nullable: true, Size: math.MaxInt32},
{Name: "uuid", Type: field.TypeUUID, Nullable: true},
{Name: "age", Type: field.TypeInt, Default: 0},
},
PrimaryKey: []*Column{
{Name: "id", Type: field.TypeInt, Increment: true},
},
},
},
before: func(mock sqliteMock) {
mock.start()
mock.tableExists("users", true)
mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('users') ORDER BY `pk`")).
WithArgs().
WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}).
AddRow("name", "varchar(255)", 0, nil, 0).
AddRow("text", "text", 0, "NULL", 0).
AddRow("uuid", "uuid", 0, "Null", 0).
AddRow("id", "integer", 1, "NULL", 1))
mock.ExpectQuery(escape("SELECT `name`, `unique` FROM pragma_index_list('users')")).
WithArgs().
WillReturnRows(sqlmock.NewRows([]string{"name", "unique"}))
mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `age` integer NOT NULL DEFAULT 0")).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
},
},
{
name: "universal id for all tables",
tables: []*Table{
Expand Down

0 comments on commit 1bf2416

Please sign in to comment.