Skip to content

Commit

Permalink
Add Initialize() to Driver interface, and add integration tests for D…
Browse files Browse the repository at this point in the history
…rop() between database implementations and migrate
  • Loading branch information
Lukas Jørgensen (LUJOR) committed Feb 19, 2019
1 parent 45d3ba3 commit 3fd5314
Show file tree
Hide file tree
Showing 23 changed files with 270 additions and 62 deletions.
14 changes: 9 additions & 5 deletions database/cassandra/cassandra.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,22 +54,26 @@ func WithInstance(session *gocql.Session, config *Config) (database.Driver, erro
return nil, ErrClosedSession
}

if len(config.MigrationsTable) == 0 {
config.MigrationsTable = DefaultMigrationsTable
}

c := &Cassandra{
session: session,
config: config,
}

if err := c.ensureVersionTable(); err != nil {
if err := c.Initialize(); err != nil {
return nil, err
}

return c, nil
}

func (c *Cassandra) Initialize() error {
if len(c.config.MigrationsTable) == 0 {
c.config.MigrationsTable = DefaultMigrationsTable
}

return c.ensureVersionTable()
}

func (c *Cassandra) Open(url string) (database.Driver, error) {
u, err := nurl.Parse(url)
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions database/clickhouse/clickhouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func WithInstance(conn *sql.DB, config *Config) (database.Driver, error) {
config: config,
}

if err := ch.init(); err != nil {
if err := ch.Initialize(); err != nil {
return nil, err
}

Expand Down Expand Up @@ -75,14 +75,14 @@ func (ch *ClickHouse) Open(dsn string) (database.Driver, error) {
},
}

if err := ch.init(); err != nil {
if err := ch.Initialize(); err != nil {
return nil, err
}

return ch, nil
}

func (ch *ClickHouse) init() error {
func (ch *ClickHouse) Initialize() error {
if len(ch.config.DatabaseName) == 0 {
if err := ch.conn.QueryRow("SELECT currentDatabase()").Scan(&ch.config.DatabaseName); err != nil {
return err
Expand Down
32 changes: 20 additions & 12 deletions database/cockroachdb/cockroachdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,29 +73,37 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {

config.DatabaseName = databaseName

if len(config.MigrationsTable) == 0 {
config.MigrationsTable = DefaultMigrationsTable
px := &CockroachDb{
db: instance,
config: config,
}

if len(config.LockTable) == 0 {
config.LockTable = DefaultLockTable
if err := px.Initialize(); err != nil {
return nil, err
}

px := &CockroachDb{
db: instance,
config: config,
return px, nil
}

func (c *CockroachDb) Initialize() error {
if len(c.config.MigrationsTable) == 0 {
c.config.MigrationsTable = DefaultMigrationsTable
}

if len(c.config.LockTable) == 0 {
c.config.LockTable = DefaultLockTable
}

// ensureVersionTable is a locking operation, so we need to ensureLockTable before we ensureVersionTable.
if err := px.ensureLockTable(); err != nil {
return nil, err
if err := c.ensureLockTable(); err != nil {
return err
}

if err := px.ensureVersionTable(); err != nil {
return nil, err
if err := c.ensureVersionTable(); err != nil {
return err
}

return px, nil
return nil
}

func (c *CockroachDb) Open(url string) (database.Driver, error) {
Expand Down
16 changes: 16 additions & 0 deletions database/cockroachdb/cockroachdb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,22 @@ func Test(t *testing.T) {
}
dt.Test(t, d, []byte("SELECT 1"))
})
dktesting.ParallelTest(t, specs, func(t *testing.T, ci dktest.ContainerInfo) {
createDB(t, ci)

ip, port, err := ci.Port(26257)
if err != nil {
t.Fatal(err)
}

addr := fmt.Sprintf("cockroach://root@%v:%v/migrate?sslmode=disable", ip, port)
c := &CockroachDb{}
d, err := c.Open(addr)
if err != nil {
t.Fatalf("%v", err)
}
dt.TestMigrate(t, d, []byte("SELECT 1"))
})
}

func TestMultiStatement(t *testing.T) {
Expand Down
7 changes: 7 additions & 0 deletions database/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,15 @@ type Driver interface {
// Open returns a new driver instance configured with parameters
// coming from the URL string. Migrate will call this function
// only once per instance.
// This will also call Initialize().
Open(url string) (Driver, error)

// Initialize makes sure the database is ready for migrations, this
// might include creating some tables for migration/lock management
// or initializing some files.
// This assumes there is an open connection to a database.
Initialize() error

// Close closes the underlying database instance managed by the driver.
// Migrate will call this function only once per instance.
Close() error
Expand Down
18 changes: 12 additions & 6 deletions database/mongodb/mongodb.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,26 @@ func WithInstance(instance *mongo.Client, config *Config) (database.Driver, erro
if len(config.DatabaseName) == 0 {
return nil, ErrNoDatabaseName
}
if len(config.MigrationsCollection) == 0 {
config.MigrationsCollection = DefaultMigrationsCollection
}
mc := &Mongo{
client: instance,
db: instance.Database(config.DatabaseName),
config: config,
}

if err := mc.Initialize(); err != nil {
return nil, err
}

return mc, nil
}

func (m *Mongo) Initialize() error {
if len(m.config.MigrationsCollection) == 0 {
m.config.MigrationsCollection = DefaultMigrationsCollection
}
return nil
}

func (m *Mongo) Open(dsn string) (database.Driver, error) {
//connsting is experimental package, but it used for parse connection string in mongo.Connect function
uri, err := connstring.Parse(dsn)
Expand All @@ -77,9 +86,6 @@ func (m *Mongo) Open(dsn string) (database.Driver, error) {
return nil, err
}
migrationsCollection := purl.Query().Get("x-migrations-collection")
if len(migrationsCollection) == 0 {
migrationsCollection = DefaultMigrationsCollection
}

transactionMode, _ := strconv.ParseBool(purl.Query().Get("x-transaction-mode"))

Expand Down
29 changes: 29 additions & 0 deletions database/mongodb/mongodb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,35 @@ func Test(t *testing.T) {
dt.TestSetVersion(t, d)
dt.TestDrop(t, d)
})
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}

addr := mongoConnectionString(ip, port)
p := &Mongo{}
d, err := p.Open(addr)
if err != nil {
t.Fatalf("%v", err)
}
defer d.Close()
dt.TestNilVersion(t, d)
//TestLockAndUnlock(t, d) driver doesn't support lock on database level
dt.TestRun(t, d, bytes.NewReader([]byte(`[{"insert":"hello","documents":[{"wild":"world"}]}]`)))
dt.TestSetVersion(t, d)
dt.TestDrop(t, d)
// Reinitialize for new round of tests
err = d.Drop()
if err != nil {
t.Fatalf("%v", err)
}
err = d.Initialize()
if err != nil {
t.Fatalf("%v", err)
}
dt.TestMigrate(t, d, []byte(`[{"insert":"hello","documents":[{"wild":"world"}]}]`))
})
}

func TestWithAuth(t *testing.T) {
Expand Down
17 changes: 9 additions & 8 deletions database/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,6 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {

config.DatabaseName = databaseName.String

if len(config.MigrationsTable) == 0 {
config.MigrationsTable = DefaultMigrationsTable
}

conn, err := instance.Conn(context.Background())
if err != nil {
return nil, err
Expand All @@ -91,13 +87,21 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
config: config,
}

if err := mx.ensureVersionTable(); err != nil {
if err := mx.Initialize(); err != nil {
return nil, err
}

return mx, nil
}

func (m *Mysql) Initialize() error {
if len(m.config.MigrationsTable) == 0 {
m.config.MigrationsTable = DefaultMigrationsTable
}

return m.ensureVersionTable()
}

// urlToMySQLConfig takes a net/url URL and returns a go-sql-driver/mysql Config.
// Manually sets username and password to avoid net/url from url-encoding the reserved URL characters
func urlToMySQLConfig(u nurl.URL) (*mysql.Config, error) {
Expand Down Expand Up @@ -128,9 +132,6 @@ func (m *Mysql) Open(url string) (database.Driver, error) {
purl.RawQuery = q.Encode()

migrationsTable := purl.Query().Get("x-migrations-table")
if len(migrationsTable) == 0 {
migrationsTable = DefaultMigrationsTable
}

// use custom TLS?
ctls := purl.Query().Get("tls")
Expand Down
10 changes: 10 additions & 0 deletions database/mysql/mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,16 @@ func Test(t *testing.T) {
}
defer d.Close()
dt.Test(t, d, []byte("SELECT 1"))
// Reinitialize for new round of tests
err = d.Drop()
if err != nil {
t.Fatalf("%v", err)
}
err = d.Initialize()
if err != nil {
t.Fatalf("%v", err)
}
dt.TestMigrate(t, d, []byte("SELECT 1"))

// check ensureVersionTable
if err := d.(*Mysql).ensureVersionTable(); err != nil {
Expand Down
14 changes: 10 additions & 4 deletions database/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
config: config,
}

if err := px.ensureVersionTable(); err != nil {
err = px.Initialize()
if err != nil {
return nil, err
}

Expand All @@ -117,9 +118,6 @@ func (p *Postgres) Open(url string) (database.Driver, error) {
}

migrationsTable := purl.Query().Get("x-migrations-table")
if len(migrationsTable) == 0 {
migrationsTable = DefaultMigrationsTable
}

px, err := WithInstance(db, &Config{
DatabaseName: purl.Path,
Expand All @@ -132,6 +130,14 @@ func (p *Postgres) Open(url string) (database.Driver, error) {
return px, nil
}

func (p *Postgres) Initialize() error {
if len(p.config.MigrationsTable) == 0 {
p.config.MigrationsTable = DefaultMigrationsTable
}

return p.ensureVersionTable()
}

func (p *Postgres) Close() error {
connErr := p.conn.Close()
dbErr := p.db.Close()
Expand Down
16 changes: 16 additions & 0 deletions database/postgres/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,22 @@ func Test(t *testing.T) {
defer d.Close()
dt.Test(t, d, []byte("SELECT 1"))
})

dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}

addr := pgConnectionString(ip, port)
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatalf("%v", err)
}
defer d.Close()
dt.TestMigrate(t, d, []byte("SELECT 1"))
})
}

func TestMultiStatement(t *testing.T) {
Expand Down
11 changes: 7 additions & 4 deletions database/ql/ql.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,22 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
if err := instance.Ping(); err != nil {
return nil, err
}
if len(config.MigrationsTable) == 0 {
config.MigrationsTable = DefaultMigrationsTable
}

mx := &Ql{
db: instance,
config: config,
}
if err := mx.ensureVersionTable(); err != nil {
if err := mx.Initialize(); err != nil {
return nil, err
}
return mx, nil
}
func (m *Ql) Initialize() error {
if len(m.config.MigrationsTable) == 0 {
m.config.MigrationsTable = DefaultMigrationsTable
}
return m.ensureVersionTable()
}
// ensureVersionTable checks if versions table exists and, if not, creates it.
// Note that this function locks the database, which deviates from the usual
// convention of "caller locks" in the Ql type.
Expand Down
10 changes: 10 additions & 0 deletions database/ql/ql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,16 @@ func Test(t *testing.T) {
}
}()
dt.Test(t, d, []byte("CREATE TABLE t (Qty int, Name string);"))
// Reinitialize for new round of tests
err = d.Drop()
if err != nil {
t.Fatalf("%v", err)
}
err = d.Initialize()
if err != nil {
t.Fatalf("%v", err)
}
dt.TestMigrate(t, d, []byte("CREATE TABLE t (Qty int, Name string);"))
driver, err := WithInstance(db, &Config{})
if err != nil {
t.Fatalf("%v", err)
Expand Down

0 comments on commit 3fd5314

Please sign in to comment.