Skip to content

Commit

Permalink
Update Migrator ColumnType interface
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Feb 18, 2022
1 parent 6752fb0 commit 262ad9b
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 73 deletions.
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.2 h1:eVKgfIdy9b6zbWBMgFpfDPoAMifwSZagU9HmEU6zgiI=
github.com/jinzhu/now v1.1.2/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/jinzhu/now v1.1.4 h1:tHnRBy1i5F2Dh8BAFxqFzxKqqvezXrL2OW1TnX+Mlas=
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
Expand Down
165 changes: 92 additions & 73 deletions migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,61 +15,6 @@ type Migrator struct {
migrator.Migrator
}

type Column struct {
name string
nullable sql.NullString
datatype string
maxlen sql.NullInt64
precision sql.NullInt64
radix sql.NullInt64
scale sql.NullInt64
datetimeprecision sql.NullInt64
typlen sql.NullInt64
}

func (c Column) Name() string {
return c.name
}

func (c Column) DatabaseTypeName() string {
return c.datatype
}

func (c Column) Length() (length int64, ok bool) {
ok = c.typlen.Valid
if ok && c.typlen.Int64 > 0 {
length = c.typlen.Int64
} else {
ok = c.maxlen.Valid
if ok {
length = c.maxlen.Int64
} else {
length = 0
}
}
return
}

func (c Column) Nullable() (nullable bool, ok bool) {
if c.nullable.Valid {
nullable, ok = c.nullable.String == "YES", true
} else {
nullable, ok = false, false
}
return
}

func (c Column) DecimalSize() (precision int64, scale int64, ok bool) {
if ok = c.precision.Valid && c.scale.Valid && c.radix.Valid && c.radix.Int64 == 10; ok {
precision, scale = c.precision.Int64, c.scale.Int64
} else if ok = c.datetimeprecision.Valid; ok {
precision, scale = c.datetimeprecision.Int64, 0
} else {
precision, scale, ok = 0, 0, false
}
return
}

func (m Migrator) CurrentDatabase() (name string) {
m.DB.Raw("SELECT CURRENT_DATABASE()").Scan(&name)
return
Expand Down Expand Up @@ -309,38 +254,112 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool {
func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, err error) {
columnTypes = make([]gorm.ColumnType, 0)
err = m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase()
currentSchema, table := m.CurrentSchema(stmt, stmt.Table)
columns, err := m.DB.Raw(
"SELECT column_name, is_nullable, udt_name, character_maximum_length, "+
"numeric_precision, numeric_precision_radix, numeric_scale, datetime_precision, 8 * typlen "+
"FROM information_schema.columns AS cols JOIN pg_type AS pgt ON cols.udt_name = pgt.typname "+
"WHERE table_catalog = ? AND table_schema = ? AND table_name = ?",
currentDatabase, currentSchema, table).Rows()
var (
currentDatabase = m.DB.Migrator().CurrentDatabase()
currentSchema, table = m.CurrentSchema(stmt, stmt.Table)
columns, err = m.DB.Raw(
"SELECT c.column_name, c.is_nullable = 'YES', c.udt_name, c.character_maximum_length, c.numeric_precision, c.numeric_precision_radix, c.numeric_scale, c.datetime_precision, 8 * typlen, c.column_default, pd.description FROM information_schema.columns AS c JOIN pg_type AS pgt ON c.udt_name = pgt.typname LEFT JOIN pg_catalog.pg_description as pd ON pd.objsubid = c.ordinal_position AND pd.objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = c.table_name AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = c.table_schema)) where table_catalog = ? AND table_schema = ? AND table_name = ?",
currentDatabase, currentSchema, table).Rows()
rows, rowsErr = m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows()
)

if err != nil {
return err
}
defer columns.Close()

if rowsErr != nil {
return rowsErr
}
defer rows.Close()
rawColumnTypes, err := rows.ColumnTypes()

for columns.Next() {
var column Column
var (
column migrator.ColumnType
datetimePrecision sql.NullInt64
radixValue sql.NullInt64
typeLenValue sql.NullInt64
)

err = columns.Scan(
&column.name,
&column.nullable,
&column.datatype,
&column.maxlen,
&column.precision,
&column.radix,
&column.scale,
&column.datetimeprecision,
&column.typlen,
&column.NameValue, &column.NullableValue, &column.DataTypeValue, &column.LengthValue, &column.DecimalSizeValue,
&radixValue, &column.ScaleValue, &datetimePrecision, &typeLenValue, &column.DefaultValueValue, &column.CommentValue,
)
if err != nil {
return err
}

if typeLenValue.Valid && typeLenValue.Int64 > 0 {
column.LengthValue = typeLenValue
}

if strings.HasPrefix(column.DefaultValueValue.String, "nextval('") && strings.HasSuffix(column.DefaultValueValue.String, "seq'::regclass)") {
column.AutoIncrementValue = sql.NullBool{Bool: true, Valid: true}
column.DefaultValueValue = sql.NullString{}
}

if datetimePrecision.Valid {
column.DecimalSizeValue = datetimePrecision
}

for _, c := range rawColumnTypes {
if c.Name() == column.NameValue.String {
column.SQLColumnType = c
break
}
}
columnTypes = append(columnTypes, column)
}

columnTypeRows, err := m.DB.Raw("SELECT c.column_name, constraint_type FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ?", currentDatabase, currentSchema, table).Rows()
if err != nil {
return err
}
defer columnTypeRows.Close()

for columnTypeRows.Next() {
var name, columnType string
columnTypeRows.Scan(&name, &columnType)
for idx, c := range columnTypes {
mc := c.(migrator.ColumnType)
if mc.NameValue.String == name {
switch columnType {
case "PRIMARY KEY":
mc.PrimayKeyValue = sql.NullBool{Bool: true, Valid: true}
case "UNIQUE":
mc.UniqueValue = sql.NullBool{Bool: true, Valid: true}
}
columnTypes[idx] = mc
break
}
}
}

// Set column type
dataTypeRows, err := m.DB.Raw(`SELECT a.attname as column_name, format_type(a.atttypid, a.atttypmod) AS data_type
FROM pg_attribute a JOIN pg_class b ON a.attrelid = b.relfilenode AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?)
WHERE a.attnum > 0 -- hide internal columns
AND NOT a.attisdropped -- hide deleted columns
AND b.relname = ?`, currentSchema, table).Rows()
if err != nil {
return err
}
defer dataTypeRows.Close()

for dataTypeRows.Next() {
var name, dataType string
columnTypeRows.Scan(&name, &dataType)
for idx, c := range columnTypes {
mc := c.(migrator.ColumnType)
if mc.NameValue.String == name {
mc.ColumnTypeValue = sql.NullString{String: dataType, Valid: true}
columnTypes[idx] = mc
break
}
}
}

return err
})
return
Expand Down

0 comments on commit 262ad9b

Please sign in to comment.