Skip to content

Commit

Permalink
SQLite3 support for migrations
Browse files Browse the repository at this point in the history
Work in progress until this message is removed

Signed-off-by: Erik Hollensbe <github@hollensbe.org>
  • Loading branch information
Erik Hollensbe committed Dec 19, 2019
1 parent 1bbe460 commit b4f476b
Show file tree
Hide file tree
Showing 7 changed files with 734 additions and 46 deletions.
23 changes: 17 additions & 6 deletions dialect/sql/builder.go
Expand Up @@ -254,7 +254,11 @@ func (t *TableAlter) DropColumn(c *ColumnBuilder) *TableAlter {

// AddForeignKey adds a foreign key constraint to the `ALTER TABLE` statement.
func (t *TableAlter) AddForeignKey(fk *ForeignKeyBuilder) *TableAlter {
t.Queries = append(t.Queries, &Wrapper{"ADD CONSTRAINT %s", fk})
if t.Dialect() == dialect.SQLite {
t.Queries = append(t.Queries, &Wrapper{"ADD COLUMN %s", fk})
} else {
t.Queries = append(t.Queries, &Wrapper{"ADD CONSTRAINT %s", fk})
}
return t
}

Expand Down Expand Up @@ -333,17 +337,24 @@ func (fk *ForeignKeyBuilder) OnUpdate(action string) *ForeignKeyBuilder {

// Query returns query representation of a foreign key constraint.
func (fk *ForeignKeyBuilder) Query() (string, []interface{}) {
if fk.symbol != "" {
if fk.Dialect() != dialect.SQLite && fk.symbol != "" {
fk.Ident(fk.symbol).Pad()
}
fk.WriteString("FOREIGN KEY")
fk.Nested(func(b *Builder) {
b.IdentComma(fk.columns...)
})

if fk.Dialect() == dialect.SQLite {
fk.Pad().Ident(fk.columns[0])
} else {
fk.WriteString("FOREIGN KEY")
fk.Nested(func(b *Builder) {
b.IdentComma(fk.columns...)
})
}

fk.Pad().Join(fk.ref)
for _, action := range fk.actions {
fk.Pad().WriteString(action)
}

return fk.String(), fk.args
}

Expand Down
13 changes: 11 additions & 2 deletions dialect/sql/schema/migrate.go
Expand Up @@ -7,6 +7,7 @@ package schema
import (
"context"
"crypto/md5"
"errors"
"fmt"
"math"
"sort"
Expand Down Expand Up @@ -154,7 +155,7 @@ func (m *Migrate) create(ctx context.Context, tx dialect.Tx, tables ...*Table) e
}
fks := make([]*ForeignKey, 0, len(t.ForeignKeys))
for _, fk := range t.ForeignKeys {
exist, err := m.fkExist(ctx, tx, fk.Symbol)
exist, err := m.fkExist(ctx, tx, t, fk)
if err != nil {
return err
}
Expand Down Expand Up @@ -235,6 +236,14 @@ type changes struct {
// changeSet returns a changes object to be applied on existing table.
// It fails if one of the changes is invalid.
func (m *Migrate) changeSet(curr, new *Table) (*changes, error) {
if curr == nil {
return nil, errors.New("current state could not be determined during change set generation")
}

if new == nil {
return nil, errors.New("determined state could not be determined during change set generation")
}

change := &changes{}
// pks.
if len(curr.PrimaryKey) != len(new.PrimaryKey) {
Expand Down Expand Up @@ -434,7 +443,7 @@ type sqlDialect interface {
init(context.Context, dialect.Tx) error
table(context.Context, dialect.Tx, string) (*Table, error)
tableExist(context.Context, dialect.Tx, string) (bool, error)
fkExist(context.Context, dialect.Tx, string) (bool, error)
fkExist(context.Context, dialect.Tx, *Table, *ForeignKey) (bool, error)
setRange(context.Context, dialect.Tx, string, int) error
dropIndex(context.Context, dialect.Tx, *Index, string) error
// table, column and index builder per dialect.
Expand Down
9 changes: 4 additions & 5 deletions dialect/sql/schema/mysql.go
Expand Up @@ -47,9 +47,9 @@ func (d *MySQL) tableExist(ctx context.Context, tx dialect.Tx, name string) (boo
return exist(ctx, tx, query, args...)
}

func (d *MySQL) fkExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) {
func (d *MySQL) fkExist(ctx context.Context, tx dialect.Tx, table *Table, fk *ForeignKey) (bool, error) {
query, args := sql.Select(sql.Count("*")).From(sql.Table("INFORMATION_SCHEMA.TABLE_CONSTRAINTS").Unquote()).
Where(sql.EQ("TABLE_SCHEMA", sql.Raw("(SELECT DATABASE())")).And().EQ("CONSTRAINT_TYPE", "FOREIGN KEY").And().EQ("CONSTRAINT_NAME", name)).Query()
Where(sql.EQ("TABLE_SCHEMA", sql.Raw("(SELECT DATABASE())")).And().EQ("CONSTRAINT_TYPE", "FOREIGN KEY").And().EQ("CONSTRAINT_NAME", fk.Symbol)).Query()
return exist(ctx, tx, query, args...)
}

Expand Down Expand Up @@ -240,9 +240,8 @@ func (d *MySQL) scanColumn(c *Column, rows *sql.Rows) error {
if nullable.Valid {
c.Nullable = nullable.String == "YES"
}
switch parts := strings.FieldsFunc(c.typ, func(r rune) bool {
return r == '(' || r == ')' || r == ' ' || r == ','
}); parts[0] {

switch parts := typeFields(c.typ); parts[0] {
case "int":
c.Type = field.TypeInt32
case "smallint":
Expand Down
35 changes: 4 additions & 31 deletions dialect/sql/schema/postgres.go
Expand Up @@ -54,10 +54,10 @@ func (d *Postgres) tableExist(ctx context.Context, tx dialect.Tx, name string) (
}

// tableExist checks if a foreign-key exists in the current schema.
func (d *Postgres) fkExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) {
func (d *Postgres) fkExist(ctx context.Context, tx dialect.Tx, table *Table, fk *ForeignKey) (bool, error) {
query, args := sql.Dialect(dialect.Postgres).
Select(sql.Count("*")).From(sql.Table("INFORMATION_SCHEMA.TABLE_CONSTRAINTS").Unquote()).
Where(sql.EQ("table_schema", sql.Raw("CURRENT_SCHEMA()")).And().EQ("constraint_type", "FOREIGN KEY").And().EQ("constraint_name", name)).Query()
Where(sql.EQ("table_schema", sql.Raw("CURRENT_SCHEMA()")).And().EQ("constraint_type", "FOREIGN KEY").And().EQ("constraint_name", fk.Symbol)).Query()
return exist(ctx, tx, query, args...)
}

Expand Down Expand Up @@ -96,35 +96,8 @@ func (d *Postgres) table(ctx context.Context, tx dialect.Tx, name string) (*Tabl
if err != nil {
return nil, err
}
// Populate the index information to the table and its columns.
// We do it manually, because PK and uniqueness information does
// not exist when querying the INFORMATION_SCHEMA.COLUMNS above.
for _, idx := range idxs {
switch {
case idx.primary:
for _, name := range idx.columns {
c, ok := t.column(name)
if !ok {
return nil, fmt.Errorf("index %q column %q was not found in table %q", idx.Name, name, t.Name)
}
c.Key = PrimaryKey
t.PrimaryKey = append(t.PrimaryKey, c)
}
case idx.Unique && len(idx.columns) == 1:
name := idx.columns[0]
c, ok := t.column(name)
if !ok {
return nil, fmt.Errorf("index %q column %q was not found in table %q", idx.Name, name, t.Name)
}
c.Key = UniqueKey
c.Unique = true
c.indexes.append(idx)
fallthrough
default:
t.AddIndex(idx.Name, idx.Unique, idx.columns)
}
}
return t, nil

return t, processIndexes(idxs, t)
}

// indexesQuery holds a query format for retrieving
Expand Down
38 changes: 38 additions & 0 deletions dialect/sql/schema/schema.go
Expand Up @@ -436,3 +436,41 @@ func compare(v1, v2 int) int {
}
return 1
}

func typeFields(typ string) []string {
return strings.FieldsFunc(typ, func(r rune) bool {
return r == '(' || r == ')' || r == ' ' || r == ','
})
}

func processIndexes(idxs Indexes, t *Table) error {
// Populate the index information to the table and its columns.
// We do it manually, because PK and uniqueness information does
// not exist when querying the INFORMATION_SCHEMA.COLUMNS above.
for _, idx := range idxs {
switch {
case idx.primary:
for _, name := range idx.columns {
c, ok := t.column(name)
if !ok {
return fmt.Errorf("index %q column %q was not found in table %q", idx.Name, name, t.Name)
}
c.Key = PrimaryKey
t.PrimaryKey = append(t.PrimaryKey, c)
}
case idx.Unique && len(idx.columns) == 1:
name := idx.columns[0]
c, ok := t.column(name)
if !ok {
return fmt.Errorf("index %q column %q was not found in table %q", idx.Name, name, t.Name)
}
c.Key = UniqueKey
c.Unique = true
c.indexes.append(idx)
fallthrough
default:
t.AddIndex(idx.Name, idx.Unique, idx.columns)
}
}
return nil
}

0 comments on commit b4f476b

Please sign in to comment.