Skip to content

Commit

Permalink
Improve support for AutoMigrate
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Feb 19, 2022
1 parent 262ad9b commit b5865f9
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 7 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ require (
github.com/jackc/pgx/v4 v4.14.1
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 // indirect
golang.org/x/text v0.3.7 // indirect
gorm.io/gorm v1.22.3
gorm.io/gorm v1.23.1
)
6 changes: 2 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dv
github.com/jackc/puddle v1.2.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.2 h1:eVKgfIdy9b6zbWBMgFpfDPoAMifwSZagU9HmEU6zgiI=
github.com/jinzhu/now v1.1.2/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/jinzhu/now v1.1.4 h1:tHnRBy1i5F2Dh8BAFxqFzxKqqvezXrL2OW1TnX+Mlas=
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
Expand Down Expand Up @@ -186,6 +184,6 @@ gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:a
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/gorm v1.22.3 h1:/JS6z+GStEQvJNW3t1FTwJwG/gZ+A7crFdRqtvG5ehA=
gorm.io/gorm v1.22.3/go.mod h1:F+OptMscr0P2F2qU97WT1WimdH9GaQPoDW7AYd5i2Y0=
gorm.io/gorm v1.23.1 h1:aj5IlhDzEPsoIyOPtTRVI+SyaN1u6k613sbt4pwbxG0=
gorm.io/gorm v1.23.1/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg=
79 changes: 77 additions & 2 deletions migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package postgres
import (
"database/sql"
"fmt"
"regexp"
"strings"

"gorm.io/gorm"
Expand Down Expand Up @@ -231,6 +232,73 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
})
}

// AlterColumn alter value's `field` column' type based on schema definition
func (m Migrator) AlterColumn(value interface{}, field string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(field); field != nil {
var (
columnTypes, _ = m.DB.Migrator().ColumnTypes(value)
fieldColumnType migrator.ColumnType
)
for _, columnType := range columnTypes {
if columnType.Name() == field.DBName {
fieldColumnType, _ = columnType.(migrator.ColumnType)
}
}

return m.DB.Connection(func(tx *gorm.DB) error {
fileType := clause.Expr{SQL: m.DataTypeOf(field)}
if fieldColumnType.DatabaseTypeName() != fileType.SQL {
if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? TYPE ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType).Error; err != nil {
return err
}
}

if null, _ := fieldColumnType.Nullable(); null == field.NotNull {
if field.NotNull {
if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? SET NOT NULL", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil {
return err
}
} else {
if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? DROP NOT NULL", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil {
return err
}
}
}

if uniq, _ := fieldColumnType.Unique(); uniq != field.Unique {
idxName := clause.Column{Name: m.DB.Config.NamingStrategy.IndexName(stmt.Table, field.DBName)}
if err := tx.Exec("ALTER TABLE ? ADD CONSTRAINT ? UNIQUE(?)", m.CurrentTable(stmt), idxName, clause.Column{Name: field.DBName}).Error; err != nil {
return err
}
}

if v, _ := fieldColumnType.DefaultValue(); v != field.DefaultValue {
if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") {
if field.DefaultValueInterface != nil {
defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}}
m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface)
if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface)}).Error; err != nil {
return err
}
} else if field.DefaultValue != "(-)" {
if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DefaultValue}).Error; err != nil {
return err
}
} else {
if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DefaultValue}).Error; err != nil {
return err
}
}
}
}
return nil
})
}
return fmt.Errorf("failed to look up field with name: %s", field)
})
}

func (m Migrator) HasConstraint(value interface{}, name string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
Expand Down Expand Up @@ -276,7 +344,10 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType,

for columns.Next() {
var (
column migrator.ColumnType
column = migrator.ColumnType{
PrimaryKeyValue: sql.NullBool{Valid: true},
UniqueValue: sql.NullBool{Valid: true},
}
datetimePrecision sql.NullInt64
radixValue sql.NullInt64
typeLenValue sql.NullInt64
Expand All @@ -299,6 +370,10 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType,
column.DefaultValueValue = sql.NullString{}
}

if column.DefaultValueValue.Valid {
column.DefaultValueValue.String = regexp.MustCompile("'(.*)'::[\\w]+$").ReplaceAllString(column.DefaultValueValue.String, "$1")
}

if datetimePrecision.Valid {
column.DecimalSizeValue = datetimePrecision
}
Expand Down Expand Up @@ -326,7 +401,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType,
if mc.NameValue.String == name {
switch columnType {
case "PRIMARY KEY":
mc.PrimayKeyValue = sql.NullBool{Bool: true, Valid: true}
mc.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true}
case "UNIQUE":
mc.UniqueValue = sql.NullBool{Bool: true, Valid: true}
}
Expand Down

0 comments on commit b5865f9

Please sign in to comment.