Skip to content

Commit

Permalink
fix: migrator run with nil schema (#182)
Browse files Browse the repository at this point in the history
  • Loading branch information
black-06 committed May 19, 2023
1 parent b320a5c commit d616c6a
Showing 1 changed file with 126 additions and 114 deletions.
240 changes: 126 additions & 114 deletions migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,10 @@ func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statem
func (m Migrator) HasIndex(value interface{}, name string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
if stmt.Schema != nil {
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
}
currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table)
return m.DB.Raw(
Expand All @@ -90,33 +92,35 @@ func (m Migrator) HasIndex(value interface{}, name string) bool {

func (m Migrator) CreateIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if idx := stmt.Schema.LookIndex(name); idx != nil {
opts := m.BuildIndexOptions(idx.Fields, stmt)
values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts}
if stmt.Schema != nil {
if idx := stmt.Schema.LookIndex(name); idx != nil {
opts := m.BuildIndexOptions(idx.Fields, stmt)
values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts}

createIndexSQL := "CREATE "
if idx.Class != "" {
createIndexSQL += idx.Class + " "
}
createIndexSQL += "INDEX "
createIndexSQL := "CREATE "
if idx.Class != "" {
createIndexSQL += idx.Class + " "
}
createIndexSQL += "INDEX "

if strings.TrimSpace(strings.ToUpper(idx.Option)) == "CONCURRENTLY" {
createIndexSQL += "CONCURRENTLY "
}
if strings.TrimSpace(strings.ToUpper(idx.Option)) == "CONCURRENTLY" {
createIndexSQL += "CONCURRENTLY "
}

createIndexSQL += "IF NOT EXISTS ? ON ?"
createIndexSQL += "IF NOT EXISTS ? ON ?"

if idx.Type != "" {
createIndexSQL += " USING " + idx.Type + "(?)"
} else {
createIndexSQL += " ?"
}
if idx.Type != "" {
createIndexSQL += " USING " + idx.Type + "(?)"
} else {
createIndexSQL += " ?"
}

if idx.Where != "" {
createIndexSQL += " WHERE " + idx.Where
}
if idx.Where != "" {
createIndexSQL += " WHERE " + idx.Where
}

return m.DB.Exec(createIndexSQL, values...).Error
return m.DB.Exec(createIndexSQL, values...).Error
}
}

return fmt.Errorf("failed to create index with name %v", name)
Expand All @@ -134,8 +138,10 @@ func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error

func (m Migrator) DropIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
if stmt.Schema != nil {
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
}

return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error
Expand All @@ -153,13 +159,15 @@ func (m Migrator) CreateTable(values ...interface{}) (err error) {
}
for _, value := range m.ReorderModels(values, false) {
if err = m.RunWithValue(value, func(stmt *gorm.Statement) error {
for _, field := range stmt.Schema.FieldsByDBName {
if field.Comment != "" {
if err := m.DB.Exec(
"COMMENT ON COLUMN ?.? IS ?",
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)),
).Error; err != nil {
return err
if stmt.Schema != nil {
for _, field := range stmt.Schema.FieldsByDBName {
if field.Comment != "" {
if err := m.DB.Exec(
"COMMENT ON COLUMN ?.? IS ?",
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)),
).Error; err != nil {
return err
}
}
}
}
Expand Down Expand Up @@ -200,13 +208,15 @@ func (m Migrator) AddColumn(value interface{}, field string) error {
m.resetPreparedStmts()

return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(field); field != nil {
if field.Comment != "" {
if err := m.DB.Exec(
"COMMENT ON COLUMN ?.? IS ?",
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)),
).Error; err != nil {
return err
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(field); field != nil {
if field.Comment != "" {
if err := m.DB.Exec(
"COMMENT ON COLUMN ?.? IS ?",
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)),
).Error; err != nil {
return err
}
}
}
}
Expand Down Expand Up @@ -269,101 +279,103 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
// AlterColumn alter value's `field` column' type based on schema definition
func (m Migrator) AlterColumn(value interface{}, field string) error {
err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(field); field != nil {
var (
columnTypes, _ = m.DB.Migrator().ColumnTypes(value)
fieldColumnType *migrator.ColumnType
)
for _, columnType := range columnTypes {
if columnType.Name() == field.DBName {
fieldColumnType, _ = columnType.(*migrator.ColumnType)
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(field); field != nil {
var (
columnTypes, _ = m.DB.Migrator().ColumnTypes(value)
fieldColumnType *migrator.ColumnType
)
for _, columnType := range columnTypes {
if columnType.Name() == field.DBName {
fieldColumnType, _ = columnType.(*migrator.ColumnType)
}
}
}

fileType := clause.Expr{SQL: m.DataTypeOf(field)}
// check for typeName and SQL name
isSameType := true
if fieldColumnType.DatabaseTypeName() != fileType.SQL {
isSameType = false
// if different, also check for aliases
aliases := m.GetTypeAliases(fieldColumnType.DatabaseTypeName())
for _, alias := range aliases {
if strings.HasPrefix(fileType.SQL, alias) {
isSameType = true
break
fileType := clause.Expr{SQL: m.DataTypeOf(field)}
// check for typeName and SQL name
isSameType := true
if fieldColumnType.DatabaseTypeName() != fileType.SQL {
isSameType = false
// if different, also check for aliases
aliases := m.GetTypeAliases(fieldColumnType.DatabaseTypeName())
for _, alias := range aliases {
if strings.HasPrefix(fileType.SQL, alias) {
isSameType = true
break
}
}
}
}

// not same, migrate
if !isSameType {
filedColumnAutoIncrement, _ := fieldColumnType.AutoIncrement()
if field.AutoIncrement && filedColumnAutoIncrement { // update
serialDatabaseType, _ := getSerialDatabaseType(fileType.SQL)
if t, _ := fieldColumnType.ColumnType(); t != serialDatabaseType {
if err := m.UpdateSequence(m.DB, stmt, field, serialDatabaseType); err != nil {
// not same, migrate
if !isSameType {
filedColumnAutoIncrement, _ := fieldColumnType.AutoIncrement()
if field.AutoIncrement && filedColumnAutoIncrement { // update
serialDatabaseType, _ := getSerialDatabaseType(fileType.SQL)
if t, _ := fieldColumnType.ColumnType(); t != serialDatabaseType {
if err := m.UpdateSequence(m.DB, stmt, field, serialDatabaseType); err != nil {
return err
}
}
} else if field.AutoIncrement && !filedColumnAutoIncrement { // create
serialDatabaseType, _ := getSerialDatabaseType(fileType.SQL)
if err := m.CreateSequence(m.DB, stmt, field, serialDatabaseType); err != nil {
return err
}
} else if !field.AutoIncrement && filedColumnAutoIncrement { // delete
if err := m.DeleteSequence(m.DB, stmt, field, fileType); err != nil {
return err
}
} else {
if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? TYPE ?"+m.genUsingExpression(fileType.SQL, fieldColumnType.DatabaseTypeName()),
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType, clause.Column{Name: field.DBName}, fileType).Error; err != nil {
return err
}
}
} else if field.AutoIncrement && !filedColumnAutoIncrement { // create
serialDatabaseType, _ := getSerialDatabaseType(fileType.SQL)
if err := m.CreateSequence(m.DB, stmt, field, serialDatabaseType); err != nil {
return err
}
} else if !field.AutoIncrement && filedColumnAutoIncrement { // delete
if err := m.DeleteSequence(m.DB, stmt, field, fileType); err != nil {
return err
}
} else {
if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? TYPE ?"+m.genUsingExpression(fileType.SQL, fieldColumnType.DatabaseTypeName()),
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType, clause.Column{Name: field.DBName}, fileType).Error; err != nil {
return err
}
}
}

if null, _ := fieldColumnType.Nullable(); null == field.NotNull {
if field.NotNull {
if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET NOT NULL", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil {
return err
}
} else {
if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP NOT NULL", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil {
return err
if null, _ := fieldColumnType.Nullable(); null == field.NotNull {
if field.NotNull {
if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET NOT NULL", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil {
return err
}
} else {
if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP NOT NULL", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil {
return err
}
}
}
}

if uniq, _ := fieldColumnType.Unique(); !uniq && field.Unique {
idxName := clause.Column{Name: m.DB.Config.NamingStrategy.IndexName(stmt.Table, field.DBName)}
// Not a unique constraint but a unique index
if !m.HasIndex(stmt.Table, idxName.Name) {
if err := m.DB.Exec("ALTER TABLE ? ADD CONSTRAINT ? UNIQUE(?)", m.CurrentTable(stmt), idxName, clause.Column{Name: field.DBName}).Error; err != nil {
return err
if uniq, _ := fieldColumnType.Unique(); !uniq && field.Unique {
idxName := clause.Column{Name: m.DB.Config.NamingStrategy.IndexName(stmt.Table, field.DBName)}
// Not a unique constraint but a unique index
if !m.HasIndex(stmt.Table, idxName.Name) {
if err := m.DB.Exec("ALTER TABLE ? ADD CONSTRAINT ? UNIQUE(?)", m.CurrentTable(stmt), idxName, clause.Column{Name: field.DBName}).Error; err != nil {
return err
}
}
}
}

if v, ok := fieldColumnType.DefaultValue(); (field.DefaultValueInterface == nil && ok) || v != field.DefaultValue {
if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") {
if field.DefaultValueInterface != nil {
defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}}
m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface)
if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface)}).Error; err != nil {
return err
}
} else if field.DefaultValue != "(-)" {
if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DefaultValue}).Error; err != nil {
return err
}
} else {
if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DefaultValue}).Error; err != nil {
return err
if v, ok := fieldColumnType.DefaultValue(); (field.DefaultValueInterface == nil && ok) || v != field.DefaultValue {
if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") {
if field.DefaultValueInterface != nil {
defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}}
m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface)
if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface)}).Error; err != nil {
return err
}
} else if field.DefaultValue != "(-)" {
if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DefaultValue}).Error; err != nil {
return err
}
} else {
if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DefaultValue}).Error; err != nil {
return err
}
}
}
}
return nil
}
return nil
}
return fmt.Errorf("failed to look up field with name: %s", field)
})
Expand Down

0 comments on commit d616c6a

Please sign in to comment.