diff --git a/database/cassandra/cassandra.go b/database/cassandra/cassandra.go index 48b4a693b..87abd2e66 100644 --- a/database/cassandra/cassandra.go +++ b/database/cassandra/cassandra.go @@ -12,6 +12,7 @@ import ( "github.com/gocql/gocql" "github.com/golang-migrate/migrate/v4/database" + "github.com/hashicorp/go-multierror" ) func init() { @@ -240,13 +241,29 @@ func (c *Cassandra) Drop() error { return err } } - // Re-create the version table - return c.ensureVersionTable() + + return nil } -// Ensure version table exists -func (c *Cassandra) ensureVersionTable() error { - err := c.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", c.config.MigrationsTable)).Exec() +// ensureVersionTable checks if versions table exists and, if not, creates it. +// Note that this function locks the database, which deviates from the usual +// convention of "caller locks" in the Cassandra type. +func (c *Cassandra) ensureVersionTable() (err error) { + if err = c.Lock(); err != nil { + return err + } + + defer func() { + if e := c.Unlock(); e != nil { + if err == nil { + err = e + } else { + err = multierror.Append(err, e) + } + } + }() + + err = c.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", c.config.MigrationsTable)).Exec() if err != nil { return err } diff --git a/database/clickhouse/clickhouse.go b/database/clickhouse/clickhouse.go index 6f98bd181..ebf5b17d6 100644 --- a/database/clickhouse/clickhouse.go +++ b/database/clickhouse/clickhouse.go @@ -11,6 +11,7 @@ import ( "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database" + "github.com/hashicorp/go-multierror" ) var DefaultMigrationsTable = "schema_migrations" @@ -159,7 +160,25 @@ func (ch *ClickHouse) SetVersion(version int, dirty bool) error { return tx.Commit() } -func (ch *ClickHouse) ensureVersionTable() error { + +// ensureVersionTable checks if versions table exists and, if not, creates it. +// Note that this function locks the database, which deviates from the usual +// convention of "caller locks" in the ClickHouse type. +func (ch *ClickHouse) ensureVersionTable() (err error) { + if err = ch.Lock(); err != nil { + return err + } + + defer func() { + if e := ch.Unlock(); e != nil { + if err == nil { + err = e + } else { + err = multierror.Append(err, e) + } + } + }() + var ( table string query = "SHOW TABLES FROM " + ch.config.DatabaseName + " LIKE '" + ch.config.MigrationsTable + "'" @@ -207,7 +226,7 @@ func (ch *ClickHouse) Drop() error { return &database.Error{OrigErr: err, Query: []byte(query)} } } - return ch.ensureVersionTable() + return nil } func (ch *ClickHouse) Lock() error { return nil } diff --git a/database/cockroachdb/cockroachdb.go b/database/cockroachdb/cockroachdb.go index df32db0d8..41379384f 100644 --- a/database/cockroachdb/cockroachdb.go +++ b/database/cockroachdb/cockroachdb.go @@ -13,6 +13,7 @@ import ( import ( "github.com/cockroachdb/cockroach-go/crdb" + "github.com/hashicorp/go-multierror" "github.com/lib/pq" ) @@ -294,15 +295,29 @@ func (c *CockroachDb) Drop() error { return &database.Error{OrigErr: err, Query: []byte(query)} } } - if err := c.ensureVersionTable(); err != nil { - return err - } } return nil } -func (c *CockroachDb) ensureVersionTable() error { +// ensureVersionTable checks if versions table exists and, if not, creates it. +// Note that this function locks the database, which deviates from the usual +// convention of "caller locks" in the CockroachDb type. +func (c *CockroachDb) ensureVersionTable() (err error) { + if err = c.Lock(); err != nil { + return err + } + + defer func() { + if e := c.Unlock(); e != nil { + if err == nil { + err = e + } else { + err = multierror.Append(err, e) + } + } + }() + // check if migration table exists var count int query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` diff --git a/database/mysql/mysql.go b/database/mysql/mysql.go index 20c840e02..6d6a1907b 100644 --- a/database/mysql/mysql.go +++ b/database/mysql/mysql.go @@ -17,6 +17,7 @@ import ( import ( "github.com/go-sql-driver/mysql" + "github.com/hashicorp/go-multierror" ) import ( @@ -342,15 +343,29 @@ func (m *Mysql) Drop() error { return &database.Error{OrigErr: err, Query: []byte(query)} } } - if err := m.ensureVersionTable(); err != nil { - return err - } } return nil } -func (m *Mysql) ensureVersionTable() error { +// ensureVersionTable checks if versions table exists and, if not, creates it. +// Note that this function locks the database, which deviates from the usual +// convention of "caller locks" in the Mysql type. +func (m *Mysql) ensureVersionTable() (err error) { + if err = m.Lock(); err != nil { + return err + } + + defer func() { + if e := m.Unlock(); e != nil { + if err == nil { + err = e + } else { + err = multierror.Append(err, e) + } + } + }() + // check if migration table exists var result string query := `SHOW TABLES LIKE "` + m.config.MigrationsTable + `"` diff --git a/database/postgres/postgres.go b/database/postgres/postgres.go index 52af6f372..5ed84337b 100644 --- a/database/postgres/postgres.go +++ b/database/postgres/postgres.go @@ -325,14 +325,14 @@ func (p *Postgres) Drop() error { return &database.Error{OrigErr: err, Query: []byte(query)} } } - if err := p.ensureVersionTable(); err != nil { - return err - } } return nil } +// ensureVersionTable checks if versions table exists and, if not, creates it. +// Note that this function locks the database, which deviates from the usual +// convention of "caller locks" in the Postgres type. func (p *Postgres) ensureVersionTable() (err error) { if err = p.Lock(); err != nil { return err diff --git a/database/ql/ql.go b/database/ql/ql.go index 86b2364dd..97a38bb25 100644 --- a/database/ql/ql.go +++ b/database/ql/ql.go @@ -3,6 +3,7 @@ package ql import ( "database/sql" "fmt" + "github.com/hashicorp/go-multierror" "io" "io/ioutil" "strings" @@ -59,7 +60,24 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { } return mx, nil } -func (m *Ql) ensureVersionTable() error { +// ensureVersionTable checks if versions table exists and, if not, creates it. +// Note that this function locks the database, which deviates from the usual +// convention of "caller locks" in the Ql type. +func (m *Ql) ensureVersionTable() (err error) { + if err = m.Lock(); err != nil { + return err + } + + defer func() { + if e := m.Unlock(); e != nil { + if err == nil { + err = e + } else { + err = multierror.Append(err, e) + } + } + }() + tx, err := m.db.Begin() if err != nil { return err @@ -132,9 +150,6 @@ func (m *Ql) Drop() error { return &database.Error{OrigErr: err, Query: []byte(query)} } } - if err := m.ensureVersionTable(); err != nil { - return err - } } return nil diff --git a/database/redshift/redshift.go b/database/redshift/redshift.go index 19f1b9f78..27bd8347f 100644 --- a/database/redshift/redshift.go +++ b/database/redshift/redshift.go @@ -14,6 +14,7 @@ import ( "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database" + "github.com/hashicorp/go-multierror" "github.com/lib/pq" ) @@ -282,15 +283,29 @@ func (p *Redshift) Drop() error { return &database.Error{OrigErr: err, Query: []byte(query)} } } - if err := p.ensureVersionTable(); err != nil { - return err - } } return nil } -func (p *Redshift) ensureVersionTable() error { +// ensureVersionTable checks if versions table exists and, if not, creates it. +// Note that this function locks the database, which deviates from the usual +// convention of "caller locks" in the Redshift type. +func (p *Redshift) ensureVersionTable() (err error) { + if err = p.Lock(); err != nil { + return err + } + + defer func() { + if e := p.Unlock(); e != nil { + if err == nil { + err = e + } else { + err = multierror.Append(err, e) + } + } + }() + // check if migration table exists var count int query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` diff --git a/database/spanner/spanner.go b/database/spanner/spanner.go index 84ff25224..f5983433b 100644 --- a/database/spanner/spanner.go +++ b/database/spanner/spanner.go @@ -17,6 +17,7 @@ import ( "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database" + "github.com/hashicorp/go-multierror" "google.golang.org/api/iterator" adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1" ) @@ -255,14 +256,27 @@ func (s *Spanner) Drop() error { return &database.Error{OrigErr: err, Query: []byte(strings.Join(stmts, "; "))} } - if err := s.ensureVersionTable(); err != nil { + return nil +} + +// ensureVersionTable checks if versions table exists and, if not, creates it. +// Note that this function locks the database, which deviates from the usual +// convention of "caller locks" in the Spanner type. +func (s *Spanner) ensureVersionTable() (err error) { + if err = s.Lock(); err != nil { return err } - return nil -} + defer func() { + if e := s.Unlock(); e != nil { + if err == nil { + err = e + } else { + err = multierror.Append(err, e) + } + } + }() -func (s *Spanner) ensureVersionTable() error { ctx := context.Background() tbl := s.config.MigrationsTable iter := s.db.data.Single().Read(ctx, tbl, spanner.AllKeys(), []string{"Version"}) diff --git a/database/sqlite3/sqlite3.go b/database/sqlite3/sqlite3.go index d65fe8070..3f33e7e17 100644 --- a/database/sqlite3/sqlite3.go +++ b/database/sqlite3/sqlite3.go @@ -10,6 +10,7 @@ import ( "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database" + "github.com/hashicorp/go-multierror" _ "github.com/mattn/go-sqlite3" ) @@ -58,7 +59,23 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { return mx, nil } -func (m *Sqlite) ensureVersionTable() error { +// ensureVersionTable checks if versions table exists and, if not, creates it. +// Note that this function locks the database, which deviates from the usual +// convention of "caller locks" in the Sqlite type. +func (m *Sqlite) ensureVersionTable() (err error) { + if err = m.Lock(); err != nil { + return err + } + + defer func() { + if e := m.Unlock(); e != nil { + if err == nil { + err = e + } else { + err = multierror.Append(err, e) + } + } + }() query := fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s (version uint64,dirty bool); @@ -125,9 +142,6 @@ func (m *Sqlite) Drop() error { return &database.Error{OrigErr: err, Query: []byte(query)} } } - if err := m.ensureVersionTable(); err != nil { - return err - } query := "VACUUM" _, err = m.db.Query(query) if err != nil {