Skip to content

Commit

Permalink
SQLite3 support for migrations
Browse files Browse the repository at this point in the history
Work in progress until this message is removed

Signed-off-by: Erik Hollensbe <github@hollensbe.org>
  • Loading branch information
Erik Hollensbe committed Dec 6, 2019
1 parent bb05160 commit 5eb1f21
Show file tree
Hide file tree
Showing 6 changed files with 246 additions and 33 deletions.
9 changes: 9 additions & 0 deletions dialect/sql/schema/migrate.go
Expand Up @@ -7,6 +7,7 @@ package schema
import (
"context"
"crypto/md5"
"errors"
"fmt"
"math"
"sort"
Expand Down Expand Up @@ -236,6 +237,14 @@ type changes struct {
// changeSet returns a changes object to be applied on existing table.
// It fails if one of the changes is invalid.
func (m *Migrate) changeSet(curr, new *Table) (*changes, error) {
if curr == nil {
return nil, errors.New("current state could not be determined during change set generation")
}

if new == nil {
return nil, errors.New("determined state could not be determined during change set generation")
}

change := &changes{}
// pks.
if len(curr.PrimaryKey) != len(new.PrimaryKey) {
Expand Down
6 changes: 3 additions & 3 deletions dialect/sql/schema/mysql.go
Expand Up @@ -239,9 +239,9 @@ func (d *MySQL) scanColumn(c *Column, rows *sql.Rows) error {
if nullable.Valid {
c.Nullable = nullable.String == "YES"
}
switch parts := strings.FieldsFunc(c.typ, func(r rune) bool {
return r == '(' || r == ')' || r == ' ' || r == ','
}); parts[0] {

parts := typeFields(c.typ)
switch parts[0] {
case "int":
c.Type = field.TypeInt32
case "smallint":
Expand Down
31 changes: 2 additions & 29 deletions dialect/sql/schema/postgres.go
Expand Up @@ -96,35 +96,8 @@ func (d *Postgres) table(ctx context.Context, tx dialect.Tx, name string) (*Tabl
if err != nil {
return nil, err
}
// Populate the index information to the table and its columns.
// We do it manually, because PK and uniqueness information does
// not exist when querying the INFORMATION_SCHEMA.COLUMNS above.
for _, idx := range idxs {
switch {
case idx.primary:
for _, name := range idx.columns {
c, ok := t.column(name)
if !ok {
return nil, fmt.Errorf("index %q column %q was not found in table %q", idx.Name, name, t.Name)
}
c.Key = PrimaryKey
t.PrimaryKey = append(t.PrimaryKey, c)
}
case idx.Unique && len(idx.columns) == 1:
name := idx.columns[0]
c, ok := t.column(name)
if !ok {
return nil, fmt.Errorf("index %q column %q was not found in table %q", idx.Name, name, t.Name)
}
c.Key = UniqueKey
c.Unique = true
c.indexes.append(idx)
fallthrough
default:
t.AddIndex(idx.Name, idx.Unique, idx.columns)
}
}
return t, nil

return t, processIndexes(idxs, t)
}

// indexesQuery holds a query format for retrieving
Expand Down
38 changes: 38 additions & 0 deletions dialect/sql/schema/schema.go
Expand Up @@ -431,3 +431,41 @@ func compare(v1, v2 int) int {
}
return 1
}

func typeFields(typ string) []string {
return strings.FieldsFunc(typ, func(r rune) bool {
return r == '(' || r == ')' || r == ' ' || r == ','
})
}

func processIndexes(idxs Indexes, t *Table) error {
// Populate the index information to the table and its columns.
// We do it manually, because PK and uniqueness information does
// not exist when querying the INFORMATION_SCHEMA.COLUMNS above.
for _, idx := range idxs {
switch {
case idx.primary:
for _, name := range idx.columns {
c, ok := t.column(name)
if !ok {
return fmt.Errorf("index %q column %q was not found in table %q", idx.Name, name, t.Name)
}
c.Key = PrimaryKey
t.PrimaryKey = append(t.PrimaryKey, c)
}
case idx.Unique && len(idx.columns) == 1:
name := idx.columns[0]
c, ok := t.column(name)
if !ok {
return fmt.Errorf("index %q column %q was not found in table %q", idx.Name, name, t.Name)
}
c.Key = UniqueKey
c.Unique = true
c.indexes.append(idx)
fallthrough
default:
t.AddIndex(idx.Name, idx.Unique, idx.columns)
}
}
return nil
}
164 changes: 163 additions & 1 deletion dialect/sql/schema/sqlite.go
Expand Up @@ -7,10 +7,14 @@ package schema
import (
"context"
"fmt"
"math"
"strconv"
"strings"

"github.com/facebookincubator/ent/dialect"
"github.com/facebookincubator/ent/dialect/sql"
"github.com/facebookincubator/ent/schema/field"
"github.com/pkg/errors"
)

// SQLite is an SQLite migration driver.
Expand Down Expand Up @@ -119,6 +123,59 @@ func (*SQLite) cType(c *Column) (t string) {
return t
}

func (d *SQLite) typeField(c *Column, str string) error {
parts := typeFields(str)
switch parts[0] {
case "int", "integer":
c.Type = field.TypeInt32
case "smallint":
c.Type = field.TypeInt16
case "bigint":
c.Type = field.TypeInt64
case "tinyint":
size, err := strconv.Atoi(parts[1])
if err != nil {
return fmt.Errorf("converting varchar size to int: %v", err)
}
switch {
// XXX this is a throw back from old mysql bools I know; but should we
// keep it anyway for consistency? I think sqlite3 has similar conventions.
case size == 1:
c.Type = field.TypeBool
default:
c.Type = field.TypeInt8
}
case "double":
c.Type = field.TypeFloat64
case "timestamp", "datetime":
c.Type = field.TypeTime
case "blob":
c.Size = math.MaxUint32
c.Type = field.TypeBytes
case "varchar":
c.Type = field.TypeString
size, err := strconv.ParseInt(parts[1], 10, 64)
if err != nil {
return fmt.Errorf("converting varchar size to int: %v", err)
}
c.Size = size
case "json":
c.Type = field.TypeJSON
case "uuid":
c.Type = field.TypeUUID
case "enum":
c.Type = field.TypeEnum
c.Enums = make([]string, len(parts)-1)
for i, e := range parts[1:] {
c.Enums[i] = strings.Trim(e, "'")
}
default:
return fmt.Errorf("unknown column type %q", parts[0])
}

return nil
}

// addColumn returns the DSL query for adding the given column to a table.
func (d *SQLite) addColumn(c *Column) *sql.ColumnBuilder {
b := sql.Column(c.Name).Type(d.cType(c)).Attr(c.Attr)
Expand Down Expand Up @@ -148,4 +205,109 @@ func (d *SQLite) dropIndex(i *Index, _ string) *sql.DropIndexBuilder {

// fkExist returns always true to disable foreign-keys creation after the table was created.
func (d *SQLite) fkExist(context.Context, dialect.Tx, string) (bool, error) { return true, nil }
func (d *SQLite) table(context.Context, dialect.Tx, string) (*Table, error) { return nil, nil }

func (d *SQLite) indexes(ctx context.Context, tx dialect.Tx, table string) (Indexes, error) {
idxs := Indexes{}
rows := &sql.Rows{}

if err := tx.Query(ctx, fmt.Sprintf("pragma index_list('%s')", table), []interface{}{}, rows); err != nil {
return nil, fmt.Errorf("sqlite3: reading table index description %v", err)
}
defer rows.Close()

for rows.Next() {
var (
seq int
name string
unique int
origin string
partial int
)
if err := rows.Scan(&seq, &name, &unique, &origin, &partial); err != nil {
return nil, errors.Wrap(err, "while querying indexes")
}
idxs = append(idxs, &Index{
Name: name,
Unique: unique == 1,
primary: origin == "pk",
})
}
rows.Close()

// second loop to gather column info

for _, idx := range idxs {
if err := tx.Query(ctx, fmt.Sprintf("pragma index_info('%s')", idx.Name), []interface{}{}, rows); err != nil {
return nil, fmt.Errorf("sqlite3: reading index description %v", err)
}
defer rows.Close()

for rows.Next() {
var (
seq int
columnID int
columnName string
)
if err := rows.Scan(&seq, &columnID, &columnName); err != nil {
return nil, errors.Wrap(err, "while querying indexes")
}

idx.columns = append(idx.columns, columnName)
}
rows.Close()
}

return idxs, nil
}

func (d *SQLite) scanColumn(c *Column, rows *sql.Rows) error {
var (
id int
name string
typ string
notnullInt int
dflt sql.NullString
pkInt int
)

if err := rows.Scan(&id, &name, &typ, &notnullInt, &dflt, &pkInt); err != nil {
return err
}

c.Name = name
c.Nullable = notnullInt == 0
if pkInt > 0 {
c.Key = PrimaryKey
}

return d.typeField(c, typ)
}

func (d *SQLite) table(ctx context.Context, tx dialect.Tx, name string) (*Table, error) {
rows := &sql.Rows{}
if err := tx.Query(ctx, fmt.Sprintf("pragma table_info('%s')", name), []interface{}{}, rows); err != nil {
return nil, fmt.Errorf("sqlite3: reading table description %v", err)
}
// call `Close` in cases of failures (`Close` is idempotent).
defer rows.Close()
t := NewTable(name)
for rows.Next() {
c := &Column{}
if err := d.scanColumn(c, rows); err != nil {
return nil, err
}
t.AddColumn(c)
if c.Key == PrimaryKey {
t.PrimaryKey = append(t.PrimaryKey, c)
}
}
if err := rows.Close(); err != nil {
return nil, fmt.Errorf("closing rows %v", err)
}
idxs, err := d.indexes(ctx, tx, name)
if err != nil {
return nil, err
}

return t, processIndexes(idxs, t)
}
31 changes: 31 additions & 0 deletions dialect/sql/schema/sqlite_test.go
Expand Up @@ -8,13 +8,44 @@ import (
"context"
"testing"

_ "github.com/mattn/go-sqlite3"

stdsql "database/sql"

"github.com/facebookincubator/ent/dialect/sql"
"github.com/facebookincubator/ent/schema/field"

"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/require"
)

func TestSQLite_DoubleCreate(t *testing.T) {
tables := []*Table{
{
Name: "users",
PrimaryKey: []*Column{
{Name: "id", Type: field.TypeInt, Increment: true},
},
Columns: []*Column{
{Name: "id", Type: field.TypeInt, Increment: true},
{Name: "name", Type: field.TypeString, Nullable: true},
{Name: "age", Type: field.TypeInt},
{Name: "doc", Type: field.TypeJSON, Nullable: true},
{Name: "uuid", Type: field.TypeUUID, Nullable: true},
},
},
}

for i := 0; i < 2; i++ {
db, err := stdsql.Open("sqlite3", "file:test?mode=memory&cache=shared&_fk=1")
require.NoError(t, err)
migrate, err := NewMigrate(sql.OpenDB("sqlite3", db))
require.NoError(t, err)
err = migrate.Create(context.Background(), tables...)
require.NoError(t, err)
}
}

func TestSQLite_Create(t *testing.T) {
tests := []struct {
name string
Expand Down

0 comments on commit 5eb1f21

Please sign in to comment.