Skip to content

Commit

Permalink
dialect/sql/schema: initial work for incremental migration
Browse files Browse the repository at this point in the history
This is a WIP PR and should be ignored this moment.
It's based on PR #221 created by Erik Hollensbe (He should
get his credit for his work before we land this).
  • Loading branch information
a8m committed Apr 12, 2020
1 parent 8effe6d commit 210cd1d
Show file tree
Hide file tree
Showing 10 changed files with 416 additions and 47 deletions.
4 changes: 2 additions & 2 deletions .golangci.yml
Expand Up @@ -57,9 +57,9 @@ issues:
text: "SQL string concatenation"
linters:
- gosec
- path: dialect/sql/schema/migrate.go
text: "weak cryptographic primitive"
- path: dialect/sql/schema
linters:
- dupl
- gosec
- path: entc/load/load.go
text: "packages.LoadSyntax is deprecated"
Expand Down
22 changes: 18 additions & 4 deletions dialect/sql/builder.go
Expand Up @@ -25,10 +25,11 @@ type Querier interface {
// ColumnBuilder is a builder for column definition in table creation.
type ColumnBuilder struct {
Builder
typ string // column type.
name string // column name.
attr string // extra attributes.
modify bool // modify existing.
typ string // column type.
name string // column name.
attr string // extra attributes.
modify bool // modify existing.
fk *ForeignKeyBuilder // foreign-key constraint.
}

// Column returns a new ColumnBuilder with the given name.
Expand All @@ -52,6 +53,12 @@ func (c *ColumnBuilder) Attr(attr string) *ColumnBuilder {
return c
}

// Constraint adds the CONSTRAINT clause to the ADD COLUMN statement in SQLite.
func (c *ColumnBuilder) Constraint(fk *ForeignKeyBuilder) *ColumnBuilder {
c.fk = fk
return c
}

// Query returns query representation of a Column.
func (c *ColumnBuilder) Query() (string, []interface{}) {
c.Ident(c.name)
Expand All @@ -64,6 +71,13 @@ func (c *ColumnBuilder) Query() (string, []interface{}) {
if c.attr != "" {
c.Pad().WriteString(c.attr)
}
if c.fk != nil {
c.WriteString(" CONSTRAINT " + c.fk.symbol)
c.Pad().Join(c.fk.ref)
for _, action := range c.fk.actions {
c.Pad().WriteString(action)
}
}
return c.String(), c.args
}

Expand Down
14 changes: 14 additions & 0 deletions dialect/sql/builder_test.go
Expand Up @@ -1164,6 +1164,20 @@ func TestBuilder(t *testing.T) {
input: DropIndex("name_index").Table("users"),
wantQuery: "DROP INDEX `name_index` ON `users`",
},
{
input: Select().
From(Table("pragma_table_info('t1')").Unquote()).
OrderBy("pk"),
wantQuery: "SELECT * FROM pragma_table_info('t1') ORDER BY `pk`",
},
{
input: AlterTable("users").
AddColumn(Column("spouse").Type("integer").
Constraint(ForeignKey("user_spouse").
Reference(Reference().Table("users").Columns("id")).
OnDelete("SET NULL"))),
wantQuery: "ALTER TABLE `users` ADD COLUMN `spouse` integer CONSTRAINT user_spouse REFERENCES `users`(`id`) ON DELETE SET NULL",
},
}
for i, tt := range tests {
t.Run(strconv.Itoa(i), func(t *testing.T) {
Expand Down
23 changes: 9 additions & 14 deletions dialect/sql/schema/migrate.go
Expand Up @@ -209,21 +209,14 @@ func (m *Migrate) apply(ctx context.Context, tx dialect.Tx, table string, change
}
}
}
b := sql.Dialect(m.Dialect()).AlterTable(table)
for _, c := range change.column.add {
b.AddColumn(m.addColumn(c))
}
for _, c := range change.column.modify {
b.ModifyColumns(m.alterColumn(c)...)
}
var drop []*Column
if m.dropColumns {
for _, c := range change.column.drop {
b.DropColumn(sql.Dialect(m.Dialect()).Column(c.Name))
}
drop = change.column.drop
}
queries := m.alterColumns(table, change.column.add, change.column.modify, drop)
// If there's actual action to execute on ALTER TABLE.
if len(b.Queries) != 0 {
query, args := b.Query()
for i := range queries {
query, args := queries[i].Query()
if err := tx.Exec(ctx, query, args, nil); err != nil {
return fmt.Errorf("alter table %q: %v", table, err)
}
Expand Down Expand Up @@ -532,6 +525,9 @@ func (m *Migrate) setupTable(t *Table) {
}
for _, fk := range t.ForeignKeys {
fk.Symbol = m.symbol(fk.Symbol)
for i := range fk.Columns {
fk.Columns[i].foreign = fk
}
}
}

Expand Down Expand Up @@ -590,9 +586,8 @@ type sqlDialect interface {
// table, column and index builder per dialect.
cType(*Column) string
tBuilder(*Table) *sql.TableBuilder
addColumn(*Column) *sql.ColumnBuilder
alterColumn(*Column) []*sql.ColumnBuilder
addIndex(*Index, string) *sql.IndexBuilder
alterColumns(table string, add, modify, drop []*Column) sql.Queries
}

type preparer interface {
Expand Down
23 changes: 18 additions & 5 deletions dialect/sql/schema/mysql.go
Expand Up @@ -243,11 +243,6 @@ func (d *MySQL) addColumn(c *Column) *sql.ColumnBuilder {
return b
}

// alterColumn returns the DSL query for modifying the given column.
func (d *MySQL) alterColumn(c *Column) []*sql.ColumnBuilder {
return []*sql.ColumnBuilder{d.addColumn(c)}
}

// addIndex returns the querying for adding an index to MySQL.
func (d *MySQL) addIndex(i *Index, table string) *sql.IndexBuilder {
return i.Builder(table)
Expand Down Expand Up @@ -465,6 +460,24 @@ func (d *MySQL) tableSchema() sql.Querier {
return sql.Raw("(SELECT DATABASE())")
}

// alterColumns returns the queries for applying the columns change-set.
func (d *MySQL) alterColumns(table string, add, modify, drop []*Column) sql.Queries {
b := sql.Dialect(dialect.MySQL).AlterTable(table)
for _, c := range add {
b.AddColumn(d.addColumn(c))
}
for _, c := range modify {
b.ModifyColumn(d.addColumn(c))
}
for _, c := range drop {
b.DropColumn(sql.Dialect(dialect.MySQL).Column(c.Name))
}
if len(b.Queries) == 0 {
return nil
}
return sql.Queries{b}
}

// parseColumn returns column parts, size and signedness by mysql type
func parseColumn(typ string) (parts []string, size int64, unsigned bool, err error) {
switch parts = strings.FieldsFunc(typ, func(r rune) bool {
Expand Down
8 changes: 4 additions & 4 deletions dialect/sql/schema/mysql_test.go
Expand Up @@ -388,7 +388,7 @@ func TestMySQL_Create(t *testing.T) {
},
},
{
name: "add bool column with default value to table",
name: "add bool column with default value",
tables: []*Table{
{
Name: "users",
Expand Down Expand Up @@ -420,7 +420,7 @@ func TestMySQL_Create(t *testing.T) {
},
},
{
name: "add string column with default value to table",
name: "add string column with default value",
tables: []*Table{
{
Name: "users",
Expand Down Expand Up @@ -452,7 +452,7 @@ func TestMySQL_Create(t *testing.T) {
},
},
{
name: "add column with unsupported default value to table",
name: "add column with unsupported default value",
tables: []*Table{
{
Name: "users",
Expand Down Expand Up @@ -484,7 +484,7 @@ func TestMySQL_Create(t *testing.T) {
},
},
{
name: "drop column to table",
name: "drop columns",
tables: []*Table{
{
Name: "users",
Expand Down
18 changes: 18 additions & 0 deletions dialect/sql/schema/postgres.go
Expand Up @@ -382,3 +382,21 @@ func (d *Postgres) renameIndex(t *Table, old, new *Index) sql.Querier {
func (d *Postgres) tableSchema() sql.Querier {
return sql.Raw("(CURRENT_SCHEMA())")
}

// alterColumns returns the queries for applying the columns change-set.
func (d *Postgres) alterColumns(table string, add, modify, drop []*Column) sql.Queries {
b := sql.Dialect(dialect.Postgres).AlterTable(table)
for _, c := range add {
b.AddColumn(d.addColumn(c))
}
for _, c := range modify {
b.ModifyColumns(d.alterColumn(c)...)
}
for _, c := range drop {
b.DropColumn(sql.Dialect(dialect.Postgres).Column(c.Name))
}
if len(b.Queries) == 0 {
return nil
}
return sql.Queries{b}
}
4 changes: 2 additions & 2 deletions dialect/sql/schema/schema.go
Expand Up @@ -98,7 +98,6 @@ func (t *Table) column(name string) (*Column, bool) {
}

// index returns a table index by its name.
// faster than map lookup for most cases.
func (t *Table) index(name string) (*Index, bool) {
for _, idx := range t.Indexes {
if idx.Name == name {
Expand Down Expand Up @@ -150,6 +149,7 @@ type Column struct {
Default interface{} // default value.
Enums []string // enum values.
indexes Indexes // linked indexes.
foreign *ForeignKey // linked foreign-key.
}

// UniqueKey returns boolean indicates if this column is a unique key.
Expand Down Expand Up @@ -186,7 +186,7 @@ func (c Column) FloatType() bool { return c.Type == field.TypeFloat32 || c.Type
// ScanDefault scans the default value string to its interface type.
func (c *Column) ScanDefault(value string) (err error) {
switch {
case value == Null: // ignore.
case strings.ToUpper(value) == Null: // ignore.
case c.IntType():
v := &sql.NullInt64{}
if err := v.Scan(value); err != nil {
Expand Down

0 comments on commit 210cd1d

Please sign in to comment.