diff --git a/Gopkg.lock b/Gopkg.lock index df3a092d..604c4cd3 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -54,12 +54,12 @@ version = "v1.3.0" [[projects]] - digest = "1:ee0845ea64262e3d1a6e2eab768fcb2008a0c8e571b7a3bebea554a1c031aeeb" + digest = "1:6a60be4b683dfa1d3e222b6ef96da4d3406280019731c6c1c23bd60cfb7928fe" name = "github.com/mattn/go-sqlite3" packages = ["."] pruneopts = "UT" - revision = "6c771bb9887719704b210e87e934f08be014bdb1" - version = "v1.6.0" + revision = "862b95943f99f3b40e317a79d41c27ac4b742011" + version = "v1.14.2" [[projects]] digest = "1:0028cb19b2e4c3112225cd871870f2d9cf49b9b4276531f03438a88e94be86fe" diff --git a/Gopkg.toml b/Gopkg.toml index d7072c22..e129f8c3 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -28,3 +28,7 @@ [prune] go-tests = true unused-packages = true + +[[constraint]] + name = "github.com/mattn/go-sqlite3" + version = "1.14.2" diff --git a/README.md b/README.md index 339a206e..27ca215b 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ REL is golang orm-ish database layer for layered architecture. It's testable and - Multi adapter. - Soft Deletion. - Pagination. - +- Schema Migration. ## Install diff --git a/adapter.go b/adapter.go index a5a1c45d..bda5b80e 100644 --- a/adapter.go +++ b/adapter.go @@ -18,4 +18,6 @@ type Adapter interface { Begin(ctx context.Context) (Adapter, error) Commit(ctx context.Context) error Rollback(ctx context.Context) error + + Apply(ctx context.Context, migration Migration) error } diff --git a/adapter/mysql/mysql.go b/adapter/mysql/mysql.go index 0a36a752..7471ca45 100644 --- a/adapter/mysql/mysql.go +++ b/adapter/mysql/mysql.go @@ -25,19 +25,26 @@ type Adapter struct { *sql.Adapter } -var _ rel.Adapter = (*Adapter)(nil) +var ( + _ rel.Adapter = (*Adapter)(nil) -// New is mysql adapter constructor. + // Config for mysql adapter. + Config = sql.Config{ + DropIndexOnTable: true, + Placeholder: "?", + EscapeChar: "`", + IncrementFunc: incrementFunc, + ErrorFunc: errorFunc, + MapColumnFunc: sql.MapColumn, + } +) + +// New mysql adapter using existing connection. func New(database *db.DB) *Adapter { return &Adapter{ Adapter: &sql.Adapter{ - Config: &sql.Config{ - Placeholder: "?", - EscapeChar: "`", - IncrementFunc: incrementFunc, - ErrorFunc: errorFunc, - }, - DB: database, + Config: Config, + DB: database, }, } } diff --git a/adapter/mysql/mysql_test.go b/adapter/mysql/mysql_test.go index 9a90182d..282708ad 100644 --- a/adapter/mysql/mysql_test.go +++ b/adapter/mysql/mysql_test.go @@ -15,58 +15,6 @@ import ( var ctx = context.TODO() -func init() { - adapter, err := Open(dsn()) - paranoid.Panic(err, "failed to open database connection") - - _, _, err = adapter.Exec(ctx, `DROP TABLE IF EXISTS extras;`, nil) - paranoid.Panic(err, "failed dropping extras table") - _, _, err = adapter.Exec(ctx, `DROP TABLE IF EXISTS addresses;`, nil) - paranoid.Panic(err, "failed dropping addresses table") - _, _, err = adapter.Exec(ctx, `DROP TABLE IF EXISTS users;`, nil) - paranoid.Panic(err, "failed dropping users table") - _, _, err = adapter.Exec(ctx, `DROP TABLE IF EXISTS composites;`, nil) - paranoid.Panic(err, "failed dropping composites table") - - _, _, err = adapter.Exec(ctx, `CREATE TABLE users ( - id INT UNSIGNED AUTO_INCREMENT PRIMARY KEY, - name VARCHAR(30) NOT NULL DEFAULT '', - gender VARCHAR(10) NOT NULL DEFAULT '', - age INT NOT NULL DEFAULT 0, - note varchar(50), - created_at DATETIME, - updated_at DATETIME - );`, nil) - paranoid.Panic(err, "failed creating users table") - - _, _, err = adapter.Exec(ctx, `CREATE TABLE addresses ( - id INT UNSIGNED AUTO_INCREMENT PRIMARY KEY, - user_id INT UNSIGNED, - name VARCHAR(60) NOT NULL DEFAULT '', - created_at DATETIME, - updated_at DATETIME, - FOREIGN KEY (user_id) REFERENCES users(id) - );`, nil) - paranoid.Panic(err, "failed creating addresses table") - - _, _, err = adapter.Exec(ctx, `CREATE TABLE extras ( - id INT UNSIGNED AUTO_INCREMENT PRIMARY KEY, - slug VARCHAR(30) DEFAULT NULL UNIQUE, - user_id INT UNSIGNED, - SCORE INT, - CONSTRAINT extras_user_id_fk FOREIGN KEY (user_id) REFERENCES users(id) - );`, nil) - paranoid.Panic(err, "failed creating extras table") - - _, _, err = adapter.Exec(ctx, `CREATE TABLE composites ( - primary1 INT UNSIGNED, - primary2 INT UNSIGNED, - data VARCHAR(255) DEFAULT NULL, - PRIMARY KEY (primary1, primary2) - );`, nil) - paranoid.Panic(err, "failed creating composites table") -} - func dsn() string { if os.Getenv("MYSQL_DATABASE") != "" { return os.Getenv("MYSQL_DATABASE") + "?charset=utf8&parseTime=True&loc=Local" @@ -82,6 +30,14 @@ func TestAdapter_specs(t *testing.T) { repo := rel.New(adapter) + // Prepare tables + teardown := specs.Setup(t, repo) + defer teardown() + + // Migration Specs + // - Rename column is only supported by MySQL 8.0 + specs.Migrate(t, repo, specs.SkipRenameColumn) + // Query Specs specs.Query(t, repo) specs.QueryJoin(t, repo) diff --git a/adapter/postgres/postgres.go b/adapter/postgres/postgres.go index 0cca263d..5e954af4 100644 --- a/adapter/postgres/postgres.go +++ b/adapter/postgres/postgres.go @@ -15,6 +15,7 @@ package postgres import ( "context" db "database/sql" + "time" "github.com/Fs02/rel" "github.com/Fs02/rel/adapter/sql" @@ -25,20 +26,26 @@ type Adapter struct { *sql.Adapter } -var _ rel.Adapter = (*Adapter)(nil) +var ( + _ rel.Adapter = (*Adapter)(nil) + + // Config for postgres adapter. + Config = sql.Config{ + Placeholder: "$", + EscapeChar: "\"", + Ordinal: true, + InsertDefaultValues: true, + ErrorFunc: errorFunc, + MapColumnFunc: mapColumnFunc, + } +) -// New is postgres adapter constructor. +// New postgres adapter using existing connection. func New(database *db.DB) *Adapter { return &Adapter{ Adapter: &sql.Adapter{ - Config: &sql.Config{ - Placeholder: "$", - EscapeChar: "\"", - Ordinal: true, - InsertDefaultValues: true, - ErrorFunc: errorFunc, - }, - DB: database, + Config: Config, + DB: database, }, } } @@ -144,3 +151,33 @@ func errorFunc(err error) error { return err } } + +func mapColumnFunc(column *rel.Column) (string, int, int) { + var ( + typ string + m, n int + ) + + // postgres specific + column.Unsigned = false + if column.Default == "" { + column.Default = nil + } + + switch column.Type { + case rel.ID: + typ = "SERIAL NOT NULL PRIMARY KEY" + case rel.DateTime: + typ = "TIMESTAMPTZ" + if t, ok := column.Default.(time.Time); ok { + column.Default = t.Format("2006-01-02 15:04:05") + } + case rel.Int, rel.BigInt, rel.Text: + column.Limit = 0 + typ, m, n = sql.MapColumn(column) + default: + typ, m, n = sql.MapColumn(column) + } + + return typ, m, n +} diff --git a/adapter/postgres/postgres_test.go b/adapter/postgres/postgres_test.go index aa75d214..19abc18c 100644 --- a/adapter/postgres/postgres_test.go +++ b/adapter/postgres/postgres_test.go @@ -16,60 +16,8 @@ import ( var ctx = context.TODO() func init() { - adapter, err := Open(dsn()) - paranoid.Panic(err, "failed to open database connection") - defer adapter.Close() - - _, _, err = adapter.Exec(ctx, `DROP TABLE IF EXISTS extras;`, nil) - paranoid.Panic(err, "failed dropping extras table") - _, _, err = adapter.Exec(ctx, `DROP TABLE IF EXISTS addresses;`, nil) - paranoid.Panic(err, "failed dropping addresses table") - _, _, err = adapter.Exec(ctx, `DROP TABLE IF EXISTS users;`, nil) - paranoid.Panic(err, "failed dropping users table") - _, _, err = adapter.Exec(ctx, `DROP TABLE IF EXISTS composites;`, nil) - paranoid.Panic(err, "failed dropping composites table") - - _, _, err = adapter.Exec(ctx, `CREATE TABLE users ( - id SERIAL NOT NULL PRIMARY KEY, - slug VARCHAR(30) DEFAULT NULL, - name VARCHAR(30) NOT NULL DEFAULT '', - gender VARCHAR(10) NOT NULL DEFAULT '', - age INT NOT NULL DEFAULT 0, - note varchar(50), - created_at TIMESTAMPTZ, - updated_at TIMESTAMPTZ, - UNIQUE(slug) - );`, nil) - paranoid.Panic(err, "failed creating users table") - - _, _, err = adapter.Exec(ctx, `CREATE TABLE addresses ( - id SERIAL NOT NULL PRIMARY KEY, - user_id INTEGER REFERENCES users(id), - name VARCHAR(60) NOT NULL DEFAULT '', - created_at TIMESTAMPTZ, - updated_at TIMESTAMPTZ - );`, nil) - paranoid.Panic(err, "failed creating addresses table") - - _, _, err = adapter.Exec(ctx, `CREATE TABLE extras ( - id SERIAL NOT NULL PRIMARY KEY, - slug VARCHAR(30) DEFAULT NULL UNIQUE, - user_id INTEGER REFERENCES users(id), - score INTEGER DEFAULT 0 CHECK (score>=0 AND score<=100) - );`, nil) - paranoid.Panic(err, "failed creating extras table") - - _, _, err = adapter.Exec(ctx, `CREATE TABLE composites ( - primary1 SERIAL NOT NULL, - primary2 SERIAL NOT NULL, - data VARCHAR(255) DEFAULT NULL, - PRIMARY KEY (primary1, primary2) - );`, nil) - paranoid.Panic(err, "failed creating composites table") - // hack to make sure location it has the same location object as returned by pq driver. - time.Local, err = time.LoadLocation("Asia/Jakarta") - paranoid.Panic(err, "failed loading time location") + time.Local, _ = time.LoadLocation("Asia/Jakarta") } func dsn() string { @@ -77,7 +25,7 @@ func dsn() string { return os.Getenv("POSTGRESQL_DATABASE") + "?sslmode=disable&timezone=Asia/Jakarta" } - return "postgres://rel@localhost:9920/rel_test?sslmode=disable&timezone=Asia/Jakarta" + return "postgres://rel@localhost:5432/rel_test?sslmode=disable&timezone=Asia/Jakarta" } func TestAdapter_specs(t *testing.T) { @@ -87,6 +35,13 @@ func TestAdapter_specs(t *testing.T) { repo := rel.New(adapter) + // Prepare tables + teardown := specs.Setup(t, repo) + defer teardown() + + // Migration Specs + specs.Migrate(t, repo) + // Query Specs specs.Query(t, repo) specs.QueryJoin(t, repo) diff --git a/adapter/specs/migration.go b/adapter/specs/migration.go new file mode 100644 index 00000000..b5cfbf19 --- /dev/null +++ b/adapter/specs/migration.go @@ -0,0 +1,221 @@ +package specs + +import ( + "testing" + "time" + + "github.com/Fs02/rel" + "github.com/Fs02/rel/migrator" +) + +var m migrator.Migrator + +// Setup database for specs execution. +func Setup(t *testing.T, repo rel.Repository) func() { + m = migrator.New(repo) + m.Register(1, + func(schema *rel.Schema) { + schema.CreateTable("users", func(t *rel.Table) { + t.ID("id") + t.String("slug", rel.Limit(30)) + t.String("name", rel.Limit(30), rel.Default("")) + t.String("gender", rel.Limit(10), rel.Default("")) + t.Int("age", rel.Required(true), rel.Default(0)) + t.String("note", rel.Limit(50)) + t.DateTime("created_at") + t.DateTime("updated_at") + + t.Unique([]string{"slug"}) + }) + }, + func(schema *rel.Schema) { + schema.DropTable("users") + }, + ) + + m.Register(2, + func(schema *rel.Schema) { + schema.CreateTable("addresses", func(t *rel.Table) { + t.ID("id") + t.Int("user_id", rel.Unsigned(true)) + t.String("name", rel.Limit(60), rel.Required(true), rel.Default("")) + t.DateTime("created_at") + t.DateTime("updated_at") + + t.ForeignKey("user_id", "users", "id") + }) + }, + func(schema *rel.Schema) { + schema.DropTable("addresses") + }, + ) + + m.Register(3, + func(schema *rel.Schema) { + schema.CreateTable("extras", func(t *rel.Table) { + t.ID("id") + t.Int("user_id", rel.Unsigned(true)) + t.String("slug", rel.Limit(30)) + t.Int("score", rel.Default(0)) + + t.ForeignKey("user_id", "users", "id") + t.Unique([]string{"slug"}) + t.Fragment("CONSTRAINT extras_score_check CHECK (score>=0 AND score<=100)") + }) + }, + func(schema *rel.Schema) { + schema.DropTable("extras") + }, + ) + + m.Register(4, + func(schema *rel.Schema) { + schema.CreateTable("composites", func(t *rel.Table) { + t.Int("primary1") + t.Int("primary2") + t.String("data") + + t.PrimaryKeys([]string{"primary1", "primary2"}) + }) + }, + func(schema *rel.Schema) { + schema.DropTable("composites") + }, + ) + + m.Migrate(ctx) + + return func() { + for i := 0; i < 4; i++ { + m.Rollback(ctx) + } + } +} + +// Migrate specs. +func Migrate(t *testing.T, repo rel.Repository, flags ...Flag) { + m.Register(5, + func(schema *rel.Schema) { + schema.CreateTable("dummies", func(t *rel.Table) { + t.ID("id") + t.Bool("bool1") + t.Bool("bool2", rel.Default(true)) + t.Int("int1") + t.Int("int2", rel.Default(8), rel.Unsigned(true), rel.Limit(10)) + t.Int("int3", rel.Unique(true)) + t.BigInt("bigint1") + t.BigInt("bigint2", rel.Default(8), rel.Unsigned(true), rel.Limit(200)) + t.Float("float1") + t.Float("float2", rel.Default(10.00), rel.Precision(2)) + t.Decimal("decimal1") + t.Decimal("decimal2", rel.Default(10.00), rel.Precision(6), rel.Scale(2)) + t.String("string1") + t.String("string2", rel.Default("string"), rel.Limit(100)) + t.Text("text") + t.Date("date1") + t.Date("date2", rel.Default(time.Now())) + t.DateTime("datetime1") + t.DateTime("datetime2", rel.Default(time.Now())) + t.Time("time1") + t.Time("time2", rel.Default(time.Now())) + t.Timestamp("timestamp1") + t.Timestamp("timestamp2", rel.Default(time.Now())) + + t.Unique([]string{"int2"}) + t.Unique([]string{"bigint1", "bigint2"}) + }) + }, + func(schema *rel.Schema) { + schema.DropTable("dummies") + }, + ) + defer m.Rollback(ctx) + + m.Register(6, + func(schema *rel.Schema) { + schema.AlterTable("dummies", func(t *rel.AlterTable) { + t.Bool("new_column") + }) + schema.AddColumn("dummies", "new_column1", rel.Int, rel.Unsigned(true)) + }, + func(schema *rel.Schema) { + if SkipDropColumn.enabled(flags) { + schema.AlterTable("dummies", func(t *rel.AlterTable) { + t.DropColumn("new_column") + }) + schema.DropColumn("dummies", "new_column1") + } + }, + ) + defer m.Rollback(ctx) + + if SkipRenameColumn.enabled(flags) { + m.Register(7, + func(schema *rel.Schema) { + schema.AlterTable("dummies", func(t *rel.AlterTable) { + t.RenameColumn("text", "teks") + t.RenameColumn("date2", "date3") + }) + schema.RenameColumn("dummies", "decimal1", "decimal0") + }, + func(schema *rel.Schema) { + schema.AlterTable("dummies", func(t *rel.AlterTable) { + t.RenameColumn("teks", "text") + t.RenameColumn("date3", "date2") + }) + schema.RenameColumn("dummies", "decimal0", "decimal1") + }, + ) + defer m.Rollback(ctx) + } + + m.Register(8, + func(schema *rel.Schema) { + schema.CreateIndex("dummies", "int1_idx", []string{"int1"}) + schema.CreateIndex("dummies", "string1_string2_idx", []string{"string1", "string2"}) + }, + func(schema *rel.Schema) { + schema.DropIndex("dummies", "int1_idx") + schema.DropIndex("dummies", "string1_string2_idx") + }, + ) + defer m.Rollback(ctx) + + m.Register(9, + func(schema *rel.Schema) { + schema.RenameTable("dummies", "new_dummies") + }, + func(schema *rel.Schema) { + schema.RenameTable("new_dummies", "dummies") + }, + ) + defer m.Rollback(ctx) + + m.Register(10, + func(schema *rel.Schema) { + schema.CreateTableIfNotExists("dummies2", func(t *rel.Table) { + t.ID("id") + }) + }, + func(schema *rel.Schema) { + schema.DropTableIfExists("dummies2") + }, + ) + defer m.Rollback(ctx) + + m.Register(11, + func(schema *rel.Schema) { + schema.CreateTableIfNotExists("dummies2", func(t *rel.Table) { + t.ID("id") + t.Int("field1") + t.Int("field2") + }) + }, + func(schema *rel.Schema) { + schema.DropTableIfExists("dummies2") + }, + ) + defer m.Rollback(ctx) + + m.Migrate(ctx) +} diff --git a/adapter/specs/specs.go b/adapter/specs/specs.go index 956be058..753519b8 100644 --- a/adapter/specs/specs.go +++ b/adapter/specs/specs.go @@ -14,6 +14,20 @@ import ( var ctx = context.TODO() +// Flag for configuration. +type Flag int + +func (f Flag) enabled(flags []Flag) bool { + return len(flags) > 0 && f&flags[0] == 0 +} + +const ( + // SkipDropColumn spec. + SkipDropColumn Flag = 1 << iota + // SkipRenameColumn spec. + SkipRenameColumn +) + // User defines users schema. type User struct { ID int64 @@ -53,11 +67,10 @@ type Composite struct { } var ( - config = &sql.Config{ + config = sql.Config{ Placeholder: "?", EscapeChar: "`", } - builder = sql.NewBuilder(config) ) func assertConstraint(t *testing.T, err error, ctype rel.ConstraintType, key string) { diff --git a/adapter/sql/adapter.go b/adapter/sql/adapter.go index 59159210..542798e3 100644 --- a/adapter/sql/adapter.go +++ b/adapter/sql/adapter.go @@ -10,20 +10,10 @@ import ( "github.com/Fs02/rel" ) -// Config holds configuration for adapter. -type Config struct { - Placeholder string - Ordinal bool - InsertDefaultValues bool - EscapeChar string - ErrorFunc func(error) error - IncrementFunc func(Adapter) int -} - // Adapter definition for database database. type Adapter struct { Instrumenter rel.Instrumenter - Config *Config + Config Config DB *sql.DB Tx *sql.Tx savepoint int @@ -32,42 +22,42 @@ type Adapter struct { var _ rel.Adapter = (*Adapter)(nil) // Close database connection. -func (adapter *Adapter) Close() error { - return adapter.DB.Close() +func (a *Adapter) Close() error { + return a.DB.Close() } // Instrumentation set instrumenter for this adapter. -func (adapter *Adapter) Instrumentation(instrumenter rel.Instrumenter) { - adapter.Instrumenter = instrumenter +func (a *Adapter) Instrumentation(instrumenter rel.Instrumenter) { + a.Instrumenter = instrumenter } // Instrument call instrumenter, if no instrumenter is set, this will be a no op. -func (adapter *Adapter) Instrument(ctx context.Context, op string, message string) func(err error) { - if adapter.Instrumenter != nil { - return adapter.Instrumenter(ctx, op, message) +func (a *Adapter) Instrument(ctx context.Context, op string, message string) func(err error) { + if a.Instrumenter != nil { + return a.Instrumenter(ctx, op, message) } return func(err error) {} } // Ping database. -func (adapter *Adapter) Ping(ctx context.Context) error { - return adapter.DB.PingContext(ctx) +func (a *Adapter) Ping(ctx context.Context) error { + return a.DB.PingContext(ctx) } // Aggregate record using given query. -func (adapter *Adapter) Aggregate(ctx context.Context, query rel.Query, mode string, field string) (int, error) { +func (a *Adapter) Aggregate(ctx context.Context, query rel.Query, mode string, field string) (int, error) { var ( err error out sql.NullInt64 - statement, args = NewBuilder(adapter.Config).Aggregate(query, mode, field) + statement, args = NewBuilder(a.Config).Aggregate(query, mode, field) ) - finish := adapter.Instrument(ctx, "adapter-aggregate", statement) - if adapter.Tx != nil { - err = adapter.Tx.QueryRowContext(ctx, statement, args...).Scan(&out) + finish := a.Instrument(ctx, "adapter-aggregate", statement) + if a.Tx != nil { + err = a.Tx.QueryRowContext(ctx, statement, args...).Scan(&out) } else { - err = adapter.DB.QueryRowContext(ctx, statement, args...).Scan(&out) + err = a.DB.QueryRowContext(ctx, statement, args...).Scan(&out) } finish(err) @@ -75,34 +65,34 @@ func (adapter *Adapter) Aggregate(ctx context.Context, query rel.Query, mode str } // Query performs query operation. -func (adapter *Adapter) Query(ctx context.Context, query rel.Query) (rel.Cursor, error) { +func (a *Adapter) Query(ctx context.Context, query rel.Query) (rel.Cursor, error) { var ( - statement, args = NewBuilder(adapter.Config).Find(query) + statement, args = NewBuilder(a.Config).Find(query) ) - finish := adapter.Instrument(ctx, "adapter-query", statement) - rows, err := adapter.query(ctx, statement, args) + finish := a.Instrument(ctx, "adapter-query", statement) + rows, err := a.query(ctx, statement, args) finish(err) - return &Cursor{rows}, adapter.Config.ErrorFunc(err) + return &Cursor{rows}, a.Config.ErrorFunc(err) } -func (adapter *Adapter) query(ctx context.Context, statement string, args []interface{}) (*sql.Rows, error) { - if adapter.Tx != nil { - return adapter.Tx.QueryContext(ctx, statement, args...) +func (a *Adapter) query(ctx context.Context, statement string, args []interface{}) (*sql.Rows, error) { + if a.Tx != nil { + return a.Tx.QueryContext(ctx, statement, args...) } - return adapter.DB.QueryContext(ctx, statement, args...) + return a.DB.QueryContext(ctx, statement, args...) } // Exec performs exec operation. -func (adapter *Adapter) Exec(ctx context.Context, statement string, args []interface{}) (int64, int64, error) { - finish := adapter.Instrument(ctx, "adapter-exec", statement) - res, err := adapter.exec(ctx, statement, args) +func (a *Adapter) Exec(ctx context.Context, statement string, args []interface{}) (int64, int64, error) { + finish := a.Instrument(ctx, "adapter-exec", statement) + res, err := a.exec(ctx, statement, args) finish(err) if err != nil { - return 0, 0, adapter.Config.ErrorFunc(err) + return 0, 0, a.Config.ErrorFunc(err) } lastID, _ := res.LastInsertId() @@ -111,28 +101,28 @@ func (adapter *Adapter) Exec(ctx context.Context, statement string, args []inter return lastID, rowCount, nil } -func (adapter *Adapter) exec(ctx context.Context, statement string, args []interface{}) (sql.Result, error) { - if adapter.Tx != nil { - return adapter.Tx.ExecContext(ctx, statement, args...) +func (a *Adapter) exec(ctx context.Context, statement string, args []interface{}) (sql.Result, error) { + if a.Tx != nil { + return a.Tx.ExecContext(ctx, statement, args...) } - return adapter.DB.ExecContext(ctx, statement, args...) + return a.DB.ExecContext(ctx, statement, args...) } // Insert inserts a record to database and returns its id. -func (adapter *Adapter) Insert(ctx context.Context, query rel.Query, primaryField string, mutates map[string]rel.Mutate) (interface{}, error) { +func (a *Adapter) Insert(ctx context.Context, query rel.Query, primaryField string, mutates map[string]rel.Mutate) (interface{}, error) { var ( - statement, args = NewBuilder(adapter.Config).Insert(query.Table, mutates) - id, _, err = adapter.Exec(ctx, statement, args) + statement, args = NewBuilder(a.Config).Insert(query.Table, mutates) + id, _, err = a.Exec(ctx, statement, args) ) return id, err } // InsertAll inserts all record to database and returns its ids. -func (adapter *Adapter) InsertAll(ctx context.Context, query rel.Query, primaryField string, fields []string, bulkMutates []map[string]rel.Mutate) ([]interface{}, error) { - statement, args := NewBuilder(adapter.Config).InsertAll(query.Table, fields, bulkMutates) - id, _, err := adapter.Exec(ctx, statement, args) +func (a *Adapter) InsertAll(ctx context.Context, query rel.Query, primaryField string, fields []string, bulkMutates []map[string]rel.Mutate) ([]interface{}, error) { + statement, args := NewBuilder(a.Config).InsertAll(query.Table, fields, bulkMutates) + id, _, err := a.Exec(ctx, statement, args) if err != nil { return nil, err } @@ -142,8 +132,8 @@ func (adapter *Adapter) InsertAll(ctx context.Context, query rel.Query, primaryF inc = 1 ) - if adapter.Config.IncrementFunc != nil { - inc = adapter.Config.IncrementFunc(*adapter) + if a.Config.IncrementFunc != nil { + inc = a.Config.IncrementFunc(*a) } if inc < 0 { @@ -169,93 +159,113 @@ func (adapter *Adapter) InsertAll(ctx context.Context, query rel.Query, primaryF } // Update updates a record in database. -func (adapter *Adapter) Update(ctx context.Context, query rel.Query, mutates map[string]rel.Mutate) (int, error) { +func (a *Adapter) Update(ctx context.Context, query rel.Query, mutates map[string]rel.Mutate) (int, error) { var ( - statement, args = NewBuilder(adapter.Config).Update(query.Table, mutates, query.WhereQuery) - _, updatedCount, err = adapter.Exec(ctx, statement, args) + statement, args = NewBuilder(a.Config).Update(query.Table, mutates, query.WhereQuery) + _, updatedCount, err = a.Exec(ctx, statement, args) ) return int(updatedCount), err } // Delete deletes all results that match the query. -func (adapter *Adapter) Delete(ctx context.Context, query rel.Query) (int, error) { +func (a *Adapter) Delete(ctx context.Context, query rel.Query) (int, error) { var ( - statement, args = NewBuilder(adapter.Config).Delete(query.Table, query.WhereQuery) - _, deletedCount, err = adapter.Exec(ctx, statement, args) + statement, args = NewBuilder(a.Config).Delete(query.Table, query.WhereQuery) + _, deletedCount, err = a.Exec(ctx, statement, args) ) return int(deletedCount), err } // Begin begins a new transaction. -func (adapter *Adapter) Begin(ctx context.Context) (rel.Adapter, error) { +func (a *Adapter) Begin(ctx context.Context) (rel.Adapter, error) { var ( tx *sql.Tx savepoint int err error ) - finish := adapter.Instrument(ctx, "adapter-begin", "begin transaction") + finish := a.Instrument(ctx, "adapter-begin", "begin transaction") - if adapter.Tx != nil { - tx = adapter.Tx - savepoint = adapter.savepoint + 1 - _, _, err = adapter.Exec(ctx, "SAVEPOINT s"+strconv.Itoa(savepoint)+";", []interface{}{}) + if a.Tx != nil { + tx = a.Tx + savepoint = a.savepoint + 1 + _, _, err = a.Exec(ctx, "SAVEPOINT s"+strconv.Itoa(savepoint)+";", []interface{}{}) } else { - tx, err = adapter.DB.BeginTx(ctx, nil) + tx, err = a.DB.BeginTx(ctx, nil) } finish(err) return &Adapter{ - Instrumenter: adapter.Instrumenter, - Config: adapter.Config, + Instrumenter: a.Instrumenter, + Config: a.Config, Tx: tx, savepoint: savepoint, }, err } // Commit commits current transaction. -func (adapter *Adapter) Commit(ctx context.Context) error { +func (a *Adapter) Commit(ctx context.Context) error { var err error - finish := adapter.Instrument(ctx, "adapter-commit", "commit transaction") + finish := a.Instrument(ctx, "adapter-commit", "commit transaction") - if adapter.Tx == nil { + if a.Tx == nil { err = errors.New("unable to commit outside transaction") - } else if adapter.savepoint > 0 { - _, _, err = adapter.Exec(ctx, "RELEASE SAVEPOINT s"+strconv.Itoa(adapter.savepoint)+";", []interface{}{}) + } else if a.savepoint > 0 { + _, _, err = a.Exec(ctx, "RELEASE SAVEPOINT s"+strconv.Itoa(a.savepoint)+";", []interface{}{}) } else { - err = adapter.Tx.Commit() + err = a.Tx.Commit() } finish(err) - return adapter.Config.ErrorFunc(err) + return a.Config.ErrorFunc(err) } // Rollback revert current transaction. -func (adapter *Adapter) Rollback(ctx context.Context) error { +func (a *Adapter) Rollback(ctx context.Context) error { var err error - finish := adapter.Instrument(ctx, "adapter-rollback", "rollback transaction") + finish := a.Instrument(ctx, "adapter-rollback", "rollback transaction") - if adapter.Tx == nil { + if a.Tx == nil { err = errors.New("unable to rollback outside transaction") - } else if adapter.savepoint > 0 { - _, _, err = adapter.Exec(ctx, "ROLLBACK TO SAVEPOINT s"+strconv.Itoa(adapter.savepoint)+";", []interface{}{}) + } else if a.savepoint > 0 { + _, _, err = a.Exec(ctx, "ROLLBACK TO SAVEPOINT s"+strconv.Itoa(a.savepoint)+";", []interface{}{}) } else { - err = adapter.Tx.Rollback() + err = a.Tx.Rollback() } finish(err) - return adapter.Config.ErrorFunc(err) + return a.Config.ErrorFunc(err) +} + +// Apply table. +func (a *Adapter) Apply(ctx context.Context, migration rel.Migration) error { + var ( + statement string + builder = NewBuilder(a.Config) + ) + + switch v := migration.(type) { + case rel.Table: + statement = builder.Table(v) + case rel.Index: + statement = builder.Index(v) + case rel.Raw: + statement = string(v) + } + + _, _, err := a.Exec(ctx, statement, nil) + return err } // New initialize adapter without db. -func New(config *Config) *Adapter { +func New(config Config) *Adapter { adapter := &Adapter{ Config: config, } diff --git a/adapter/sql/adapter_test.go b/adapter/sql/adapter_test.go index facf5b59..43964c8d 100644 --- a/adapter/sql/adapter_test.go +++ b/adapter/sql/adapter_test.go @@ -14,12 +14,13 @@ import ( func open(t *testing.T) *Adapter { var ( err error - config = &Config{ + config = Config{ Placeholder: "?", EscapeChar: "`", InsertDefaultValues: true, ErrorFunc: func(err error) error { return err }, IncrementFunc: func(Adapter) int { return -1 }, + MapColumnFunc: MapColumn, } adapter = New(config) ) @@ -43,7 +44,7 @@ type Name struct { } func TestNew(t *testing.T) { - assert.NotNil(t, New(nil)) + assert.NotNil(t, New(Config{})) } func TestAdapter_Ping(t *testing.T) { @@ -329,3 +330,36 @@ func TestAdapter_Exec_error(t *testing.T) { _, _, err := adapter.Exec(context.TODO(), "error", nil) assert.NotNil(t, err) } + +func TestAdapter_Apply(t *testing.T) { + var ( + ctx = context.TODO() + adapter = open(t) + ) + + defer adapter.Close() + + t.Run("Table", func(t *testing.T) { + adapter.Apply(ctx, rel.Table{ + Name: "tests", + Optional: true, + Definitions: []rel.TableDefinition{ + rel.Column{Name: "ID", Type: rel.ID}, + rel.Column{Name: "username", Type: rel.String}, + }, + }) + }) + + t.Run("Index", func(t *testing.T) { + adapter.Apply(ctx, rel.Index{ + Name: "username_idx", + Optional: true, + Table: "tests", + Columns: []string{"username"}, + }) + }) + + t.Run("Raw", func(t *testing.T) { + adapter.Apply(ctx, rel.Raw("SELECT 1;")) + }) +} diff --git a/adapter/sql/buffer.go b/adapter/sql/buffer.go index 5d165ab9..191b625d 100644 --- a/adapter/sql/buffer.go +++ b/adapter/sql/buffer.go @@ -1,12 +1,12 @@ package sql import ( - "bytes" + "strings" ) // Buffer used to strings buffer and argument of the query. type Buffer struct { - bytes.Buffer + strings.Builder Arguments []interface{} } @@ -17,6 +17,6 @@ func (b *Buffer) Append(args ...interface{}) { // Reset buffer. func (b *Buffer) Reset() { - b.Buffer.Reset() + b.Builder.Reset() b.Arguments = nil } diff --git a/adapter/sql/builder.go b/adapter/sql/builder.go index 8bb6043d..1813f135 100644 --- a/adapter/sql/builder.go +++ b/adapter/sql/builder.go @@ -1,6 +1,7 @@ package sql import ( + "encoding/json" "strconv" "strings" "sync" @@ -15,11 +16,262 @@ var fieldCache sync.Map // Builder defines information of query b. type Builder struct { - config *Config + config Config returnField string count int } +// Table generates query for table creation and modification. +func (b *Builder) Table(table rel.Table) string { + var buffer Buffer + + switch table.Op { + case rel.SchemaCreate: + b.createTable(&buffer, table) + case rel.SchemaAlter: + b.alterTable(&buffer, table) + case rel.SchemaRename: + buffer.WriteString("ALTER TABLE ") + buffer.WriteString(Escape(b.config, table.Name)) + buffer.WriteString(" RENAME TO ") + buffer.WriteString(Escape(b.config, table.Rename)) + buffer.WriteByte(';') + case rel.SchemaDrop: + buffer.WriteString("DROP TABLE ") + + if table.Optional { + buffer.WriteString("IF EXISTS ") + } + + buffer.WriteString(Escape(b.config, table.Name)) + buffer.WriteByte(';') + } + + return buffer.String() +} + +func (b *Builder) createTable(buffer *Buffer, table rel.Table) { + buffer.WriteString("CREATE TABLE ") + + if table.Optional { + buffer.WriteString("IF NOT EXISTS ") + } + + buffer.WriteString(Escape(b.config, table.Name)) + buffer.WriteString(" (") + + for i, def := range table.Definitions { + if i > 0 { + buffer.WriteString(", ") + } + switch v := def.(type) { + case rel.Column: + b.column(buffer, v) + case rel.Key: + b.key(buffer, v) + case rel.Raw: + buffer.WriteString(string(v)) + } + } + + buffer.WriteByte(')') + b.options(buffer, table.Options) + buffer.WriteByte(';') +} + +func (b *Builder) alterTable(buffer *Buffer, table rel.Table) { + for _, def := range table.Definitions { + buffer.WriteString("ALTER TABLE ") + buffer.WriteString(Escape(b.config, table.Name)) + buffer.WriteByte(' ') + + switch v := def.(type) { + case rel.Column: + switch v.Op { + case rel.SchemaCreate: + buffer.WriteString("ADD COLUMN ") + b.column(buffer, v) + case rel.SchemaRename: + // Add Change + buffer.WriteString("RENAME COLUMN ") + buffer.WriteString(Escape(b.config, v.Name)) + buffer.WriteString(" TO ") + buffer.WriteString(Escape(b.config, v.Rename)) + case rel.SchemaDrop: + buffer.WriteString("DROP COLUMN ") + buffer.WriteString(Escape(b.config, v.Name)) + } + case rel.Key: + // TODO: Rename and Drop, PR welcomed. + switch v.Op { + case rel.SchemaCreate: + buffer.WriteString("ADD ") + b.key(buffer, v) + } + } + + b.options(buffer, table.Options) + buffer.WriteByte(';') + } +} + +func (b *Builder) column(buffer *Buffer, column rel.Column) { + var ( + typ, m, n = b.config.MapColumnFunc(&column) + ) + + buffer.WriteString(Escape(b.config, column.Name)) + buffer.WriteByte(' ') + buffer.WriteString(typ) + + if m != 0 { + buffer.WriteByte('(') + buffer.WriteString(strconv.Itoa(m)) + + if n != 0 { + buffer.WriteByte(',') + buffer.WriteString(strconv.Itoa(n)) + } + + buffer.WriteByte(')') + } + + if column.Unsigned { + buffer.WriteString(" UNSIGNED") + } + + if column.Unique { + buffer.WriteString(" UNIQUE") + } + + if column.Required { + buffer.WriteString(" NOT NULL") + } + + if column.Default != nil { + buffer.WriteString(" DEFAULT ") + switch v := column.Default.(type) { + case string: + // TODO: single quote only required by postgres. + buffer.WriteByte('\'') + buffer.WriteString(v) + buffer.WriteByte('\'') + default: + // TODO: improve + bytes, _ := json.Marshal(column.Default) + buffer.Write(bytes) + } + } + + b.options(buffer, column.Options) +} + +func (b *Builder) key(buffer *Buffer, key rel.Key) { + var ( + typ = string(key.Type) + ) + + buffer.WriteString(typ) + + if key.Name != "" { + buffer.WriteByte(' ') + buffer.WriteString(Escape(b.config, key.Name)) + } + + buffer.WriteString(" (") + for i, col := range key.Columns { + if i > 0 { + buffer.WriteString(", ") + } + buffer.WriteString(Escape(b.config, col)) + } + buffer.WriteString(")") + + if key.Type == rel.ForeignKey { + buffer.WriteString(" REFERENCES ") + buffer.WriteString(Escape(b.config, key.Reference.Table)) + + buffer.WriteString(" (") + for i, col := range key.Reference.Columns { + if i > 0 { + buffer.WriteString(", ") + } + buffer.WriteString(Escape(b.config, col)) + } + buffer.WriteString(")") + + if onDelete := key.Reference.OnDelete; onDelete != "" { + buffer.WriteString(" ON DELETE ") + buffer.WriteString(onDelete) + } + + if onUpdate := key.Reference.OnUpdate; onUpdate != "" { + buffer.WriteString(" ON UPDATE ") + buffer.WriteString(onUpdate) + } + } + + b.options(buffer, key.Options) +} + +// Index generates query for index. +func (b *Builder) Index(index rel.Index) string { + var buffer Buffer + + switch index.Op { + case rel.SchemaCreate: + buffer.WriteString("CREATE ") + if index.Unique { + buffer.WriteString("UNIQUE ") + } + buffer.WriteString("INDEX ") + + if index.Optional { + buffer.WriteString("IF NOT EXISTS ") + } + + buffer.WriteString(Escape(b.config, index.Name)) + buffer.WriteString(" ON ") + buffer.WriteString(Escape(b.config, index.Table)) + + buffer.WriteString(" (") + for i, col := range index.Columns { + if i > 0 { + buffer.WriteString(", ") + } + buffer.WriteString(Escape(b.config, col)) + } + buffer.WriteString(")") + case rel.SchemaDrop: + buffer.WriteString("DROP INDEX ") + + if index.Optional { + buffer.WriteString("IF EXISTS ") + } + + buffer.WriteString(Escape(b.config, index.Name)) + + if b.config.DropIndexOnTable { + buffer.WriteString(" ON ") + buffer.WriteString(Escape(b.config, index.Table)) + } + } + + b.options(&buffer, index.Options) + buffer.WriteByte(';') + + return buffer.String() +} + +func (b *Builder) options(buffer *Buffer, options string) { + if options == "" { + return + } + + buffer.WriteByte(' ') + buffer.WriteString(options) +} + // Find generates query for select. func (b *Builder) Find(query rel.Query) (string, []interface{}) { if query.SQLQuery.Statement != "" { @@ -47,13 +299,13 @@ func (b *Builder) Aggregate(query rel.Query, mode string, field string) (string, buffer.WriteString("SELECT ") buffer.WriteString(mode) buffer.WriteByte('(') - buffer.WriteString(b.escape(field)) + buffer.WriteString(Escape(b.config, field)) buffer.WriteString(") AS ") buffer.WriteString(mode) for _, f := range query.GroupQuery.Fields { buffer.WriteByte(',') - buffer.WriteString(b.escape(f)) + buffer.WriteString(Escape(b.config, f)) } b.query(&buffer, query) @@ -90,7 +342,7 @@ func (b *Builder) Insert(table string, mutates map[string]rel.Mutate) (string, [ ) buffer.WriteString("INSERT INTO ") - buffer.WriteString(b.escape(table)) + buffer.WriteString(Escape(b.config, table)) if count == 0 && b.config.InsertDefaultValues { buffer.WriteString(" DEFAULT VALUES") @@ -219,14 +471,14 @@ func (b *Builder) Update(table string, mutates map[string]rel.Mutate, filter rel for field, mut := range mutates { switch mut.Type { case rel.ChangeSetOp: - buffer.WriteString(b.escape(field)) + buffer.WriteString(Escape(b.config, field)) buffer.WriteByte('=') buffer.WriteString(b.ph()) buffer.Append(mut.Value) case rel.ChangeIncOp: - buffer.WriteString(b.escape(field)) + buffer.WriteString(Escape(b.config, field)) buffer.WriteByte('=') - buffer.WriteString(b.escape(field)) + buffer.WriteString(Escape(b.config, field)) buffer.WriteByte('+') buffer.WriteString(b.ph()) buffer.Append(mut.Value) @@ -284,7 +536,7 @@ func (b *Builder) fields(buffer *Buffer, distinct bool, fields []string) { l := len(fields) - 1 for i, f := range fields { - buffer.WriteString(b.escape(f)) + buffer.WriteString(Escape(b.config, f)) if i < l { buffer.WriteByte(',') @@ -306,8 +558,8 @@ func (b *Builder) join(buffer *Buffer, table string, joins []rel.JoinQuery) { for _, join := range joins { var ( - from = b.escape(join.From) - to = b.escape(join.To) + from = Escape(b.config, join.From) + to = Escape(b.config, join.To) ) // TODO: move this to core functionality, and infer join condition using assoc data. @@ -350,7 +602,7 @@ func (b *Builder) groupBy(buffer *Buffer, fields []string) { l := len(fields) - 1 for i, f := range fields { - buffer.WriteString(b.escape(f)) + buffer.WriteString(Escape(b.config, f)) if i < l { buffer.WriteByte(',') @@ -379,7 +631,7 @@ func (b *Builder) orderBy(buffer *Buffer, orders []rel.SortQuery) { buffer.WriteString(" ORDER BY") for i, order := range orders { buffer.WriteByte(' ') - buffer.WriteString(b.escape(order.Field)) + buffer.WriteString(Escape(b.config, order.Field)) if order.Asc() { buffer.WriteString(" ASC") @@ -422,21 +674,21 @@ func (b *Builder) filter(buffer *Buffer, filter rel.FilterQuery) { rel.FilterGteOp: b.buildComparison(buffer, filter) case rel.FilterNilOp: - buffer.WriteString(b.escape(filter.Field)) + buffer.WriteString(Escape(b.config, filter.Field)) buffer.WriteString(" IS NULL") case rel.FilterNotNilOp: - buffer.WriteString(b.escape(filter.Field)) + buffer.WriteString(Escape(b.config, filter.Field)) buffer.WriteString(" IS NOT NULL") case rel.FilterInOp, rel.FilterNinOp: b.buildInclusion(buffer, filter) case rel.FilterLikeOp: - buffer.WriteString(b.escape(filter.Field)) + buffer.WriteString(Escape(b.config, filter.Field)) buffer.WriteString(" LIKE ") buffer.WriteString(b.ph()) buffer.Append(filter.Value) case rel.FilterNotLikeOp: - buffer.WriteString(b.escape(filter.Field)) + buffer.WriteString(Escape(b.config, filter.Field)) buffer.WriteString(" NOT LIKE ") buffer.WriteString(b.ph()) buffer.Append(filter.Value) @@ -471,7 +723,7 @@ func (b *Builder) build(buffer *Buffer, op string, inner []rel.FilterQuery) { } func (b *Builder) buildComparison(buffer *Buffer, filter rel.FilterQuery) { - buffer.WriteString(b.escape(filter.Field)) + buffer.WriteString(Escape(b.config, filter.Field)) switch filter.Type { case rel.FilterEqOp: @@ -497,7 +749,7 @@ func (b *Builder) buildInclusion(buffer *Buffer, filter rel.FilterQuery) { values = filter.Value.([]interface{}) ) - buffer.WriteString(b.escape(filter.Field)) + buffer.WriteString(Escape(b.config, filter.Field)) if filter.Type == rel.FilterInOp { buffer.WriteString(" IN (") @@ -523,38 +775,6 @@ func (b *Builder) ph() string { return b.config.Placeholder } -type fieldCacheKey struct { - field string - escape string -} - -func (b *Builder) escape(field string) string { - if b.config.EscapeChar == "" || field == "*" { - return field - } - - key := fieldCacheKey{field: field, escape: b.config.EscapeChar} - escapedField, ok := fieldCache.Load(key) - if ok { - return escapedField.(string) - } - - if len(field) > 0 && field[0] == UnescapeCharacter { - escapedField = field[1:] - } else if start, end := strings.IndexRune(field, '('), strings.IndexRune(field, ')'); start >= 0 && end >= 0 && end > start { - escapedField = field[:start+1] + b.escape(field[start+1:end]) + field[end:] - } else if strings.HasSuffix(field, "*") { - escapedField = b.config.EscapeChar + strings.Replace(field, ".", b.config.EscapeChar+".", 1) - } else { - escapedField = b.config.EscapeChar + - strings.Replace(field, ".", b.config.EscapeChar+"."+b.config.EscapeChar, 1) + - b.config.EscapeChar - } - - fieldCache.Store(key, escapedField) - return escapedField.(string) -} - // Returning append returning to insert rel. func (b *Builder) Returning(field string) *Builder { b.returnField = field @@ -562,7 +782,7 @@ func (b *Builder) Returning(field string) *Builder { } // NewBuilder create new SQL builder. -func NewBuilder(config *Config) *Builder { +func NewBuilder(config Config) *Builder { return &Builder{ config: config, } diff --git a/adapter/sql/builder_test.go b/adapter/sql/builder_test.go index 4ab0bb45..5ced9556 100644 --- a/adapter/sql/builder_test.go +++ b/adapter/sql/builder_test.go @@ -3,6 +3,7 @@ package sql import ( "fmt" "testing" + "time" "github.com/Fs02/rel" "github.com/Fs02/rel/sort" @@ -12,7 +13,7 @@ import ( func BenchmarkBuilder_Find(b *testing.B) { var ( - config = &Config{ + config = Config{ Placeholder: "?", EscapeChar: "`", } @@ -32,9 +33,214 @@ func BenchmarkBuilder_Find(b *testing.B) { } } +func TestBuilder_Table(t *testing.T) { + var ( + config = Config{ + Placeholder: "?", + EscapeChar: "`", + MapColumnFunc: MapColumn, + } + ) + + tests := []struct { + result string + table rel.Table + }{ + { + result: "CREATE TABLE `products` (`id` INT UNSIGNED AUTO_INCREMENT PRIMARY KEY, `name` VARCHAR(255), `description` TEXT);", + table: rel.Table{ + Op: rel.SchemaCreate, + Name: "products", + Definitions: []rel.TableDefinition{ + rel.Column{Name: "id", Type: rel.ID}, + rel.Column{Name: "name", Type: rel.String}, + rel.Column{Name: "description", Type: rel.Text}, + }, + }, + }, + { + result: "CREATE TABLE `columns` (`bool` BOOL NOT NULL DEFAULT false, `int` INT(11) UNSIGNED, `bigint` BIGINT(20) UNSIGNED, `float` FLOAT(24) UNSIGNED, `decimal` DECIMAL(6,2) UNSIGNED, `string` VARCHAR(144) UNIQUE, `text` TEXT(1000), `date` DATE, `datetime` DATETIME, `time` TIME, `timestamp` TIMESTAMP DEFAULT '2020-01-01 01:00:00', `blob` blob, PRIMARY KEY (`int`), FOREIGN KEY (`int`, `string`) REFERENCES `products` (`id`, `name`) ON DELETE CASCADE ON UPDATE CASCADE, UNIQUE `date_unique` (`date`)) Engine=InnoDB;", + table: rel.Table{ + Op: rel.SchemaCreate, + Name: "columns", + Definitions: []rel.TableDefinition{ + rel.Column{Name: "bool", Type: rel.Bool, Required: true, Default: false}, + rel.Column{Name: "int", Type: rel.Int, Limit: 11, Unsigned: true}, + rel.Column{Name: "bigint", Type: rel.BigInt, Limit: 20, Unsigned: true}, + rel.Column{Name: "float", Type: rel.Float, Precision: 24, Unsigned: true}, + rel.Column{Name: "decimal", Type: rel.Decimal, Precision: 6, Scale: 2, Unsigned: true}, + rel.Column{Name: "string", Type: rel.String, Limit: 144, Unique: true}, + rel.Column{Name: "text", Type: rel.Text, Limit: 1000}, + rel.Column{Name: "date", Type: rel.Date}, + rel.Column{Name: "datetime", Type: rel.DateTime}, + rel.Column{Name: "time", Type: rel.Time}, + rel.Column{Name: "timestamp", Type: rel.Timestamp, Default: time.Date(2020, 1, 1, 1, 0, 0, 0, time.UTC)}, + rel.Column{Name: "blob", Type: "blob"}, + rel.Key{Columns: []string{"int"}, Type: rel.PrimaryKey}, + rel.Key{Columns: []string{"int", "string"}, Type: rel.ForeignKey, Reference: rel.ForeignKeyReference{Table: "products", Columns: []string{"id", "name"}, OnDelete: "CASCADE", OnUpdate: "CASCADE"}}, + rel.Key{Columns: []string{"date"}, Name: "date_unique", Type: rel.UniqueKey}, + }, + Options: "Engine=InnoDB", + }, + }, + { + result: "CREATE TABLE IF NOT EXISTS `products` (`id` INT UNSIGNED AUTO_INCREMENT PRIMARY KEY, `raw` BOOL);", + table: rel.Table{ + Op: rel.SchemaCreate, + Name: "products", + Optional: true, + Definitions: []rel.TableDefinition{ + rel.Column{Name: "id", Type: rel.ID}, + rel.Raw("`raw` BOOL"), + }, + }, + }, + { + result: "ALTER TABLE `columns` ADD COLUMN `verified` BOOL;ALTER TABLE `columns` RENAME COLUMN `string` TO `name`;ALTER TABLE `columns` ;ALTER TABLE `columns` DROP COLUMN `blob`;", + table: rel.Table{ + Op: rel.SchemaAlter, + Name: "columns", + Definitions: []rel.TableDefinition{ + rel.Column{Name: "verified", Type: rel.Bool, Op: rel.SchemaCreate}, + rel.Column{Name: "string", Rename: "name", Op: rel.SchemaRename}, + rel.Column{Name: "bool", Type: rel.Int, Op: rel.SchemaAlter}, + rel.Column{Name: "blob", Op: rel.SchemaDrop}, + }, + }, + }, + { + result: "ALTER TABLE `transactions` ADD FOREIGN KEY (`user_id`) REFERENCES `products` (`id`, `name`) ON DELETE CASCADE ON UPDATE CASCADE;", + table: rel.Table{ + Op: rel.SchemaAlter, + Name: "transactions", + Definitions: []rel.TableDefinition{ + rel.Key{Columns: []string{"user_id"}, Type: rel.ForeignKey, Reference: rel.ForeignKeyReference{Table: "products", Columns: []string{"id", "name"}, OnDelete: "CASCADE", OnUpdate: "CASCADE"}}, + }, + }, + }, + { + result: "ALTER TABLE `table` RENAME TO `table1`;", + table: rel.Table{ + Op: rel.SchemaRename, + Name: "table", + Rename: "table1", + }, + }, + { + result: "DROP TABLE `table`;", + table: rel.Table{ + Op: rel.SchemaDrop, + Name: "table", + }, + }, + { + result: "DROP TABLE IF EXISTS `table`;", + table: rel.Table{ + Op: rel.SchemaDrop, + Name: "table", + Optional: true, + }, + }, + } + + for _, test := range tests { + t.Run(test.result, func(t *testing.T) { + var ( + builder = NewBuilder(config) + result = builder.Table(test.table) + ) + + assert.Equal(t, test.result, result) + }) + } +} + +func TestBuilder_Index(t *testing.T) { + var ( + config = Config{ + Placeholder: "?", + EscapeChar: "`", + MapColumnFunc: MapColumn, + DropIndexOnTable: true, + } + ) + + tests := []struct { + result string + index rel.Index + }{ + { + result: "CREATE INDEX `index` ON `table` (`column1`);", + index: rel.Index{ + Op: rel.SchemaCreate, + Table: "table", + Name: "index", + Columns: []string{"column1"}, + }, + }, + { + result: "CREATE UNIQUE INDEX `index` ON `table` (`column1`);", + index: rel.Index{ + Op: rel.SchemaCreate, + Table: "table", + Name: "index", + Unique: true, + Columns: []string{"column1"}, + }, + }, + { + result: "CREATE INDEX `index` ON `table` (`column1`, `column2`);", + index: rel.Index{ + Op: rel.SchemaCreate, + Table: "table", + Name: "index", + Columns: []string{"column1", "column2"}, + }, + }, + { + result: "CREATE INDEX IF NOT EXISTS `index` ON `table` (`column1`);", + index: rel.Index{ + Op: rel.SchemaCreate, + Table: "table", + Name: "index", + Optional: true, + Columns: []string{"column1"}, + }, + }, + { + result: "DROP INDEX `index` ON `table`;", + index: rel.Index{ + Op: rel.SchemaDrop, + Name: "index", + Table: "table", + }, + }, + { + result: "DROP INDEX IF EXISTS `index` ON `table`;", + index: rel.Index{ + Op: rel.SchemaDrop, + Name: "index", + Table: "table", + Optional: true, + }, + }, + } + + for _, test := range tests { + t.Run(test.result, func(t *testing.T) { + var ( + builder = NewBuilder(config) + result = builder.Index(test.index) + ) + + assert.Equal(t, test.result, result) + }) + } +} + func TestBuilder_Find(t *testing.T) { var ( - config = &Config{ + config = Config{ Placeholder: "?", EscapeChar: "`", } @@ -118,7 +324,7 @@ func TestBuilder_Find(t *testing.T) { func TestBuilder_Find_ordinal(t *testing.T) { var ( - config = &Config{ + config = Config{ Placeholder: "$", EscapeChar: "\"", Ordinal: true, @@ -199,7 +405,7 @@ func TestBuilder_Find_ordinal(t *testing.T) { func TestBuilder_Find_SQLQuery(t *testing.T) { var ( - config = &Config{} + config = Config{} builder = NewBuilder(config) query = rel.Build("", rel.SQL("SELECT * FROM `users` WHERE id=?;", 1)) qs, args = builder.Find(query) @@ -211,7 +417,7 @@ func TestBuilder_Find_SQLQuery(t *testing.T) { func BenchmarkBuilder_Aggregate(b *testing.B) { var ( - config = &Config{ + config = Config{ Placeholder: "?", EscapeChar: "`", } @@ -225,7 +431,7 @@ func BenchmarkBuilder_Aggregate(b *testing.B) { func TestBuilder_Aggregate(t *testing.T) { var ( - config = &Config{ + config = Config{ Placeholder: "?", EscapeChar: "`", } @@ -248,7 +454,7 @@ func TestBuilder_Aggregate(t *testing.T) { func BenchmarkBuilder_Insert(b *testing.B) { var ( - config = &Config{ + config = Config{ Placeholder: "?", EscapeChar: "`", } @@ -267,7 +473,7 @@ func BenchmarkBuilder_Insert(b *testing.B) { func TestBuilder_Insert(t *testing.T) { var ( - config = &Config{ + config = Config{ Placeholder: "?", EscapeChar: "`", } @@ -289,7 +495,7 @@ func TestBuilder_Insert(t *testing.T) { func TestBuilder_Insert_ordinal(t *testing.T) { var ( - config = &Config{ + config = Config{ Placeholder: "$", EscapeChar: "\"", Ordinal: true, @@ -313,7 +519,7 @@ func TestBuilder_Insert_ordinal(t *testing.T) { func TestBuilder_Insert_defaultValuesDisabled(t *testing.T) { var ( - config = &Config{ + config = Config{ Placeholder: "?", EscapeChar: "`", InsertDefaultValues: false, @@ -329,7 +535,7 @@ func TestBuilder_Insert_defaultValuesDisabled(t *testing.T) { func TestBuilder_Insert_defaultValuesEnabled(t *testing.T) { var ( - config = &Config{ + config = Config{ Placeholder: "?", InsertDefaultValues: true, EscapeChar: "`", @@ -345,7 +551,7 @@ func TestBuilder_Insert_defaultValuesEnabled(t *testing.T) { func BenchmarkBuilder_InsertAll(b *testing.B) { var ( - config = &Config{ + config = Config{ Placeholder: "?", EscapeChar: "`", } @@ -371,7 +577,7 @@ func BenchmarkBuilder_InsertAll(b *testing.B) { func TestBuilder_InsertAll(t *testing.T) { var ( - config = &Config{ + config = Config{ Placeholder: "?", EscapeChar: "`", } @@ -402,7 +608,7 @@ func TestBuilder_InsertAll(t *testing.T) { func TestBuilder_InsertAll_ordinal(t *testing.T) { var ( - config = &Config{ + config = Config{ Placeholder: "$", EscapeChar: "\"", Ordinal: true, @@ -436,7 +642,7 @@ func TestBuilder_InsertAll_ordinal(t *testing.T) { func TestBuilder_Update(t *testing.T) { var ( - config = &Config{ + config = Config{ Placeholder: "?", EscapeChar: "`", } @@ -459,7 +665,7 @@ func TestBuilder_Update(t *testing.T) { func TestBuilder_Update_ordinal(t *testing.T) { var ( - config = &Config{ + config = Config{ Placeholder: "$", EscapeChar: "\"", Ordinal: true, @@ -485,7 +691,7 @@ func TestBuilder_Update_ordinal(t *testing.T) { func TestBuilder_Update_incDecAndFragment(t *testing.T) { var ( - config = &Config{ + config = Config{ Placeholder: "?", EscapeChar: "`", } @@ -503,7 +709,7 @@ func TestBuilder_Update_incDecAndFragment(t *testing.T) { func TestBuilder_Delete(t *testing.T) { var ( - config = &Config{ + config = Config{ Placeholder: "?", EscapeChar: "`", } @@ -521,7 +727,7 @@ func TestBuilder_Delete(t *testing.T) { func TestBuilder_Delete_ordinal(t *testing.T) { var ( - config = &Config{ + config = Config{ Placeholder: "$", EscapeChar: "\"", Ordinal: true, @@ -541,7 +747,7 @@ func TestBuilder_Delete_ordinal(t *testing.T) { func TestBuilder_Select(t *testing.T) { var ( - config = &Config{ + config = Config{ Placeholder: "?", EscapeChar: "`", } @@ -603,7 +809,7 @@ func TestBuilder_Select(t *testing.T) { func TestBuilder_From(t *testing.T) { var ( buffer Buffer - config = &Config{ + config = Config{ Placeholder: "?", EscapeChar: "`", } @@ -616,7 +822,7 @@ func TestBuilder_From(t *testing.T) { func TestBuilder_Join(t *testing.T) { var ( - config = &Config{ + config = Config{ Placeholder: "?", EscapeChar: "`", } @@ -666,7 +872,7 @@ func TestBuilder_Join(t *testing.T) { func TestBuilder_Where(t *testing.T) { var ( - config = &Config{ + config = Config{ Placeholder: "?", EscapeChar: "`", } @@ -706,7 +912,7 @@ func TestBuilder_Where(t *testing.T) { func TestBuilder_Where_ordinal(t *testing.T) { var ( - config = &Config{ + config = Config{ Placeholder: "$", EscapeChar: "\"", Ordinal: true, @@ -749,7 +955,7 @@ func TestBuilder_Where_ordinal(t *testing.T) { func TestBuilder_GroupBy(t *testing.T) { var ( buffer Buffer - config = &Config{ + config = Config{ Placeholder: "?", EscapeChar: "`", } @@ -766,7 +972,7 @@ func TestBuilder_GroupBy(t *testing.T) { func TestBuilder_Having(t *testing.T) { var ( - config = &Config{ + config = Config{ Placeholder: "?", EscapeChar: "`", } @@ -806,7 +1012,7 @@ func TestBuilder_Having(t *testing.T) { func TestBuilder_Having_ordinal(t *testing.T) { var ( - config = &Config{ + config = Config{ Placeholder: "$", EscapeChar: "\"", Ordinal: true, @@ -849,7 +1055,7 @@ func TestBuilder_Having_ordinal(t *testing.T) { func TestBuilder_OrderBy(t *testing.T) { var ( buffer Buffer - config = &Config{ + config = Config{ Placeholder: "?", EscapeChar: "`", } @@ -867,7 +1073,7 @@ func TestBuilder_OrderBy(t *testing.T) { func TestBuilder_LimitOffset(t *testing.T) { var ( buffer Buffer - config = &Config{ + config = Config{ Placeholder: "?", EscapeChar: "`", } @@ -884,7 +1090,7 @@ func TestBuilder_LimitOffset(t *testing.T) { func TestBuilder_Filter(t *testing.T) { var ( - config = &Config{ + config = Config{ Placeholder: "?", EscapeChar: "`", } @@ -1069,7 +1275,7 @@ func TestBuilder_Filter(t *testing.T) { func TestBuilder_Filter_ordinal(t *testing.T) { var ( - config = &Config{ + config = Config{ Placeholder: "$", EscapeChar: "\"", Ordinal: true, @@ -1256,7 +1462,7 @@ func TestBuilder_Filter_ordinal(t *testing.T) { func TestBuilder_Lock(t *testing.T) { var ( - config = &Config{ + config = Config{ Placeholder: "?", EscapeChar: "`", } diff --git a/adapter/sql/config.go b/adapter/sql/config.go new file mode 100644 index 00000000..95f8068a --- /dev/null +++ b/adapter/sql/config.go @@ -0,0 +1,76 @@ +package sql + +import ( + "time" + + "github.com/Fs02/rel" +) + +// Config holds configuration for adapter. +type Config struct { + Placeholder string + Ordinal bool + InsertDefaultValues bool + DropIndexOnTable bool + EscapeChar string + ErrorFunc func(error) error + IncrementFunc func(Adapter) int + IndexToSQL func(config Config, buffer *Buffer, index rel.Index) bool + MapColumnFunc func(column *rel.Column) (string, int, int) +} + +// MapColumn func. +func MapColumn(column *rel.Column) (string, int, int) { + var ( + typ string + m, n int + timeLayout = "2006-01-02 15:04:05" + ) + + switch column.Type { + case rel.ID: + typ = "INT UNSIGNED AUTO_INCREMENT PRIMARY KEY" + case rel.Bool: + typ = "BOOL" + case rel.Int: + typ = "INT" + m = column.Limit + case rel.BigInt: + typ = "BIGINT" + m = column.Limit + case rel.Float: + typ = "FLOAT" + m = column.Precision + case rel.Decimal: + typ = "DECIMAL" + m = column.Precision + n = column.Scale + case rel.String: + typ = "VARCHAR" + m = column.Limit + if m == 0 { + m = 255 + } + case rel.Text: + typ = "TEXT" + m = column.Limit + case rel.Date: + typ = "DATE" + timeLayout = "2006-01-02" + case rel.DateTime: + typ = "DATETIME" + case rel.Time: + typ = "TIME" + timeLayout = "15:04:05" + case rel.Timestamp: + typ = "TIMESTAMP" + default: + typ = string(column.Type) + } + + if t, ok := column.Default.(time.Time); ok { + column.Default = t.Format(timeLayout) + } + + return typ, m, n +} diff --git a/adapter/sql/util.go b/adapter/sql/util.go index f4ac1aad..a9b39ea2 100644 --- a/adapter/sql/util.go +++ b/adapter/sql/util.go @@ -18,6 +18,39 @@ func ExtractString(s, left, right string) string { return s[start+len(left) : end] } +type fieldCacheKey struct { + field string + escape string +} + +// Escape field or table name. +func Escape(config Config, field string) string { + if config.EscapeChar == "" || field == "*" { + return field + } + + key := fieldCacheKey{field: field, escape: config.EscapeChar} + escapedField, ok := fieldCache.Load(key) + if ok { + return escapedField.(string) + } + + if len(field) > 0 && field[0] == UnescapeCharacter { + escapedField = field[1:] + } else if start, end := strings.IndexRune(field, '('), strings.IndexRune(field, ')'); start >= 0 && end >= 0 && end > start { + escapedField = field[:start+1] + Escape(config, field[start+1:end]) + field[end:] + } else if strings.HasSuffix(field, "*") { + escapedField = config.EscapeChar + strings.Replace(field, ".", config.EscapeChar+".", 1) + } else { + escapedField = config.EscapeChar + + strings.Replace(field, ".", config.EscapeChar+"."+config.EscapeChar, 1) + + config.EscapeChar + } + + fieldCache.Store(key, escapedField) + return escapedField.(string) +} + func toInt64(i interface{}) int64 { var result int64 diff --git a/adapter/sqlite3/sqlite3.go b/adapter/sqlite3/sqlite3.go index e8eed1f5..31091588 100644 --- a/adapter/sqlite3/sqlite3.go +++ b/adapter/sqlite3/sqlite3.go @@ -25,25 +25,31 @@ type Adapter struct { *sql.Adapter } -var _ rel.Adapter = (*Adapter)(nil) +var ( + _ rel.Adapter = (*Adapter)(nil) -// New is mysql adapter constructor. + // Config for mysql adapter. + Config = sql.Config{ + Placeholder: "?", + EscapeChar: "`", + InsertDefaultValues: true, + IncrementFunc: incrementFunc, + ErrorFunc: errorFunc, + MapColumnFunc: mapColumnFunc, + } +) + +// New sqlite adapter using existing connection. func New(database *db.DB) *Adapter { return &Adapter{ Adapter: &sql.Adapter{ - Config: &sql.Config{ - Placeholder: "?", - EscapeChar: "`", - InsertDefaultValues: true, - IncrementFunc: incrementFunc, - ErrorFunc: errorFunc, - }, - DB: database, + Config: Config, + DB: database, }, } } -// Open mysql connection using dsn. +// Open sqlite connection using dsn. func Open(dsn string) (*Adapter, error) { var database, err = db.Open("sqlite3", dsn) return New(database), err @@ -87,3 +93,29 @@ func errorFunc(err error) error { return err } } + +func mapColumnFunc(column *rel.Column) (string, int, int) { + var ( + typ string + m, n int + unsigned = column.Unsigned + ) + + column.Unsigned = false + + switch column.Type { + case rel.ID: + typ = "INTEGER PRIMARY KEY" + case rel.Int: + typ = "INTEGER" + m = column.Limit + default: + typ, m, n = sql.MapColumn(column) + } + + if unsigned { + typ = "UNSIGNED " + typ + } + + return typ, m, n +} diff --git a/adapter/sqlite3/sqlite3_test.go b/adapter/sqlite3/sqlite3_test.go index 4934f113..2c66c829 100644 --- a/adapter/sqlite3/sqlite3_test.go +++ b/adapter/sqlite3/sqlite3_test.go @@ -14,62 +14,6 @@ import ( var ctx = context.TODO() -func init() { - adapter, err := Open(dsn()) - paranoid.Panic(err, "failed to open database connection") - defer adapter.Close() - - _, _, err = adapter.Exec(ctx, `DROP TABLE IF EXISTS extras;`, nil) - paranoid.Panic(err, "failed when dropping extras table") - _, _, err = adapter.Exec(ctx, `DROP TABLE IF EXISTS addresses;`, nil) - paranoid.Panic(err, "failed when dropping addresses table") - _, _, err = adapter.Exec(ctx, `DROP TABLE IF EXISTS users;`, nil) - paranoid.Panic(err, "failed when dropping users table") - _, _, err = adapter.Exec(ctx, `DROP TABLE IF EXISTS composites;`, nil) - paranoid.Panic(err, "failed when dropping users table") - - _, _, err = adapter.Exec(ctx, `CREATE TABLE users ( - id INTEGER PRIMARY KEY, - slug VARCHAR(30) DEFAULT NULL, - name VARCHAR(30) NOT NULL DEFAULT '', - gender VARCHAR(10) NOT NULL DEFAULT '', - age INTEGER NOT NULL DEFAULT 0, - note varchar(50), - created_at DATETIME, - updated_at DATETIME, - UNIQUE (slug) - );`, nil) - paranoid.Panic(err, "failed when creating users table") - - _, _, err = adapter.Exec(ctx, `CREATE TABLE addresses ( - id INTEGER PRIMARY KEY, - user_id INTEGER, - name VARCHAR(60) NOT NULL DEFAULT '', - created_at DATETIME, - updated_at DATETIME, - FOREIGN KEY (user_id) REFERENCES users(id) - );`, nil) - paranoid.Panic(err, "failed when creating addresses table") - - _, _, err = adapter.Exec(ctx, `CREATE TABLE extras ( - id INTEGER PRIMARY KEY, - slug VARCHAR(30) DEFAULT NULL UNIQUE, - user_id INTEGER, - score INTEGER DEFAULT 0, - FOREIGN KEY (user_id) REFERENCES users(id), - CONSTRAINT extras_score_check CHECK (score>=0 AND score<=100) - );`, nil) - paranoid.Panic(err, "failed when creating extras table") - - _, _, err = adapter.Exec(ctx, `CREATE TABLE composites ( - primary1 INTEGER, - primary2 INTEGER, - data VARCHAR(255) DEFAULT NULL, - PRIMARY KEY (primary1, primary2) - );`, nil) - paranoid.Panic(err, "failed when creating extras table") -} - func dsn() string { if os.Getenv("SQLITE3_DATABASE") != "" { return os.Getenv("SQLITE3_DATABASE") + "?_foreign_keys=1&_loc=Local" @@ -85,6 +29,13 @@ func TestAdapter_specs(t *testing.T) { repo := rel.New(adapter) + // Prepare tables + teardown := specs.Setup(t, repo) + defer teardown() + + // Migration Specs + specs.Migrate(t, repo, specs.SkipDropColumn) + // Query Specs specs.Query(t, repo) specs.QueryJoin(t, repo) diff --git a/adapter_test.go b/adapter_test.go index c8a87048..7f2a2cda 100644 --- a/adapter_test.go +++ b/adapter_test.go @@ -76,6 +76,11 @@ func (ta *testAdapter) Rollback(ctx context.Context) error { return args.Error(0) } +func (ta *testAdapter) Apply(ctx context.Context, migration Migration) error { + args := ta.Called(migration) + return args.Error(0) +} + func (ta *testAdapter) Result(result interface{}) *testAdapter { ta.result = result return ta diff --git a/collection.go b/collection.go index cb022abf..d8913cdc 100644 --- a/collection.go +++ b/collection.go @@ -18,7 +18,6 @@ type Collection struct { rv reflect.Value rt reflect.Type data documentData - index map[interface{}]int swapper func(i, j int) } diff --git a/column.go b/column.go new file mode 100644 index 00000000..a903520a --- /dev/null +++ b/column.go @@ -0,0 +1,81 @@ +package rel + +// ColumnType definition. +type ColumnType string + +const ( + // ID ColumnType. + ID ColumnType = "ID" + // Bool ColumnType. + Bool ColumnType = "BOOL" + // Int ColumnType. + Int ColumnType = "INT" + // BigInt ColumnType. + BigInt ColumnType = "BIGINT" + // Float ColumnType. + Float ColumnType = "FLOAT" + // Decimal ColumnType. + Decimal ColumnType = "DECIMAL" + // String ColumnType. + String ColumnType = "STRING" + // Text ColumnType. + Text ColumnType = "TEXT" + // Date ColumnType. + Date ColumnType = "DATE" + // DateTime ColumnType. + DateTime ColumnType = "DATETIME" + // Time ColumnType. + Time ColumnType = "TIME" + // Timestamp ColumnType. + Timestamp ColumnType = "TIMESTAMP" +) + +// Column definition. +type Column struct { + Op SchemaOp + Name string + Type ColumnType + Rename string + Unique bool + Required bool + Unsigned bool + Limit int + Precision int + Scale int + Default interface{} + Options string +} + +func (Column) internalTableDefinition() {} + +func createColumn(name string, typ ColumnType, options []ColumnOption) Column { + column := Column{ + Op: SchemaCreate, + Name: name, + Type: typ, + } + + applyColumnOptions(&column, options) + return column +} + +func renameColumn(name string, newName string, options []ColumnOption) Column { + column := Column{ + Op: SchemaRename, + Name: name, + Rename: newName, + } + + applyColumnOptions(&column, options) + return column +} + +func dropColumn(name string, options []ColumnOption) Column { + column := Column{ + Op: SchemaDrop, + Name: name, + } + + applyColumnOptions(&column, options) + return column +} diff --git a/column_test.go b/column_test.go new file mode 100644 index 00000000..53c58bf7 --- /dev/null +++ b/column_test.go @@ -0,0 +1,91 @@ +package rel + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCreateColumn(t *testing.T) { + var ( + options = []ColumnOption{ + Unique(true), + Required(true), + Unsigned(true), + Limit(1000), + Precision(5), + Scale(2), + Default(0), + Options("options"), + } + column = createColumn("add", Decimal, options) + ) + + assert.Equal(t, Column{ + Name: "add", + Type: Decimal, + Unique: true, + Required: true, + Unsigned: true, + Limit: 1000, + Precision: 5, + Scale: 2, + Default: 0, + Options: "options", + }, column) +} + +func TestRenameColumn(t *testing.T) { + var ( + options = []ColumnOption{ + Required(true), + Unsigned(true), + Limit(1000), + Precision(5), + Scale(2), + Default(0), + Options("options"), + } + column = renameColumn("add", "rename", options) + ) + + assert.Equal(t, Column{ + Op: SchemaRename, + Name: "add", + Rename: "rename", + Required: true, + Unsigned: true, + Limit: 1000, + Precision: 5, + Scale: 2, + Default: 0, + Options: "options", + }, column) +} + +func TestDropColumn(t *testing.T) { + var ( + options = []ColumnOption{ + Required(true), + Unsigned(true), + Limit(1000), + Precision(5), + Scale(2), + Default(0), + Options("options"), + } + column = dropColumn("drop", options) + ) + + assert.Equal(t, Column{ + Op: SchemaDrop, + Name: "drop", + Required: true, + Unsigned: true, + Limit: 1000, + Precision: 5, + Scale: 2, + Default: 0, + Options: "options", + }, column) +} diff --git a/context_wrapper.go b/context_wrapper.go index 664ca28f..ae1dee64 100644 --- a/context_wrapper.go +++ b/context_wrapper.go @@ -6,10 +6,6 @@ import ( type contextKey int8 -type contextData struct { - adapter Adapter -} - type contextWrapper struct { ctx context.Context adapter Adapter diff --git a/docs/README.md b/docs/README.md index 40e4e52a..a64b0fe4 100644 --- a/docs/README.md +++ b/docs/README.md @@ -14,6 +14,7 @@ REL is golang orm-ish database layer for layered architecture. It's testable and - Multi adapter. - Soft Deletion. - Pagination. +- Schema Migration. ## Install diff --git a/document.go b/document.go index 0242f3ae..c48a8060 100644 --- a/document.go +++ b/document.go @@ -33,8 +33,6 @@ const ( var ( tablesCache sync.Map primariesCache sync.Map - fieldsCache sync.Map - typesCache sync.Map documentDataCache sync.Map rtTime = reflect.TypeOf(time.Time{}) rtTable = reflect.TypeOf((*table)(nil)).Elem() diff --git a/go.mod b/go.mod index 4836ea45..26f403ec 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/jinzhu/inflection v0.0.0-20180308033659-04140366298a github.com/kr/pretty v0.1.0 // indirect github.com/lib/pq v1.3.0 - github.com/mattn/go-sqlite3 v1.6.0 + github.com/mattn/go-sqlite3 v1.14.2 github.com/stretchr/objx v0.2.0 // indirect github.com/stretchr/testify v1.4.0 gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect diff --git a/go.sum b/go.sum index e2a1fee9..235f62c4 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/Fs02/go-paranoid v0.0.0-20190122110906-018c1ac5124a h1:wYjvXrzEmkEe3kNQXUd2Nzt/EO28kqebKsUWjXH9Opk= github.com/Fs02/go-paranoid v0.0.0-20190122110906-018c1ac5124a/go.mod h1:mUYWV9DG75bJ33LZlW1Je3MW64017zkfUFCf+QnCJs0= +github.com/PuerkitoBio/goquery v1.5.1/go.mod h1:GsLWisAFVj4WgDibEWF4pvYnkVQBpKBKeU+7zCJoLcc= +github.com/andybalholm/cascadia v1.1.0/go.mod h1:GsXiBklL0woXo1j/WYWtSYYC4ouU9PqHO0sqidkEA4Y= github.com/azer/snakecase v0.0.0-20161028114325-c818dddafb5c h1:7zL0ljVI6ads5EFvx+Oq+uompnFBMJqtbuHvyobbJ1Q= github.com/azer/snakecase v0.0.0-20161028114325-c818dddafb5c/go.mod h1:iApMeoHF0YlMPzCwqH/d59E3w2s8SeO4rGK+iGClS8Y= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= @@ -17,8 +19,8 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lib/pq v1.3.0 h1:/qkRGz8zljWiDcFvgpwUpwIAPu3r07TDvs3Rws+o/pU= github.com/lib/pq v1.3.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/mattn/go-sqlite3 v1.6.0 h1:TDwTWbeII+88Qy55nWlof0DclgAtI4LqGujkYMzmQII= -github.com/mattn/go-sqlite3 v1.6.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= +github.com/mattn/go-sqlite3 v1.14.2 h1:A2EQLwjYf/hfYaM20FVjs1UewCTTFR7RmjEHkLjldIA= +github.com/mattn/go-sqlite3 v1.14.2/go.mod h1:JIl7NbARA7phWnGvh0LKTyg7S9BA+6gx71ShQilpsus= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= @@ -29,6 +31,13 @@ github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0 github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/net v0.0.0-20180218175443-cbe0f9307d01/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/index.go b/index.go new file mode 100644 index 00000000..6193cb3e --- /dev/null +++ b/index.go @@ -0,0 +1,62 @@ +package rel + +// Index definition. +type Index struct { + Op SchemaOp + Table string + Name string + Unique bool + Columns []string + Optional bool + Options string +} + +func (Index) internalMigration() {} + +func createIndex(table string, name string, columns []string, options []IndexOption) Index { + index := Index{ + Op: SchemaCreate, + Table: table, + Name: name, + Columns: columns, + } + + applyIndexOptions(&index, options) + return index +} + +func createUniqueIndex(table string, name string, columns []string, options []IndexOption) Index { + index := createIndex(table, name, columns, options) + index.Unique = true + return index +} + +func dropIndex(table string, name string, options []IndexOption) Index { + index := Index{ + Op: SchemaDrop, + Table: table, + Name: name, + } + + applyIndexOptions(&index, options) + return index +} + +// IndexOption interface. +// Available options are: Comment, Options. +type IndexOption interface { + applyIndex(index *Index) +} + +func applyIndexOptions(index *Index, options []IndexOption) { + for i := range options { + options[i].applyIndex(index) + } +} + +// Name option for defining custom index name. +type Name string + +func (n Name) applyKey(key *Key) { + key.Name = string(n) +} diff --git a/index_test.go b/index_test.go new file mode 100644 index 00000000..f2a9601d --- /dev/null +++ b/index_test.go @@ -0,0 +1,58 @@ +package rel + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCreateIndex(t *testing.T) { + var ( + options = []IndexOption{ + Options("options"), + Optional(true), + } + index = createIndex("table", "add_idx", []string{"add"}, options) + ) + + assert.Equal(t, Index{ + Table: "table", + Name: "add_idx", + Columns: []string{"add"}, + Optional: true, + Options: "options", + }, index) +} + +func TestCreateUniqueIndex(t *testing.T) { + var ( + options = []IndexOption{ + Options("options"), + } + index = createUniqueIndex("table", "add_idx", []string{"add"}, options) + ) + + assert.Equal(t, Index{ + Table: "table", + Name: "add_idx", + Unique: true, + Columns: []string{"add"}, + Options: "options", + }, index) +} + +func TestDropIndex(t *testing.T) { + var ( + options = []IndexOption{ + Options("options"), + } + index = dropIndex("table", "drop", options) + ) + + assert.Equal(t, Index{ + Op: SchemaDrop, + Table: "table", + Name: "drop", + Options: "options", + }, index) +} diff --git a/key.go b/key.go new file mode 100644 index 00000000..9d151571 --- /dev/null +++ b/key.go @@ -0,0 +1,66 @@ +package rel + +// KeyType definition. +type KeyType string + +const ( + // PrimaryKey KeyType. + PrimaryKey KeyType = "PRIMARY KEY" + // ForeignKey KeyType. + ForeignKey KeyType = "FOREIGN KEY" + // UniqueKey KeyType. + UniqueKey = "UNIQUE" +) + +// ForeignKeyReference definition. +type ForeignKeyReference struct { + Table string + Columns []string + OnDelete string + OnUpdate string +} + +// Key definition. +type Key struct { + Op SchemaOp + Name string + Type KeyType + Columns []string + Rename string + Reference ForeignKeyReference + Options string +} + +func (Key) internalTableDefinition() {} + +func createKeys(columns []string, typ KeyType, options []KeyOption) Key { + key := Key{ + Op: SchemaCreate, + Columns: columns, + Type: typ, + } + + applyKeyOptions(&key, options) + return key +} + +func createPrimaryKeys(columns []string, options []KeyOption) Key { + return createKeys(columns, PrimaryKey, options) +} + +func createForeignKey(column string, refTable string, refColumn string, options []KeyOption) Key { + key := Key{ + Op: SchemaCreate, + Type: ForeignKey, + Columns: []string{column}, + Reference: ForeignKeyReference{ + Table: refTable, + Columns: []string{refColumn}, + }, + } + + applyKeyOptions(&key, options) + return key +} + +// TODO: Rename and Drop, PR welcomed. diff --git a/key_test.go b/key_test.go new file mode 100644 index 00000000..9e79de76 --- /dev/null +++ b/key_test.go @@ -0,0 +1,32 @@ +package rel + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCreateForeignKey(t *testing.T) { + var ( + options = []KeyOption{ + OnDelete("cascade"), + OnUpdate("cascade"), + Name("fk"), + Options("options"), + } + index = createForeignKey("table_id", "table", "id", options) + ) + + assert.Equal(t, Key{ + Type: ForeignKey, + Name: "fk", + Columns: []string{"table_id"}, + Reference: ForeignKeyReference{ + Table: "table", + Columns: []string{"id"}, + OnDelete: "cascade", + OnUpdate: "cascade", + }, + Options: "options", + }, index) +} diff --git a/migrator/migrator.go b/migrator/migrator.go new file mode 100644 index 00000000..0b27d916 --- /dev/null +++ b/migrator/migrator.go @@ -0,0 +1,165 @@ +package migrator + +import ( + "context" + "fmt" + "sort" + "time" + + "github.com/Fs02/rel" +) + +const versionTable = "rel_schema_versions" + +type version struct { + ID int + Version int + CreatedAt time.Time + UpdatedAt time.Time + + up rel.Schema + down rel.Schema + applied bool +} + +func (version) Table() string { + return versionTable +} + +type versions []version + +func (v versions) Len() int { + return len(v) +} + +func (v versions) Less(i, j int) bool { + return v[i].Version < v[j].Version +} + +func (v versions) Swap(i, j int) { + v[i], v[j] = v[j], v[i] +} + +// Migrator is a migration manager that handles migration logic. +type Migrator struct { + repo rel.Repository + versions versions + versionTableExists bool +} + +// Register a migration. +func (m *Migrator) Register(v int, up func(schema *rel.Schema), down func(schema *rel.Schema)) { + var upSchema, downSchema rel.Schema + + up(&upSchema) + down(&downSchema) + + m.versions = append(m.versions, version{Version: v, up: upSchema, down: downSchema}) +} + +func (m Migrator) buildVersionTableDefinition() rel.Table { + var schema rel.Schema + schema.CreateTableIfNotExists(versionTable, func(t *rel.Table) { + t.ID("id") + t.BigInt("version", rel.Unsigned(true), rel.Unique(true)) + t.DateTime("created_at") + t.DateTime("updated_at") + }) + + return schema.Migrations[0].(rel.Table) +} + +func (m *Migrator) sync(ctx context.Context) { + var ( + versions versions + vi int + adapter = m.repo.Adapter(ctx).(rel.Adapter) + ) + + if !m.versionTableExists { + check(adapter.Apply(ctx, m.buildVersionTableDefinition())) + m.versionTableExists = true + } + + m.repo.MustFindAll(ctx, &versions, rel.NewSortAsc("version")) + sort.Sort(m.versions) + + for i := range m.versions { + if vi < len(versions) && m.versions[i].Version == versions[vi].Version { + m.versions[i].ID = versions[vi].ID + m.versions[i].applied = true + vi++ + } else { + m.versions[i].applied = false + } + } + + if vi != len(versions) { + panic(fmt.Sprint("rel: missing local migration: ", versions[vi].Version)) + } +} + +// Migrate to the latest schema version. +func (m *Migrator) Migrate(ctx context.Context) { + m.sync(ctx) + + for _, v := range m.versions { + if v.applied { + continue + } + + err := m.repo.Transaction(ctx, func(ctx context.Context) error { + m.repo.MustInsert(ctx, &version{Version: v.Version}) + m.run(ctx, v.up.Migrations) + return nil + }) + + check(err) + } +} + +// Rollback migration 1 step. +func (m *Migrator) Rollback(ctx context.Context) { + m.sync(ctx) + + for i := range m.versions { + v := m.versions[len(m.versions)-i-1] + if !v.applied { + continue + } + + err := m.repo.Transaction(ctx, func(ctx context.Context) error { + m.repo.MustDelete(ctx, &v) + m.run(ctx, v.down.Migrations) + return nil + }) + + check(err) + + // only rollback one version. + return + } +} + +func (m *Migrator) run(ctx context.Context, migrations []rel.Migration) { + adapter := m.repo.Adapter(ctx).(rel.Adapter) + for _, migration := range migrations { + if fn, ok := migration.(rel.Do); ok { + check(fn(m.repo)) + } else { + check(adapter.Apply(ctx, migration)) + } + } + +} + +// New migrationr. +func New(repo rel.Repository) Migrator { + return Migrator{repo: repo} +} + +func check(err error) { + if err != nil { + panic(err) + } +} diff --git a/migrator/migrator_test.go b/migrator/migrator_test.go new file mode 100644 index 00000000..6f3fb00c --- /dev/null +++ b/migrator/migrator_test.go @@ -0,0 +1,212 @@ +package migrator + +import ( + "context" + "errors" + "testing" + + "github.com/Fs02/rel" + "github.com/Fs02/rel/reltest" + "github.com/stretchr/testify/assert" +) + +func TestMigrator(t *testing.T) { + var ( + ctx = context.TODO() + repo = reltest.New() + migrator = New(repo) + ) + + t.Run("Register", func(t *testing.T) { + migrator.Register(20200829084000, + func(schema *rel.Schema) { + schema.CreateTable("users", func(t *rel.Table) { + t.ID("id") + }) + }, + func(schema *rel.Schema) { + schema.DropTable("users") + }, + ) + + migrator.Register(20200828100000, + func(schema *rel.Schema) { + schema.CreateTable("tags", func(t *rel.Table) { + t.ID("id") + }) + + schema.Do(func(repo rel.Repository) error { + assert.NotNil(t, repo) + return nil + }) + }, + func(schema *rel.Schema) { + schema.DropTable("tags") + }, + ) + + migrator.Register(20200829115100, + func(schema *rel.Schema) { + schema.CreateTable("books", func(t *rel.Table) { + t.ID("id") + }) + }, + func(schema *rel.Schema) { + schema.DropTable("books") + }, + ) + + assert.Len(t, migrator.versions, 3) + assert.Equal(t, 20200829084000, migrator.versions[0].Version) + assert.Equal(t, 20200828100000, migrator.versions[1].Version) + assert.Equal(t, 20200829115100, migrator.versions[2].Version) + }) + + t.Run("Migrate", func(t *testing.T) { + repo.ExpectFindAll(rel.NewSortAsc("version")). + Result(versions{{ID: 1, Version: 20200829115100}}) + + repo.ExpectTransaction(func(repo *reltest.Repository) { + repo.ExpectInsert().For(&version{Version: 20200828100000}) + }) + + repo.ExpectTransaction(func(repo *reltest.Repository) { + repo.ExpectInsert().For(&version{Version: 20200829084000}) + }) + + migrator.Migrate(ctx) + }) + + t.Run("Rollback", func(t *testing.T) { + repo.ExpectFindAll(rel.NewSortAsc("version")). + Result(versions{ + {ID: 1, Version: 20200828100000}, + {ID: 2, Version: 20200829084000}, + }) + + assert.Equal(t, 20200829084000, migrator.versions[1].Version) + + repo.ExpectTransaction(func(repo *reltest.Repository) { + repo.ExpectDelete().For(&migrator.versions[1]) + }) + + migrator.Rollback(ctx) + }) +} + +func TestMigrator_Sync(t *testing.T) { + var ( + ctx = context.TODO() + repo = reltest.New() + nfn = func(schema *rel.Schema) {} + ) + + tests := []struct { + name string + applied versions + synced versions + isPanic bool + }{ + { + name: "all migrated", + applied: versions{ + {ID: 1, Version: 1}, + {ID: 2, Version: 2}, + {ID: 3, Version: 3}, + }, + synced: versions{ + {ID: 1, Version: 1, applied: true}, + {ID: 2, Version: 2, applied: true}, + {ID: 3, Version: 3, applied: true}, + }, + }, + { + name: "not migrated", + applied: versions{}, + synced: versions{ + {ID: 0, Version: 1, applied: false}, + {ID: 0, Version: 2, applied: false}, + {ID: 0, Version: 3, applied: false}, + }, + }, + { + name: "first not migrated", + applied: versions{ + {ID: 2, Version: 2}, + {ID: 3, Version: 3}, + }, + synced: versions{ + {ID: 0, Version: 1, applied: false}, + {ID: 2, Version: 2, applied: true}, + {ID: 3, Version: 3, applied: true}, + }, + }, + { + name: "middle not migrated", + applied: versions{ + {ID: 1, Version: 1}, + {ID: 3, Version: 3}, + }, + synced: versions{ + {ID: 1, Version: 1, applied: true}, + {ID: 0, Version: 2, applied: false}, + {ID: 3, Version: 3, applied: true}, + }, + }, + { + name: "last not migrated", + applied: versions{ + {ID: 1, Version: 1}, + {ID: 2, Version: 2}, + }, + synced: versions{ + {ID: 1, Version: 1, applied: true}, + {ID: 2, Version: 2, applied: true}, + {ID: 0, Version: 3, applied: false}, + }, + }, + { + name: "broken migration", + applied: versions{ + {ID: 1, Version: 1}, + {ID: 2, Version: 2}, + {ID: 3, Version: 3}, + {ID: 4, Version: 4}, + }, + synced: versions{ + {ID: 1, Version: 1, applied: true}, + {ID: 2, Version: 2, applied: true}, + {ID: 3, Version: 3, applied: true}, + }, + isPanic: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + migrator := New(repo) + migrator.Register(3, nfn, nfn) + migrator.Register(2, nfn, nfn) + migrator.Register(1, nfn, nfn) + + repo.ExpectFindAll(rel.NewSortAsc("version")).Result(test.applied) + + if test.isPanic { + assert.Panics(t, func() { + migrator.sync(ctx) + }) + } else { + assert.NotPanics(t, func() { + migrator.sync(ctx) + }) + + assert.Equal(t, test.synced, migrator.versions) + } + }) + } +} +func TestCheck(t *testing.T) { + assert.Panics(t, func() { + check(errors.New("error")) + }) +} diff --git a/query.go b/query.go index a8d8d354..5aaffe8e 100644 --- a/query.go +++ b/query.go @@ -336,7 +336,9 @@ func (o Offset) Build(query *Query) { query.OffsetQuery = o } -// Limit query. +// Limit options. +// When passed as query, it limits returned result from database. +// When passed as column option, it sets the maximum size of the string/text/binary/integer columns. type Limit int // Build query. @@ -344,6 +346,10 @@ func (l Limit) Build(query *Query) { query.LimitQuery = l } +func (l Limit) applyColumn(column *Column) { + column.Limit = int(l) +} + // Lock query. // This query will be ignored if used outside of transaction. type Lock string diff --git a/reltest/nop_adapter.go b/reltest/nop_adapter.go index d7dc555c..b7ea8cc6 100644 --- a/reltest/nop_adapter.go +++ b/reltest/nop_adapter.go @@ -7,7 +7,6 @@ import ( ) type nopAdapter struct { - count int } func (na *nopAdapter) Instrumentation(instrumenter rel.Instrumenter) { @@ -61,6 +60,10 @@ func (na *nopAdapter) Update(ctx context.Context, query rel.Query, mutates map[s return 1, nil } +func (na *nopAdapter) Apply(ctx context.Context, migration rel.Migration) error { + return nil +} + type nopCursor struct { count int } diff --git a/reltest/nop_adapter_test.go b/reltest/nop_adapter_test.go new file mode 100644 index 00000000..00237bc3 --- /dev/null +++ b/reltest/nop_adapter_test.go @@ -0,0 +1,18 @@ +package reltest + +import ( + "context" + "testing" + + "github.com/Fs02/rel" + "github.com/stretchr/testify/assert" +) + +func TestNopAdapter_Apply(t *testing.T) { + var ( + ctx = context.TODO() + adapter = &nopAdapter{} + ) + + assert.Nil(t, adapter.Apply(ctx, rel.Table{})) +} diff --git a/reltest/repository.go b/reltest/repository.go index 17fb8685..3fe770d5 100644 --- a/reltest/repository.go +++ b/reltest/repository.go @@ -20,7 +20,7 @@ var _ rel.Repository = (*Repository)(nil) // Adapter provides a mock function with given fields: func (r *Repository) Adapter(ctx context.Context) rel.Adapter { - return nil + return r.repo.Adapter(ctx) } // Instrumentation provides a mock function with given fields: instrumenter diff --git a/reltest/repository_test.go b/reltest/repository_test.go index 77ecbc0c..d95c6b65 100644 --- a/reltest/repository_test.go +++ b/reltest/repository_test.go @@ -45,7 +45,7 @@ func TestRepository_Adapter(t *testing.T) { repo = New() ) - assert.Nil(t, repo.Adapter(ctx)) + assert.NotNil(t, repo.Adapter(ctx)) } func TestRepository_Instrumentation(t *testing.T) { diff --git a/repository_test.go b/repository_test.go index d45b14c7..8173dcea 100644 --- a/repository_test.go +++ b/repository_test.go @@ -18,8 +18,6 @@ func init() { } } -var repo = repository{} - func createCursor(row int) *testCursor { cur := &testCursor{} diff --git a/schema.go b/schema.go new file mode 100644 index 00000000..2a0a1a82 --- /dev/null +++ b/schema.go @@ -0,0 +1,111 @@ +package rel + +// SchemaOp type. +type SchemaOp uint8 + +const ( + // SchemaCreate operation. + SchemaCreate SchemaOp = iota + // SchemaAlter operation. + SchemaAlter + // SchemaRename operation. + SchemaRename + // SchemaDrop operation. + SchemaDrop +) + +// Migration definition. +type Migration interface { + internalMigration() +} + +// Schema builder. +type Schema struct { + Migrations []Migration +} + +func (s *Schema) add(migration Migration) { + s.Migrations = append(s.Migrations, migration) +} + +// CreateTable with name and its definition. +func (s *Schema) CreateTable(name string, fn func(t *Table), options ...TableOption) { + table := createTable(name, options) + fn(&table) + s.add(table) +} + +// CreateTableIfNotExists with name and its definition. +func (s *Schema) CreateTableIfNotExists(name string, fn func(t *Table), options ...TableOption) { + table := createTableIfNotExists(name, options) + fn(&table) + s.add(table) +} + +// AlterTable with name and its definition. +func (s *Schema) AlterTable(name string, fn func(t *AlterTable), options ...TableOption) { + table := alterTable(name, options) + fn(&table) + s.add(table.Table) +} + +// RenameTable by name. +func (s *Schema) RenameTable(name string, newName string, options ...TableOption) { + s.add(renameTable(name, newName, options)) +} + +// DropTable by name. +func (s *Schema) DropTable(name string, options ...TableOption) { + s.add(dropTable(name, options)) +} + +// DropTableIfExists by name. +func (s *Schema) DropTableIfExists(name string, options ...TableOption) { + s.add(dropTableIfExists(name, options)) +} + +// AddColumn with name and type. +func (s *Schema) AddColumn(table string, name string, typ ColumnType, options ...ColumnOption) { + at := alterTable(table, nil) + at.Column(name, typ, options...) + s.add(at.Table) +} + +// RenameColumn by name. +func (s *Schema) RenameColumn(table string, name string, newName string, options ...ColumnOption) { + at := alterTable(table, nil) + at.RenameColumn(name, newName, options...) + s.add(at.Table) +} + +// DropColumn by name. +func (s *Schema) DropColumn(table string, name string, options ...ColumnOption) { + at := alterTable(table, nil) + at.DropColumn(name, options...) + s.add(at.Table) +} + +// CreateIndex for columns on a table. +func (s *Schema) CreateIndex(table string, name string, column []string, options ...IndexOption) { + s.add(createIndex(table, name, column, options)) +} + +// CreateUniqueIndex for columns on a table. +func (s *Schema) CreateUniqueIndex(table string, name string, column []string, options ...IndexOption) { + s.add(createUniqueIndex(table, name, column, options)) +} + +// DropIndex by name. +func (s *Schema) DropIndex(table string, name string, options ...IndexOption) { + s.add(dropIndex(table, name, options)) +} + +// Exec queries. +func (s *Schema) Exec(raw Raw) { + s.add(raw) +} + +// Do migration using golang codes. +func (s *Schema) Do(fn Do) { + s.add(fn) +} diff --git a/schema_options.go b/schema_options.go new file mode 100644 index 00000000..dcd26170 --- /dev/null +++ b/schema_options.go @@ -0,0 +1,146 @@ +package rel + +// TableOption interface. +// Available options are: Comment, Options. +type TableOption interface { + applyTable(table *Table) +} + +func applyTableOptions(table *Table, options []TableOption) { + for i := range options { + options[i].applyTable(table) + } +} + +// ColumnOption interface. +// Available options are: Nil, Unsigned, Limit, Precision, Scale, Default, Comment, Options. +type ColumnOption interface { + applyColumn(column *Column) +} + +func applyColumnOptions(column *Column, options []ColumnOption) { + for i := range options { + options[i].applyColumn(column) + } +} + +// KeyOption interface. +// Available options are: Comment, Options. +type KeyOption interface { + applyKey(key *Key) +} + +func applyKeyOptions(key *Key, options []KeyOption) { + for i := range options { + options[i].applyKey(key) + } +} + +// Unique set column as unique. +type Unique bool + +func (r Unique) applyColumn(column *Column) { + column.Unique = bool(r) +} + +func (r Unique) applyIndex(index *Index) { + index.Unique = bool(r) +} + +// Required disallows nil values in the column. +type Required bool + +func (r Required) applyColumn(column *Column) { + column.Required = bool(r) +} + +// Unsigned sets integer column to be unsigned. +type Unsigned bool + +func (u Unsigned) applyColumn(column *Column) { + column.Unsigned = bool(u) +} + +// Precision defines the precision for the decimal fields, representing the total number of digits in the number. +type Precision int + +func (p Precision) applyColumn(column *Column) { + column.Precision = int(p) +} + +// Scale Defines the scale for the decimal fields, representing the number of digits after the decimal point. +type Scale int + +func (s Scale) applyColumn(column *Column) { + column.Scale = int(s) +} + +type defaultValue struct { + value interface{} +} + +func (d defaultValue) applyColumn(column *Column) { + column.Default = d.value +} + +// Default allows to set a default value on the column.). +func Default(def interface{}) ColumnOption { + return defaultValue{value: def} +} + +// OnDelete option for foreign key. +type OnDelete string + +func (od OnDelete) applyKey(key *Key) { + key.Reference.OnDelete = string(od) +} + +// OnUpdate option for foreign key. +type OnUpdate string + +func (ou OnUpdate) applyKey(key *Key) { + key.Reference.OnUpdate = string(ou) +} + +// Options options for table, column and index. +type Options string + +func (o Options) applyTable(table *Table) { + table.Options = string(o) +} + +func (o Options) applyColumn(column *Column) { + column.Options = string(o) +} + +func (o Options) applyIndex(index *Index) { + index.Options = string(o) +} + +func (o Options) applyKey(key *Key) { + key.Options = string(o) +} + +// Optional option. +// when used with create table, will create table only if it's not exists. +// when used with drop table, will drop table only if it's exists. +type Optional bool + +func (o Optional) applyTable(table *Table) { + table.Optional = bool(o) +} + +func (o Optional) applyIndex(index *Index) { + index.Optional = bool(o) +} + +// Raw string +type Raw string + +func (r Raw) internalMigration() {} +func (r Raw) internalTableDefinition() {} + +// Do used internally for schema migration. +type Do func(Repository) error + +func (d Do) internalMigration() {} diff --git a/schema_test.go b/schema_test.go new file mode 100644 index 00000000..d77c450b --- /dev/null +++ b/schema_test.go @@ -0,0 +1,199 @@ +package rel + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSchema_CreateTable(t *testing.T) { + var schema Schema + + schema.CreateTable("products", func(t *Table) { + t.ID("id") + t.String("name") + t.Text("description") + }) + + assert.Equal(t, Table{ + Op: SchemaCreate, + Name: "products", + Definitions: []TableDefinition{ + Column{Name: "id", Type: ID}, + Column{Name: "name", Type: String}, + Column{Name: "description", Type: Text}, + }, + }, schema.Migrations[0]) + + schema.CreateTableIfNotExists("products", func(t *Table) { + t.ID("id") + }) + + assert.Equal(t, Table{ + Op: SchemaCreate, + Name: "products", + Optional: true, + Definitions: []TableDefinition{ + Column{Name: "id", Type: ID}, + }, + }, schema.Migrations[1]) +} + +func TestSchema_AlterTable(t *testing.T) { + var schema Schema + + schema.AlterTable("users", func(t *AlterTable) { + t.Bool("verified") + t.RenameColumn("name", "fullname") + }) + + assert.Equal(t, Table{ + Op: SchemaAlter, + Name: "users", + Definitions: []TableDefinition{ + Column{Name: "verified", Type: Bool, Op: SchemaCreate}, + Column{Name: "name", Rename: "fullname", Op: SchemaRename}, + }, + }, schema.Migrations[0]) +} + +func TestSchema_RenameTable(t *testing.T) { + var schema Schema + + schema.RenameTable("trxs", "transactions") + + assert.Equal(t, Table{ + Op: SchemaRename, + Name: "trxs", + Rename: "transactions", + }, schema.Migrations[0]) +} + +func TestSchema_DropTable(t *testing.T) { + var schema Schema + + schema.DropTable("logs") + + assert.Equal(t, Table{ + Op: SchemaDrop, + Name: "logs", + }, schema.Migrations[0]) + + schema.DropTableIfExists("logs") + + assert.Equal(t, Table{ + Op: SchemaDrop, + Name: "logs", + Optional: true, + }, schema.Migrations[1]) +} + +func TestSchema_AddColumn(t *testing.T) { + var schema Schema + + schema.AddColumn("products", "description", String) + + assert.Equal(t, Table{ + Op: SchemaAlter, + Name: "products", + Definitions: []TableDefinition{ + Column{Name: "description", Type: String, Op: SchemaCreate}, + }, + }, schema.Migrations[0]) +} + +func TestSchema_RenameColumn(t *testing.T) { + var schema Schema + + schema.RenameColumn("users", "name", "fullname") + + assert.Equal(t, Table{ + Op: SchemaAlter, + Name: "users", + Definitions: []TableDefinition{ + Column{Name: "name", Rename: "fullname", Op: SchemaRename}, + }, + }, schema.Migrations[0]) +} + +func TestSchema_DropColumn(t *testing.T) { + var schema Schema + + schema.DropColumn("users", "verified") + + assert.Equal(t, Table{ + Op: SchemaAlter, + Name: "users", + Definitions: []TableDefinition{ + Column{Name: "verified", Op: SchemaDrop}, + }, + }, schema.Migrations[0]) +} + +func TestSchema_CreateIndex(t *testing.T) { + var schema Schema + + schema.CreateIndex("products", "sale_idx", []string{"sale"}) + + assert.Equal(t, Index{ + Table: "products", + Name: "sale_idx", + Columns: []string{"sale"}, + Op: SchemaCreate, + }, schema.Migrations[0]) +} + +func TestSchema_CreateIndex_unique(t *testing.T) { + var schema Schema + + schema.CreateIndex("products", "sale_idx", []string{"sale"}, Unique(true)) + + assert.Equal(t, Index{ + Table: "products", + Name: "sale_idx", + Unique: true, + Columns: []string{"sale"}, + Op: SchemaCreate, + }, schema.Migrations[0]) +} + +func TestSchema_CreateUniqueIndex(t *testing.T) { + var schema Schema + + schema.CreateUniqueIndex("products", "sale_idx", []string{"sale"}) + assert.Equal(t, Index{ + Table: "products", + Name: "sale_idx", + Unique: true, + Columns: []string{"sale"}, + Op: SchemaCreate, + }, schema.Migrations[0]) +} + +func TestSchema_DropIndex(t *testing.T) { + var schema Schema + + schema.DropIndex("products", "sale") + + assert.Equal(t, Index{ + Table: "products", + Name: "sale", + Op: SchemaDrop, + }, schema.Migrations[0]) +} + +func TestSchema_Exec(t *testing.T) { + var schema Schema + + schema.Exec("RAW SQL") + assert.Equal(t, Raw("RAW SQL"), schema.Migrations[0]) +} + +func TestSchema_Do(t *testing.T) { + var ( + schema Schema + ) + + schema.Do(func(repo Repository) error { return nil }) + assert.NotNil(t, schema.Migrations[0]) +} diff --git a/table.go b/table.go new file mode 100644 index 00000000..54c5c10a --- /dev/null +++ b/table.go @@ -0,0 +1,177 @@ +package rel + +// TableDefinition interface. +type TableDefinition interface { + internalTableDefinition() +} + +// Table definition. +type Table struct { + Op SchemaOp + Name string + Rename string + Definitions []TableDefinition + Optional bool + Options string +} + +// Column defines a column with name and type. +func (t *Table) Column(name string, typ ColumnType, options ...ColumnOption) { + t.Definitions = append(t.Definitions, createColumn(name, typ, options)) +} + +// ID defines a column with name and ID type. +// the resulting database type will depends on database. +func (t *Table) ID(name string, options ...ColumnOption) { + t.Column(name, ID, options...) +} + +// Bool defines a column with name and Bool type. +func (t *Table) Bool(name string, options ...ColumnOption) { + t.Column(name, Bool, options...) +} + +// Int defines a column with name and Int type. +func (t *Table) Int(name string, options ...ColumnOption) { + t.Column(name, Int, options...) +} + +// BigInt defines a column with name and BigInt type. +func (t *Table) BigInt(name string, options ...ColumnOption) { + t.Column(name, BigInt, options...) +} + +// Float defines a column with name and Float type. +func (t *Table) Float(name string, options ...ColumnOption) { + t.Column(name, Float, options...) +} + +// Decimal defines a column with name and Decimal type. +func (t *Table) Decimal(name string, options ...ColumnOption) { + t.Column(name, Decimal, options...) +} + +// String defines a column with name and String type. +func (t *Table) String(name string, options ...ColumnOption) { + t.Column(name, String, options...) +} + +// Text defines a column with name and Text type. +func (t *Table) Text(name string, options ...ColumnOption) { + t.Column(name, Text, options...) +} + +// Date defines a column with name and Date type. +func (t *Table) Date(name string, options ...ColumnOption) { + t.Column(name, Date, options...) +} + +// DateTime defines a column with name and DateTime type. +func (t *Table) DateTime(name string, options ...ColumnOption) { + t.Column(name, DateTime, options...) +} + +// Time defines a column with name and Time type. +func (t *Table) Time(name string, options ...ColumnOption) { + t.Column(name, Time, options...) +} + +// Timestamp defines a column with name and Timestamp type. +func (t *Table) Timestamp(name string, options ...ColumnOption) { + t.Column(name, Timestamp, options...) +} + +// PrimaryKey defines a primary key for table. +func (t *Table) PrimaryKey(column string, options ...KeyOption) { + t.PrimaryKeys([]string{column}, options...) +} + +// PrimaryKeys defines composite primary keys for table. +func (t *Table) PrimaryKeys(columns []string, options ...KeyOption) { + t.Definitions = append(t.Definitions, createPrimaryKeys(columns, options)) +} + +// ForeignKey defines foreign key index. +func (t *Table) ForeignKey(column string, refTable string, refColumn string, options ...KeyOption) { + t.Definitions = append(t.Definitions, createForeignKey(column, refTable, refColumn, options)) +} + +// Unique defines an unique key for columns. +func (t *Table) Unique(columns []string, options ...KeyOption) { + t.Definitions = append(t.Definitions, createKeys(columns, UniqueKey, options)) +} + +// Fragment defines anything using sql fragment. +func (t *Table) Fragment(fragment string) { + t.Definitions = append(t.Definitions, Raw(fragment)) +} + +func (t Table) internalMigration() {} + +// AlterTable Migrator. +type AlterTable struct { + Table +} + +// RenameColumn to a new name. +func (at *AlterTable) RenameColumn(name string, newName string, options ...ColumnOption) { + at.Definitions = append(at.Definitions, renameColumn(name, newName, options)) +} + +// DropColumn from this table. +func (at *AlterTable) DropColumn(name string, options ...ColumnOption) { + at.Definitions = append(at.Definitions, dropColumn(name, options)) +} + +func createTable(name string, options []TableOption) Table { + table := Table{ + Op: SchemaCreate, + Name: name, + } + + applyTableOptions(&table, options) + return table +} + +func createTableIfNotExists(name string, options []TableOption) Table { + table := createTable(name, options) + table.Optional = true + return table +} + +func alterTable(name string, options []TableOption) AlterTable { + table := Table{ + Op: SchemaAlter, + Name: name, + } + + applyTableOptions(&table, options) + return AlterTable{Table: table} +} + +func renameTable(name string, newName string, options []TableOption) Table { + table := Table{ + Op: SchemaRename, + Name: name, + Rename: newName, + } + + applyTableOptions(&table, options) + return table +} + +func dropTable(name string, options []TableOption) Table { + table := Table{ + Op: SchemaDrop, + Name: name, + } + + applyTableOptions(&table, options) + return table +} + +func dropTableIfExists(name string, options []TableOption) Table { + table := dropTable(name, options) + table.Optional = true + return table +} diff --git a/table_test.go b/table_test.go new file mode 100644 index 00000000..f051de50 --- /dev/null +++ b/table_test.go @@ -0,0 +1,177 @@ +package rel + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTable(t *testing.T) { + var table Table + + t.Run("Column", func(t *testing.T) { + table.Column("column", String) + assert.Equal(t, Column{ + Name: "column", + Type: String, + }, table.Definitions[len(table.Definitions)-1]) + }) + + t.Run("Bool", func(t *testing.T) { + table.Bool("boolean") + assert.Equal(t, Column{ + Name: "boolean", + Type: Bool, + }, table.Definitions[len(table.Definitions)-1]) + }) + + t.Run("Int", func(t *testing.T) { + table.Int("integer") + assert.Equal(t, Column{ + Name: "integer", + Type: Int, + }, table.Definitions[len(table.Definitions)-1]) + }) + + t.Run("BigInt", func(t *testing.T) { + table.BigInt("bigint") + assert.Equal(t, Column{ + Name: "bigint", + Type: BigInt, + }, table.Definitions[len(table.Definitions)-1]) + }) + + t.Run("Float", func(t *testing.T) { + table.Float("float") + assert.Equal(t, Column{ + Name: "float", + Type: Float, + }, table.Definitions[len(table.Definitions)-1]) + }) + + t.Run("Decimal", func(t *testing.T) { + table.Decimal("decimal") + assert.Equal(t, Column{ + Name: "decimal", + Type: Decimal, + }, table.Definitions[len(table.Definitions)-1]) + }) + + t.Run("String", func(t *testing.T) { + table.String("string") + assert.Equal(t, Column{ + Name: "string", + Type: String, + }, table.Definitions[len(table.Definitions)-1]) + }) + + t.Run("Text", func(t *testing.T) { + table.Text("text") + assert.Equal(t, Column{ + Name: "text", + Type: Text, + }, table.Definitions[len(table.Definitions)-1]) + }) + + t.Run("Date", func(t *testing.T) { + table.Date("date") + assert.Equal(t, Column{ + Name: "date", + Type: Date, + }, table.Definitions[len(table.Definitions)-1]) + }) + + t.Run("DateTime", func(t *testing.T) { + table.DateTime("datetime") + assert.Equal(t, Column{ + Name: "datetime", + Type: DateTime, + }, table.Definitions[len(table.Definitions)-1]) + }) + + t.Run("Time", func(t *testing.T) { + table.Time("time") + assert.Equal(t, Column{ + Name: "time", + Type: Time, + }, table.Definitions[len(table.Definitions)-1]) + }) + + t.Run("Timestamp", func(t *testing.T) { + table.Timestamp("timestamp") + assert.Equal(t, Column{ + Name: "timestamp", + Type: Timestamp, + }, table.Definitions[len(table.Definitions)-1]) + }) + + t.Run("PrimaryKey", func(t *testing.T) { + table.PrimaryKey("id") + assert.Equal(t, Key{ + Columns: []string{"id"}, + Type: PrimaryKey, + }, table.Definitions[len(table.Definitions)-1]) + }) + + t.Run("ForeignKey", func(t *testing.T) { + table.ForeignKey("user_id", "users", "id") + assert.Equal(t, Key{ + Columns: []string{"user_id"}, + Type: ForeignKey, + Reference: ForeignKeyReference{ + Table: "users", + Columns: []string{"id"}, + }, + }, table.Definitions[len(table.Definitions)-1]) + }) + + t.Run("Unique", func(t *testing.T) { + table.Unique([]string{"username"}) + assert.Equal(t, Key{ + Columns: []string{"username"}, + Type: UniqueKey, + }, table.Definitions[len(table.Definitions)-1]) + }) + + t.Run("Fragment", func(t *testing.T) { + table.Fragment("SQL") + assert.Equal(t, Raw("SQL"), table.Definitions[len(table.Definitions)-1]) + }) +} + +func TestAlterTable(t *testing.T) { + var table AlterTable + + t.Run("RenameColumn", func(t *testing.T) { + table.RenameColumn("column", "new_column") + assert.Equal(t, Column{ + Op: SchemaRename, + Name: "column", + Rename: "new_column", + }, table.Definitions[len(table.Definitions)-1]) + }) + + t.Run("DropColumn", func(t *testing.T) { + table.DropColumn("column") + assert.Equal(t, Column{ + Op: SchemaDrop, + Name: "column", + }, table.Definitions[len(table.Definitions)-1]) + }) +} + +func TestCreateTable(t *testing.T) { + var ( + options = []TableOption{ + Options("options"), + Optional(true), + } + table = createTable("table", options) + ) + + assert.Equal(t, Table{ + Name: "table", + Optional: true, + Options: "options", + }, table) +} diff --git a/util.go b/util.go index f10fb098..400e5e5e 100644 --- a/util.go +++ b/util.go @@ -37,7 +37,7 @@ func isZero(value interface{}) bool { case nil: zero = true case bool: - zero = v == false + zero = !v case string: zero = v == "" case int: