Skip to content

Commit

Permalink
Fix auto migration with transaction, close go-gorm/gorm#5175
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Apr 3, 2022
1 parent 8d60c48 commit 55755e3
Showing 1 changed file with 61 additions and 50 deletions.
111 changes: 61 additions & 50 deletions migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (m Migrator) CreateIndex(value interface{}, name string) error {
}
createIndexSQL += "INDEX "

if strings.EqualFold(strings.TrimSpace(idx.Option), "CONCURRENTLY") {
if strings.TrimSpace(strings.ToUpper(idx.Option)) == "CONCURRENTLY" {
createIndexSQL += "CONCURRENTLY "
}

Expand Down Expand Up @@ -328,23 +328,15 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType,
columns, err = m.DB.Raw(
"SELECT c.column_name, c.is_nullable = 'YES', c.udt_name, c.character_maximum_length, c.numeric_precision, c.numeric_precision_radix, c.numeric_scale, c.datetime_precision, 8 * typlen, c.column_default, pd.description FROM information_schema.columns AS c JOIN pg_type AS pgt ON c.udt_name = pgt.typname LEFT JOIN pg_catalog.pg_description as pd ON pd.objsubid = c.ordinal_position AND pd.objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = c.table_name AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = c.table_schema)) where table_catalog = ? AND table_schema = ? AND table_name = ?",
currentDatabase, currentSchema, table).Rows()
rows, rowsErr = m.GetRows(currentSchema, table)
)

if err != nil {
return err
}
defer columns.Close()

if rowsErr != nil {
return rowsErr
}
defer rows.Close()
rawColumnTypes, err := rows.ColumnTypes()

for columns.Next() {
var (
column = migrator.ColumnType{
column = &migrator.ColumnType{
PrimaryKeyValue: sql.NullBool{Valid: true},
UniqueValue: sql.NullBool{Valid: true},
}
Expand Down Expand Up @@ -378,61 +370,80 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType,
column.DecimalSizeValue = datetimePrecision
}

for _, c := range rawColumnTypes {
if c.Name() == column.NameValue.String {
column.SQLColumnType = c
break
}
}
columnTypes = append(columnTypes, column)
}
columns.Close()

columnTypeRows, err := m.DB.Raw("SELECT c.column_name, constraint_type FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ?", currentDatabase, currentSchema, table).Rows()
if err != nil {
return err
// assign sql column type
{
rows, rowsErr := m.GetRows(currentSchema, table)
if rowsErr != nil {
return rowsErr
}
rawColumnTypes, err := rows.ColumnTypes()
if err != nil {
return err
}
for _, columnType := range columnTypes {
for _, c := range rawColumnTypes {
if c.Name() == columnType.Name() {
columnType.(*migrator.ColumnType).SQLColumnType = c
break
}
}
}
rows.Close()
}
defer columnTypeRows.Close()

for columnTypeRows.Next() {
var name, columnType string
columnTypeRows.Scan(&name, &columnType)
for idx, c := range columnTypes {
mc := c.(migrator.ColumnType)
if mc.NameValue.String == name {
switch columnType {
case "PRIMARY KEY":
mc.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true}
case "UNIQUE":
mc.UniqueValue = sql.NullBool{Bool: true, Valid: true}

// check primary, unique field
{
columnTypeRows, err := m.DB.Raw("SELECT c.column_name, constraint_type FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ?", currentDatabase, currentSchema, table).Rows()
if err != nil {
return err
}

for columnTypeRows.Next() {
var name, columnType string
columnTypeRows.Scan(&name, &columnType)
for _, c := range columnTypes {
mc := c.(*migrator.ColumnType)
if mc.NameValue.String == name {
switch columnType {
case "PRIMARY KEY":
mc.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true}
case "UNIQUE":
mc.UniqueValue = sql.NullBool{Bool: true, Valid: true}
}
break
}
columnTypes[idx] = mc
break
}
}
columnTypeRows.Close()
}

// Set column type
dataTypeRows, err := m.DB.Raw(`SELECT a.attname as column_name, format_type(a.atttypid, a.atttypmod) AS data_type
// check column type
{
dataTypeRows, err := m.DB.Raw(`SELECT a.attname as column_name, format_type(a.atttypid, a.atttypmod) AS data_type
FROM pg_attribute a JOIN pg_class b ON a.attrelid = b.relfilenode AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?)
WHERE a.attnum > 0 -- hide internal columns
AND NOT a.attisdropped -- hide deleted columns
AND b.relname = ?`, currentSchema, table).Rows()
if err != nil {
return err
}
defer dataTypeRows.Close()

for dataTypeRows.Next() {
var name, dataType string
columnTypeRows.Scan(&name, &dataType)
for idx, c := range columnTypes {
mc := c.(migrator.ColumnType)
if mc.NameValue.String == name {
mc.ColumnTypeValue = sql.NullString{String: dataType, Valid: true}
columnTypes[idx] = mc
break
if err != nil {
return err
}

for dataTypeRows.Next() {
var name, dataType string
dataTypeRows.Scan(&name, &dataType)
for _, c := range columnTypes {
mc := c.(*migrator.ColumnType)
if mc.NameValue.String == name {
mc.ColumnTypeValue = sql.NullString{String: dataType, Valid: true}
break
}
}
}
dataTypeRows.Close()
}

return err
Expand Down

0 comments on commit 55755e3

Please sign in to comment.