Skip to content

Commit

Permalink
Merge pull request #2846 from infrahq/dnephin/data-reaplce-gorm-migrator
Browse files Browse the repository at this point in the history
Replace gorm.Migrator calls with sql
  • Loading branch information
dnephin committed Aug 9, 2022
2 parents 8b8e96d + f1b3126 commit b7a5bd1
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 86 deletions.
54 changes: 33 additions & 21 deletions internal/server/data/migrations.go
Expand Up @@ -22,20 +22,22 @@ func migrations() []*migrator.Migration {
{
ID: "202204281130",
Migrate: func(tx *gorm.DB) error {
return tx.Migrator().DropColumn(&models.Settings{}, "signup_enabled")
stmt := `ALTER TABLE settings DROP COLUMN IF EXISTS signup_enabled`
if tx.Dialector.Name() == "sqlite" {
stmt = `ALTER TABLE settings DROP COLUMN signup_enabled`
}
return tx.Exec(stmt).Error
},
},
// #1657: get rid of identity kind
{
ID: "202204291613",
Migrate: func(tx *gorm.DB) error {
if tx.Migrator().HasColumn(&models.Identity{}, "kind") {
if err := tx.Migrator().DropColumn(&models.Identity{}, "kind"); err != nil {
return err
}
stmt := `ALTER TABLE identities DROP COLUMN IF EXISTS kind`
if tx.Dialector.Name() == "sqlite" {
stmt = `ALTER TABLE identities DROP COLUMN kind`
}

return nil
return tx.Exec(stmt).Error
},
},
// drop old Groups index; new index will be created automatically
Expand Down Expand Up @@ -101,11 +103,12 @@ func addKindToProviders() *migrator.Migration {
return &migrator.Migration{
ID: "202206151027",
Migrate: func(tx *gorm.DB) error {
if !tx.Migrator().HasColumn(&models.Provider{}, "kind") {
logging.Debugf("migrating provider table kind")
if err := tx.Migrator().AddColumn(&models.Provider{}, "kind"); err != nil {
return err
}
stmt := `ALTER TABLE providers ADD COLUMN IF NOT EXISTS kind text`
if tx.Dialector.Name() == "sqlite" {
stmt = `ALTER TABLE providers ADD COLUMN kind text`
}
if err := tx.Exec(stmt).Error; err != nil {
return err
}

db := tx.Begin()
Expand All @@ -122,7 +125,13 @@ func dropCertificateTables() *migrator.Migration {
return &migrator.Migration{
ID: "202206161733",
Migrate: func(tx *gorm.DB) error {
return tx.Migrator().DropTable("trusted_certificates", "root_certificates")
if err := tx.Exec(`DROP TABLE IF EXISTS trusted_certificates`).Error; err != nil {
return err
}
if err := tx.Exec(`DROP TABLE IF EXISTS root_certificates`).Error; err != nil {
return err
}
return nil
},
}
}
Expand All @@ -132,12 +141,12 @@ func addAuthURLAndScopeToProviders() *migrator.Migration {
return &migrator.Migration{
ID: "202206281027",
Migrate: func(tx *gorm.DB) error {
if !tx.Migrator().HasColumn(&models.Provider{}, "scopes") {
if !migrator.HasColumn(tx, "providers", "scopes") {
logging.Debugf("migrating provider table auth URL and scopes")
if err := tx.Migrator().AddColumn(&models.Provider{}, "auth_url"); err != nil {
if err := tx.Exec(`ALTER TABLE providers ADD COLUMN auth_url text`).Error; err != nil {
return err
}
if err := tx.Migrator().AddColumn(&models.Provider{}, "scopes"); err != nil {
if err := tx.Exec(`ALTER TABLE providers ADD COLUMN scopes text`).Error; err != nil {
return err
}

Expand Down Expand Up @@ -202,14 +211,17 @@ func setDestinationLastSeenAt() *migrator.Migration {
return &migrator.Migration{
ID: "202207041724",
Migrate: func(tx *gorm.DB) error {
if tx.Migrator().HasColumn(&models.Destination{}, "last_seen_at") {
if migrator.HasColumn(tx, "destinations", "last_seen_at") {
return nil
}

if err := tx.Migrator().AddColumn(&models.Destination{}, "last_seen_at"); err != nil {
stmt := `ALTER TABLE destinations ADD COLUMN last_seen_at timestamp with time zone`
if tx.Dialector.Name() == "sqlite" {
stmt = `ALTER TABLE destinations ADD COLUMN last_seen_at datetime`
}
if err := tx.Exec(stmt).Error; err != nil {
return err
}

return tx.Exec("UPDATE destinations SET last_seen_at = updated_at").Error
},
}
Expand All @@ -220,11 +232,11 @@ func dropDeletedProviderUsers() *migrator.Migration {
return &migrator.Migration{
ID: "202207270000",
Migrate: func(tx *gorm.DB) error {
if tx.Migrator().HasColumn(&models.ProviderUser{}, "deleted_at") {
if migrator.HasColumn(tx, "provider_users", "deleted_at") {
if err := tx.Exec("DELETE FROM provider_users WHERE deleted_at IS NOT NULL").Error; err != nil {
return fmt.Errorf("could not remove soft deleted provider users: %w", err)
}
return tx.Migrator().DropColumn(&models.ProviderUser{}, "deleted_at")
return tx.Exec(`ALTER TABLE provider_users DROP COLUMN deleted_at`).Error
}
return nil
},
Expand Down
21 changes: 5 additions & 16 deletions internal/server/data/migrations_test.go
Expand Up @@ -86,8 +86,7 @@ func TestMigrations(t *testing.T) {
{
label: testCaseLine("202204281130"),
expected: func(t *testing.T, tx *gorm.DB) {
hasCol := tx.Migrator().HasColumn("settings", "signup_enabled")
assert.Assert(t, !hasCol)
// dropped columns are tested by schema comparison
},
},
{
Expand Down Expand Up @@ -143,12 +142,12 @@ func TestMigrations(t *testing.T) {
label: testCaseLine("202206161733"),
setup: func(t *testing.T, db *gorm.DB) {
// integrity check
assert.Assert(t, tableExists(t, db, "trusted_certificates"))
assert.Assert(t, tableExists(t, db, "root_certificates"))
assert.Assert(t, migrator.HasTable(db, "trusted_certificates"))
assert.Assert(t, migrator.HasTable(db, "root_certificates"))
},
expected: func(t *testing.T, db *gorm.DB) {
assert.Assert(t, !tableExists(t, db, "trusted_certificates"))
assert.Assert(t, !tableExists(t, db, "root_certificates"))
assert.Assert(t, !migrator.HasTable(db, "trusted_certificates"))
assert.Assert(t, !migrator.HasTable(db, "root_certificates"))
},
},
{
Expand Down Expand Up @@ -467,16 +466,6 @@ type testCaseLabel struct {
Line string
}

func tableExists(t *testing.T, db *gorm.DB, name string) bool {
t.Helper()
var count int
err := db.Raw("SELECT count(id) FROM " + name).Row().Scan(&count)
if err != nil {
t.Logf("table exists error: %v", err)
}
return err == nil
}

func dumpSchema(t *testing.T, conn string) string {
t.Helper()
if _, err := exec.LookPath("pg_dump"); err != nil {
Expand Down
57 changes: 57 additions & 0 deletions internal/server/data/migrator/helpers.go
@@ -0,0 +1,57 @@
package migrator

import "gorm.io/gorm"

type Tx interface {
Exec(stmt string, args ...any) *gorm.DB
}

// HasTable returns true if the database has a table with name. Returns
// false if the table does not exist, or if there was a failure querying the
// database.
func HasTable(tx *gorm.DB, name string) bool {
var count int
stmt := `
SELECT count(*)
FROM information_schema.tables
WHERE table_schema = CURRENT_SCHEMA()
AND table_name = ? AND table_type = 'BASE TABLE'
`
if tx.Dialector.Name() == "sqlite" {
stmt = `SELECT count(*) FROM sqlite_master WHERE type = 'table' AND name = ?`
}

if err := tx.Raw(stmt, name).Scan(&count).Error; err != nil {
return false
}
return count != 0
}

// HasColumn returns true if the database table has the column. Returns false if
// the database table does not have the column, or if there was a failure querying
// the database.
func HasColumn(tx *gorm.DB, table string, column string) bool {
var count int

stmt := `
SELECT count(*)
FROM information_schema.columns
WHERE table_schema = CURRENT_SCHEMA()
AND table_name = ? AND column_name = ?
`

if tx.Dialector.Name() == "sqlite" {
stmt = `
SELECT count(*)
FROM sqlite_master
WHERE type = 'table' AND name = ?
AND sql LIKE ?
`
column = "%`" + column + "`%"
}

if err := tx.Raw(stmt, table, column).Scan(&count).Error; err != nil {
return false
}
return count != 0
}
36 changes: 18 additions & 18 deletions internal/server/data/migrator/migrator.go
Expand Up @@ -9,10 +9,10 @@ import (

const initSchemaMigrationID = "SCHEMA_INIT"

// Options define options for all migrations.
// Options used by the Migrator to perform database migrations.
type Options struct {
// UseTransaction makes Migrator execute migrations inside a single transaction.
// Keep in mind that not all databases support DDL commands inside transactions.
// UseTransaction indicates that Migrator should execute all migrations
// inside a single transaction.
UseTransaction bool

// InitSchema is used to create the database when no migrations table exists.
Expand All @@ -26,32 +26,24 @@ type Options struct {
LoadKey func(*gorm.DB) error
}

// Migration represents a database migration (a modification to be made on the database).
// Migration defines a database migration, and an optional rollback.
type Migration struct {
// ID is the migration identifier. Usually a timestamp like "201601021504".
// ID is the migration identifier. Usually a timestamp like "2016-01-02T15:04".
ID string
// Migrate is a function that will br executed while running this migration.
Migrate func(*gorm.DB) error
// Rollback will be executed on rollback. Can be nil.
Rollback func(*gorm.DB) error
}

// Migrator represents a collection of all migrations of a database schema.
// Migrator performs database migrations.
type Migrator struct {
db *gorm.DB
tx *gorm.DB
options Options
migrations []*Migration
}

// DefaultOptions can be used if you don't want to think about options.
var DefaultOptions = Options{
UseTransaction: false,
InitSchema: func(db *gorm.DB) error {
return nil
},
}

// New returns a new Migrator.
func New(db *gorm.DB, options Options, migrations []*Migration) *Migrator {
if options.LoadKey == nil {
Expand All @@ -66,7 +58,15 @@ func New(db *gorm.DB, options Options, migrations []*Migration) *Migrator {
}
}

// Migrate executes all migrations that did not run yet.
// Migrate runs all the migrations that have not yet been applied to the
// database. Migrate may follow one of three flows:
//
// 1. If the initial schema has not yet been applied then Migrate will run
// Options.InitSchema, and then exit.
// 2. If all the migrations have already been applied then Migrate will do
// nothing.
// 3. If there are migrations in the list that have not yet been applied then
// Migrate will run them in order.
func (g *Migrator) Migrate() error {
if g.options.InitSchema == nil && len(g.migrations) == 0 {
return fmt.Errorf("there are no migrations")
Expand Down Expand Up @@ -209,7 +209,7 @@ func (g *Migrator) runMigration(migration *Migration) error {

func (g *Migrator) createMigrationTableIfNotExists() error {
// TODO: replace gorm helper
if g.tx.Migrator().HasTable("migrations") {
if HasTable(g.tx, "migrations") {
return nil
}

Expand All @@ -220,7 +220,7 @@ func (g *Migrator) createMigrationTableIfNotExists() error {
// individually
func (g *Migrator) migrationRan(m *Migration) (bool, error) {
var count int64
err := g.tx.Raw(`select count(id) from migrations where id = ?`, m.ID).Scan(&count).Error
err := g.tx.Raw(`SELECT count(id) FROM migrations WHERE id = ?`, m.ID).Scan(&count).Error
return count > 0, err
}

Expand All @@ -235,7 +235,7 @@ func (g *Migrator) mustInitializeSchema() (bool, error) {

// If the ID doesn't exist, we also want the list of migrations to be empty
var count int64
err = g.tx.Raw(`SELECT count(id) from migrations`).Scan(&count).Error
err = g.tx.Raw(`SELECT count(id) FROM migrations`).Scan(&count).Error
return count == 0, err
}

Expand Down

0 comments on commit b7a5bd1

Please sign in to comment.