diff --git a/clause/benchmarks_test.go b/clause/benchmarks_test.go index 2faed773b..88a238e36 100644 --- a/clause/benchmarks_test.go +++ b/clause/benchmarks_test.go @@ -7,7 +7,7 @@ import ( "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" - "gorm.io/gorm/tests" + "gorm.io/gorm/utils/tests" ) func BenchmarkSelect(b *testing.B) { diff --git a/clause/clause_test.go b/clause/clause_test.go index f9d26a4ac..6239ff399 100644 --- a/clause/clause_test.go +++ b/clause/clause_test.go @@ -9,7 +9,7 @@ import ( "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" - "gorm.io/gorm/tests" + "gorm.io/gorm/utils/tests" ) var db, _ = gorm.Open(tests.DummyDialector{}, nil) diff --git a/clause/expression_test.go b/clause/expression_test.go index 4e9376504..3059aea69 100644 --- a/clause/expression_test.go +++ b/clause/expression_test.go @@ -8,7 +8,7 @@ import ( "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" - "gorm.io/gorm/tests" + "gorm.io/gorm/utils/tests" ) func TestExpr(t *testing.T) { diff --git a/dialects/mssql/create.go b/dialects/mssql/create.go deleted file mode 100644 index b07f13c58..000000000 --- a/dialects/mssql/create.go +++ /dev/null @@ -1,225 +0,0 @@ -package mssql - -import ( - "reflect" - "sort" - - "gorm.io/gorm" - "gorm.io/gorm/callbacks" - "gorm.io/gorm/clause" - "gorm.io/gorm/schema" -) - -func Create(db *gorm.DB) { - if db.Statement.Schema != nil && !db.Statement.Unscoped { - for _, c := range db.Statement.Schema.CreateClauses { - db.Statement.AddClause(c) - } - } - - if db.Statement.SQL.String() == "" { - setIdentityInsert := false - c := db.Statement.Clauses["ON CONFLICT"] - onConflict, hasConflict := c.Expression.(clause.OnConflict) - - if field := db.Statement.Schema.PrioritizedPrimaryField; field != nil { - setIdentityInsert = false - switch db.Statement.ReflectValue.Kind() { - case reflect.Struct: - _, isZero := field.ValueOf(db.Statement.ReflectValue) - setIdentityInsert = !isZero - case reflect.Slice: - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - _, isZero := field.ValueOf(db.Statement.ReflectValue.Index(i)) - setIdentityInsert = !isZero - break - } - } - - if setIdentityInsert && (field.DataType == schema.Int || field.DataType == schema.Uint) { - setIdentityInsert = true - db.Statement.WriteString("SET IDENTITY_INSERT ") - db.Statement.WriteQuoted(db.Statement.Table) - db.Statement.WriteString(" ON;") - } else { - setIdentityInsert = false - } - } - - if hasConflict && len(db.Statement.Schema.PrimaryFields) > 0 { - MergeCreate(db, onConflict) - } else { - db.Statement.AddClauseIfNotExists(clause.Insert{Table: clause.Table{Name: db.Statement.Table}}) - db.Statement.Build("INSERT") - db.Statement.WriteByte(' ') - - db.Statement.AddClause(callbacks.ConvertToCreateValues(db.Statement)) - if values, ok := db.Statement.Clauses["VALUES"].Expression.(clause.Values); ok { - if len(values.Columns) > 0 { - db.Statement.WriteByte('(') - for idx, column := range values.Columns { - if idx > 0 { - db.Statement.WriteByte(',') - } - db.Statement.WriteQuoted(column) - } - db.Statement.WriteByte(')') - - outputInserted(db) - - db.Statement.WriteString(" VALUES ") - - for idx, value := range values.Values { - if idx > 0 { - db.Statement.WriteByte(',') - } - - db.Statement.WriteByte('(') - db.Statement.AddVar(db.Statement, value...) - db.Statement.WriteByte(')') - } - - db.Statement.WriteString(";") - } else { - db.Statement.WriteString("DEFAULT VALUES;") - } - } - } - - if setIdentityInsert { - db.Statement.WriteString("SET IDENTITY_INSERT ") - db.Statement.WriteQuoted(db.Statement.Table) - db.Statement.WriteString(" OFF;") - } - } - - if !db.DryRun { - rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - - if err == nil { - defer rows.Close() - - if len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 { - sortedKeys := []string{} - for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue { - sortedKeys = append(sortedKeys, field.DBName) - } - sort.Strings(sortedKeys) - - returnningFields := make([]*schema.Field, len(sortedKeys)) - for idx, key := range sortedKeys { - returnningFields[idx] = db.Statement.Schema.LookUpField(key) - } - - values := make([]interface{}, len(returnningFields)) - - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - for rows.Next() { - for idx, field := range returnningFields { - values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() - } - - db.RowsAffected++ - db.AddError(rows.Scan(values...)) - } - case reflect.Struct: - for idx, field := range returnningFields { - values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() - } - - if rows.Next() { - db.RowsAffected++ - db.AddError(rows.Scan(values...)) - } - } - } - } else { - db.AddError(err) - } - } -} - -func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) { - values := callbacks.ConvertToCreateValues(db.Statement) - - db.Statement.WriteString("MERGE INTO ") - db.Statement.WriteQuoted(db.Statement.Table) - db.Statement.WriteString(" USING (VALUES") - for idx, value := range values.Values { - if idx > 0 { - db.Statement.WriteByte(',') - } - - db.Statement.WriteByte('(') - db.Statement.AddVar(db.Statement, value...) - db.Statement.WriteByte(')') - } - - db.Statement.WriteString(") AS source (") - for idx, column := range values.Columns { - if idx > 0 { - db.Statement.WriteByte(',') - } - db.Statement.WriteQuoted(column.Name) - } - db.Statement.WriteString(") ON ") - - var where clause.Where - for _, field := range db.Statement.Schema.PrimaryFields { - where.Exprs = append(where.Exprs, clause.Eq{ - Column: clause.Column{Table: db.Statement.Table, Name: field.DBName}, - Value: clause.Column{Table: "source", Name: field.DBName}, - }) - } - where.Build(db.Statement) - - if len(onConflict.DoUpdates) > 0 { - db.Statement.WriteString(" WHEN MATCHED THEN UPDATE SET ") - onConflict.DoUpdates.Build(db.Statement) - } - - db.Statement.WriteString(" WHEN NOT MATCHED THEN INSERT (") - - for idx, column := range values.Columns { - if idx > 0 { - db.Statement.WriteByte(',') - } - db.Statement.WriteQuoted(column.Name) - } - - db.Statement.WriteString(") VALUES (") - - for idx, column := range values.Columns { - if idx > 0 { - db.Statement.WriteByte(',') - } - db.Statement.WriteQuoted(clause.Column{ - Table: "source", - Name: column.Name, - }) - } - - db.Statement.WriteString(")") - outputInserted(db) - db.Statement.WriteString(";") -} - -func outputInserted(db *gorm.DB) { - if len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 { - sortedKeys := []string{} - for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue { - sortedKeys = append(sortedKeys, field.DBName) - } - sort.Strings(sortedKeys) - - db.Statement.WriteString(" OUTPUT") - for idx, key := range sortedKeys { - if idx > 0 { - db.Statement.WriteString(",") - } - db.Statement.WriteString(" INSERTED.") - db.Statement.AddVar(db.Statement, clause.Column{Name: key}) - } - } -} diff --git a/dialects/mssql/migrator.go b/dialects/mssql/migrator.go deleted file mode 100644 index 3bb2086db..000000000 --- a/dialects/mssql/migrator.go +++ /dev/null @@ -1,142 +0,0 @@ -package mssql - -import ( - "fmt" - - "gorm.io/gorm" - "gorm.io/gorm/clause" - "gorm.io/gorm/migrator" -) - -type Migrator struct { - migrator.Migrator -} - -func (m Migrator) HasTable(value interface{}) bool { - var count int - m.RunWithValue(value, func(stmt *gorm.Statement) error { - return m.DB.Raw( - "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", - stmt.Table, m.CurrentDatabase(), - ).Row().Scan(&count) - }) - return count > 0 -} - -func (m Migrator) RenameTable(oldName, newName interface{}) error { - var oldTable, newTable string - if v, ok := oldName.(string); ok { - oldTable = v - } else { - stmt := &gorm.Statement{DB: m.DB} - if err := stmt.Parse(oldName); err == nil { - oldTable = stmt.Table - } else { - return err - } - } - - if v, ok := newName.(string); ok { - newTable = v - } else { - stmt := &gorm.Statement{DB: m.DB} - if err := stmt.Parse(newName); err == nil { - newTable = stmt.Table - } else { - return err - } - } - - return m.DB.Exec( - "sp_rename @objname = ?, @newname = ?;", - clause.Table{Name: oldTable}, clause.Table{Name: newTable}, - ).Error -} - -func (m Migrator) HasColumn(value interface{}, field string) bool { - var count int64 - m.RunWithValue(value, func(stmt *gorm.Statement) error { - currentDatabase := m.DB.Migrator().CurrentDatabase() - name := field - if field := stmt.Schema.LookUpField(field); field != nil { - name = field.DBName - } - - return m.DB.Raw( - "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", - currentDatabase, stmt.Table, name, - ).Row().Scan(&count) - }) - - return count > 0 -} - -func (m Migrator) AlterColumn(value interface{}, field string) error { - return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if field := stmt.Schema.LookUpField(field); field != nil { - return m.DB.Exec( - "ALTER TABLE ? ALTER COLUMN ? ?", - clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.FullDataTypeOf(field), - ).Error - } - return fmt.Errorf("failed to look up field with name: %s", field) - }) -} - -func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error { - return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if field := stmt.Schema.LookUpField(oldName); field != nil { - oldName = field.DBName - } - - if field := stmt.Schema.LookUpField(newName); field != nil { - newName = field.DBName - } - - return m.DB.Exec( - "sp_rename @objname = ?, @newname = ?, @objtype = 'COLUMN';", - fmt.Sprintf("%s.%s", stmt.Table, oldName), clause.Column{Name: newName}, - ).Error - }) -} - -func (m Migrator) HasIndex(value interface{}, name string) bool { - var count int - m.RunWithValue(value, func(stmt *gorm.Statement) error { - if idx := stmt.Schema.LookIndex(name); idx != nil { - name = idx.Name - } - - return m.DB.Raw( - "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", - name, stmt.Table, - ).Row().Scan(&count) - }) - return count > 0 -} - -func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { - return m.RunWithValue(value, func(stmt *gorm.Statement) error { - - return m.DB.Exec( - "sp_rename @objname = ?, @newname = ?, @objtype = 'INDEX';", - fmt.Sprintf("%s.%s", stmt.Table, oldName), clause.Column{Name: newName}, - ).Error - }) -} - -func (m Migrator) HasConstraint(value interface{}, name string) bool { - var count int64 - m.RunWithValue(value, func(stmt *gorm.Statement) error { - return m.DB.Raw( - `SELECT count(*) FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id inner join information_schema.tables as I on I.TABLE_NAME = T.name WHERE F.name = ? AND T.Name = ? AND I.TABLE_CATALOG = ?;`, - name, stmt.Table, m.CurrentDatabase(), - ).Row().Scan(&count) - }) - return count > 0 -} - -func (m Migrator) CurrentDatabase() (name string) { - m.DB.Raw("SELECT DB_NAME() AS [Current Database]").Row().Scan(&name) - return -} diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go deleted file mode 100644 index 3f87180cf..000000000 --- a/dialects/mssql/mssql.go +++ /dev/null @@ -1,127 +0,0 @@ -package mssql - -import ( - "database/sql" - "fmt" - "regexp" - "strconv" - - _ "github.com/denisenkom/go-mssqldb" - "gorm.io/gorm" - "gorm.io/gorm/callbacks" - "gorm.io/gorm/clause" - "gorm.io/gorm/logger" - "gorm.io/gorm/migrator" - "gorm.io/gorm/schema" -) - -type Dialector struct { - DSN string -} - -func (dialector Dialector) Name() string { - return "mssql" -} - -func Open(dsn string) gorm.Dialector { - return &Dialector{DSN: dsn} -} - -func (dialector Dialector) Initialize(db *gorm.DB) (err error) { - // register callbacks - callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{}) - db.Callback().Create().Replace("gorm:create", Create) - db.ConnPool, err = sql.Open("sqlserver", dialector.DSN) - - for k, v := range dialector.ClauseBuilders() { - db.ClauseBuilders[k] = v - } - return -} - -func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder { - return map[string]clause.ClauseBuilder{ - "LIMIT": func(c clause.Clause, builder clause.Builder) { - if limit, ok := c.Expression.(clause.Limit); ok { - if limit.Offset > 0 { - builder.WriteString("OFFSET ") - builder.WriteString(strconv.Itoa(limit.Offset)) - builder.WriteString("ROWS") - } - - if limit.Limit > 0 { - if limit.Offset == 0 { - builder.WriteString(" OFFSET 0 ROWS") - } - builder.WriteString(" FETCH NEXT ") - builder.WriteString(strconv.Itoa(limit.Limit)) - builder.WriteString(" ROWS ONLY") - } - } - }, - } -} - -func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { - return Migrator{migrator.Migrator{Config: migrator.Config{ - DB: db, - Dialector: dialector, - CreateIndexAfterCreateTable: true, - }}} -} - -func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { - writer.WriteString("@p") - writer.WriteString(strconv.Itoa(len(stmt.Vars))) -} - -func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { - writer.WriteByte('"') - writer.WriteString(str) - writer.WriteByte('"') -} - -var numericPlaceholder = regexp.MustCompile("@p(\\d+)") - -func (dialector Dialector) Explain(sql string, vars ...interface{}) string { - return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...) -} - -func (dialector Dialector) DataTypeOf(field *schema.Field) string { - switch field.DataType { - case schema.Bool: - return "bit" - case schema.Int, schema.Uint: - var sqlType string - switch { - case field.Size < 16: - sqlType = "smallint" - case field.Size < 31: - sqlType = "int" - default: - sqlType = "bigint" - } - - if field.AutoIncrement || field == field.Schema.PrioritizedPrimaryField { - return sqlType + " IDENTITY(1,1)" - } - return sqlType - case schema.Float: - return "float" - case schema.String: - size := field.Size - if field.PrimaryKey && size == 0 { - size = 256 - } - if size > 0 && size <= 4000 { - return fmt.Sprintf("nvarchar(%d)", size) - } - return "nvarchar(MAX)" - case schema.Time: - return "datetimeoffset" - case schema.Bytes: - return "varbinary(MAX)" - } - - return "" -} diff --git a/dialects/mysql/migrator.go b/dialects/mysql/migrator.go deleted file mode 100644 index 8d3d20c6a..000000000 --- a/dialects/mysql/migrator.go +++ /dev/null @@ -1,58 +0,0 @@ -package mysql - -import ( - "fmt" - - "gorm.io/gorm" - "gorm.io/gorm/clause" - "gorm.io/gorm/migrator" -) - -type Migrator struct { - migrator.Migrator -} - -func (m Migrator) AlterColumn(value interface{}, field string) error { - return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if field := stmt.Schema.LookUpField(field); field != nil { - return m.DB.Exec( - "ALTER TABLE ? MODIFY COLUMN ? ?", - clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.FullDataTypeOf(field), - ).Error - } - return fmt.Errorf("failed to look up field with name: %s", field) - }) -} - -func (m Migrator) DropTable(values ...interface{}) error { - values = m.ReorderModels(values, false) - tx := m.DB.Session(&gorm.Session{}) - tx.Exec("SET FOREIGN_KEY_CHECKS = 0;") - for i := len(values) - 1; i >= 0; i-- { - if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { - return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", clause.Table{Name: stmt.Table}).Error - }); err != nil { - return err - } - } - tx.Exec("SET FOREIGN_KEY_CHECKS = 1;") - return nil -} - -func (m Migrator) DropConstraint(value interface{}, name string) error { - return m.RunWithValue(value, func(stmt *gorm.Statement) error { - for _, chk := range stmt.Schema.ParseCheckConstraints() { - if chk.Name == name { - return m.DB.Exec( - "ALTER TABLE ? DROP CHECK ?", - clause.Table{Name: stmt.Table}, clause.Column{Name: name}, - ).Error - } - } - - return m.DB.Exec( - "ALTER TABLE ? DROP FOREIGN KEY ?", - clause.Table{Name: stmt.Table}, clause.Column{Name: name}, - ).Error - }) -} diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go deleted file mode 100644 index 035a6d79f..000000000 --- a/dialects/mysql/mysql.go +++ /dev/null @@ -1,169 +0,0 @@ -package mysql - -import ( - "database/sql" - "fmt" - "math" - - _ "github.com/go-sql-driver/mysql" - "gorm.io/gorm" - "gorm.io/gorm/callbacks" - "gorm.io/gorm/clause" - "gorm.io/gorm/logger" - "gorm.io/gorm/migrator" - "gorm.io/gorm/schema" -) - -type Dialector struct { - DSN string -} - -func Open(dsn string) gorm.Dialector { - return &Dialector{DSN: dsn} -} - -func (dialector Dialector) Name() string { - return "mysql" -} - -func (dialector Dialector) Initialize(db *gorm.DB) (err error) { - // register callbacks - callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{}) - db.ConnPool, err = sql.Open("mysql", dialector.DSN) - - for k, v := range dialector.ClauseBuilders() { - db.ClauseBuilders[k] = v - } - return -} - -func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder { - return map[string]clause.ClauseBuilder{ - "ON CONFLICT": func(c clause.Clause, builder clause.Builder) { - if onConflict, ok := c.Expression.(clause.OnConflict); ok { - builder.WriteString("ON DUPLICATE KEY UPDATE ") - if len(onConflict.DoUpdates) == 0 { - if s := builder.(*gorm.Statement).Schema; s != nil { - var column clause.Column - onConflict.DoNothing = false - - if s.PrioritizedPrimaryField != nil { - column = clause.Column{Name: s.PrioritizedPrimaryField.DBName} - } else { - for _, field := range s.FieldsByDBName { - column = clause.Column{Name: field.DBName} - break - } - } - onConflict.DoUpdates = []clause.Assignment{{Column: column, Value: column}} - } - } - - onConflict.DoUpdates.Build(builder) - } else { - c.Build(builder) - } - }, - "VALUES": func(c clause.Clause, builder clause.Builder) { - if values, ok := c.Expression.(clause.Values); ok && len(values.Columns) == 0 { - builder.WriteString("VALUES()") - return - } - c.Build(builder) - }, - } -} - -func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { - return Migrator{migrator.Migrator{Config: migrator.Config{ - DB: db, - Dialector: dialector, - }}} -} - -func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { - writer.WriteByte('?') -} - -func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { - writer.WriteByte('`') - writer.WriteString(str) - writer.WriteByte('`') -} - -func (dialector Dialector) Explain(sql string, vars ...interface{}) string { - return logger.ExplainSQL(sql, nil, `"`, vars...) -} - -func (dialector Dialector) DataTypeOf(field *schema.Field) string { - switch field.DataType { - case schema.Bool: - return "boolean" - case schema.Int, schema.Uint: - sqlType := "int" - switch { - case field.Size <= 8: - sqlType = "tinyint" - case field.Size <= 16: - sqlType = "smallint" - case field.Size <= 32: - sqlType = "int" - default: - sqlType = "bigint" - } - - if field.DataType == schema.Uint { - sqlType += " unsigned" - } - - if field.AutoIncrement || field == field.Schema.PrioritizedPrimaryField { - sqlType += " AUTO_INCREMENT" - } - return sqlType - case schema.Float: - if field.Size <= 32 { - return "float" - } - return "double" - case schema.String: - size := field.Size - if size == 0 { - if field.PrimaryKey || field.HasDefaultValue { - size = 256 - } - } - - if size >= 65536 && size <= int(math.Pow(2, 24)) { - return "mediumtext" - } else if size > int(math.Pow(2, 24)) || size <= 0 { - return "longtext" - } - return fmt.Sprintf("varchar(%d)", size) - case schema.Time: - precision := "" - if field.Precision == 0 { - field.Precision = 3 - } - - if field.Precision > 0 { - precision = fmt.Sprintf("(%d)", field.Precision) - } - - if field.NotNull || field.PrimaryKey { - return "datetime" + precision - } - return "datetime" + precision + " NULL" - case schema.Bytes: - if field.Size > 0 && field.Size < 65536 { - return fmt.Sprintf("varbinary(%d)", field.Size) - } - - if field.Size >= 65536 && field.Size <= int(math.Pow(2, 24)) { - return "mediumblob" - } - - return "longblob" - } - - return "" -} diff --git a/dialects/postgres/migrator.go b/dialects/postgres/migrator.go deleted file mode 100644 index 6b1085e33..000000000 --- a/dialects/postgres/migrator.go +++ /dev/null @@ -1,139 +0,0 @@ -package postgres - -import ( - "fmt" - - "gorm.io/gorm" - "gorm.io/gorm/clause" - "gorm.io/gorm/migrator" - "gorm.io/gorm/schema" -) - -type Migrator struct { - migrator.Migrator -} - -func (m Migrator) CurrentDatabase() (name string) { - m.DB.Raw("SELECT CURRENT_DATABASE()").Row().Scan(&name) - return -} - -func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { - for _, opt := range opts { - str := stmt.Quote(opt.DBName) - if opt.Expression != "" { - str = opt.Expression - } - - if opt.Collate != "" { - str += " COLLATE " + opt.Collate - } - - if opt.Sort != "" { - str += " " + opt.Sort - } - results = append(results, clause.Expr{SQL: str}) - } - return -} - -func (m Migrator) HasIndex(value interface{}, name string) bool { - var count int64 - m.RunWithValue(value, func(stmt *gorm.Statement) error { - if idx := stmt.Schema.LookIndex(name); idx != nil { - name = idx.Name - } - - return m.DB.Raw( - "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ? AND schemaname = CURRENT_SCHEMA()", stmt.Table, name, - ).Row().Scan(&count) - }) - - return count > 0 -} - -func (m Migrator) CreateIndex(value interface{}, name string) error { - return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if idx := stmt.Schema.LookIndex(name); idx != nil { - opts := m.BuildIndexOptions(idx.Fields, stmt) - values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} - - createIndexSQL := "CREATE " - if idx.Class != "" { - createIndexSQL += idx.Class + " " - } - createIndexSQL += "INDEX ?" - - if idx.Type != "" { - createIndexSQL += " USING " + idx.Type - } - createIndexSQL += " ON ??" - - if idx.Where != "" { - createIndexSQL += " WHERE " + idx.Where - } - - return m.DB.Exec(createIndexSQL, values...).Error - } - - return fmt.Errorf("failed to create index with name %v", name) - }) -} - -func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { - return m.RunWithValue(value, func(stmt *gorm.Statement) error { - return m.DB.Exec( - "ALTER INDEX ? RENAME TO ?", - clause.Column{Name: oldName}, clause.Column{Name: newName}, - ).Error - }) -} - -func (m Migrator) DropIndex(value interface{}, name string) error { - return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if idx := stmt.Schema.LookIndex(name); idx != nil { - name = idx.Name - } - - return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error - }) -} - -func (m Migrator) HasTable(value interface{}) bool { - var count int64 - m.RunWithValue(value, func(stmt *gorm.Statement) error { - return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND table_type = ?", stmt.Table, "BASE TABLE").Row().Scan(&count) - }) - - return count > 0 -} - -func (m Migrator) DropTable(values ...interface{}) error { - values = m.ReorderModels(values, false) - tx := m.DB.Session(&gorm.Session{}) - for i := len(values) - 1; i >= 0; i-- { - if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { - return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", clause.Table{Name: stmt.Table}).Error - }); err != nil { - return err - } - } - return nil -} - -func (m Migrator) HasColumn(value interface{}, field string) bool { - var count int64 - m.RunWithValue(value, func(stmt *gorm.Statement) error { - name := field - if field := stmt.Schema.LookUpField(field); field != nil { - name = field.DBName - } - - return m.DB.Raw( - "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND column_name = ?", - stmt.Table, name, - ).Row().Scan(&count) - }) - - return count > 0 -} diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go deleted file mode 100644 index 57e51d581..000000000 --- a/dialects/postgres/postgres.go +++ /dev/null @@ -1,102 +0,0 @@ -package postgres - -import ( - "database/sql" - "fmt" - "regexp" - "strconv" - - "gorm.io/gorm" - "gorm.io/gorm/callbacks" - "gorm.io/gorm/clause" - "gorm.io/gorm/logger" - "gorm.io/gorm/migrator" - "gorm.io/gorm/schema" - _ "github.com/lib/pq" -) - -type Dialector struct { - DSN string -} - -func Open(dsn string) gorm.Dialector { - return &Dialector{DSN: dsn} -} - -func (dialector Dialector) Name() string { - return "postgres" -} - -func (dialector Dialector) Initialize(db *gorm.DB) (err error) { - // register callbacks - callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ - WithReturning: true, - }) - db.ConnPool, err = sql.Open("postgres", dialector.DSN) - return -} - -func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { - return Migrator{migrator.Migrator{Config: migrator.Config{ - DB: db, - Dialector: dialector, - CreateIndexAfterCreateTable: true, - }}} -} - -func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { - writer.WriteByte('$') - writer.WriteString(strconv.Itoa(len(stmt.Vars))) -} - -func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { - writer.WriteByte('"') - writer.WriteString(str) - writer.WriteByte('"') -} - -var numericPlaceholder = regexp.MustCompile("\\$(\\d+)") - -func (dialector Dialector) Explain(sql string, vars ...interface{}) string { - return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...) -} - -func (dialector Dialector) DataTypeOf(field *schema.Field) string { - switch field.DataType { - case schema.Bool: - return "boolean" - case schema.Int, schema.Uint: - if field.AutoIncrement || field == field.Schema.PrioritizedPrimaryField { - switch { - case field.Size < 16: - return "smallserial" - case field.Size < 31: - return "serial" - default: - return "bigserial" - } - } else { - switch { - case field.Size < 16: - return "smallint" - case field.Size < 31: - return "integer" - default: - return "bigint" - } - } - case schema.Float: - return "decimal" - case schema.String: - if field.Size > 0 { - return fmt.Sprintf("varchar(%d)", field.Size) - } - return "text" - case schema.Time: - return "timestamptz" - case schema.Bytes: - return "bytea" - } - - return "" -} diff --git a/dialects/sqlite/migrator.go b/dialects/sqlite/migrator.go deleted file mode 100644 index 14c682ca7..000000000 --- a/dialects/sqlite/migrator.go +++ /dev/null @@ -1,211 +0,0 @@ -package sqlite - -import ( - "fmt" - "regexp" - "strings" - - "gorm.io/gorm" - "gorm.io/gorm/clause" - "gorm.io/gorm/migrator" - "gorm.io/gorm/schema" -) - -type Migrator struct { - migrator.Migrator -} - -func (m Migrator) HasTable(value interface{}) bool { - var count int - m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error { - return m.DB.Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", stmt.Table).Row().Scan(&count) - }) - return count > 0 -} - -func (m Migrator) HasColumn(value interface{}, name string) bool { - var count int - m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error { - if field := stmt.Schema.LookUpField(name); field != nil { - name = field.DBName - } - - return m.DB.Raw( - "SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ?)", - "table", stmt.Table, `%"`+name+`" %`, `%`+name+` %`, "%`"+name+"`%", - ).Row().Scan(&count) - }) - return count > 0 -} - -func (m Migrator) AlterColumn(value interface{}, name string) error { - return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if field := stmt.Schema.LookUpField(name); field != nil { - var ( - createSQL string - newTableName = stmt.Table + "__temp" - ) - - m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", stmt.Table, stmt.Table).Row().Scan(&createSQL) - - if reg, err := regexp.Compile("(`|'|\"| )" + name + "(`|'|\"| ) .*?,"); err == nil { - tableReg, err := regexp.Compile(" ('|`|\"| )" + stmt.Table + "('|`|\"| ) ") - if err != nil { - return err - } - - createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName)) - createSQL = reg.ReplaceAllString(createSQL, "?") - - var columns []string - columnTypes, _ := m.DB.Migrator().ColumnTypes(value) - for _, columnType := range columnTypes { - columns = append(columns, fmt.Sprintf("`%v`", columnType.Name())) - } - - createSQL = fmt.Sprintf("PRAGMA foreign_keys=off;BEGIN TRANSACTION;"+createSQL+";INSERT INTO `%v`(%v) SELECT %v FROM `%v`;DROP TABLE `%v`;ALTER TABLE `%v` RENAME TO `%v`;COMMIT;", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), stmt.Table, stmt.Table, newTableName, stmt.Table) - return m.DB.Exec(createSQL, m.FullDataTypeOf(field)).Error - } else { - return err - } - } else { - return fmt.Errorf("failed to alter field with name %v", name) - } - }) -} - -func (m Migrator) DropColumn(value interface{}, name string) error { - return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if field := stmt.Schema.LookUpField(name); field != nil { - name = field.DBName - } - - var ( - createSQL string - newTableName = stmt.Table + "__temp" - ) - - m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", stmt.Table, stmt.Table).Row().Scan(&createSQL) - - if reg, err := regexp.Compile("(`|'|\"| )" + name + "(`|'|\"| ) .*?,"); err == nil { - tableReg, err := regexp.Compile(" ('|`|\"| )" + stmt.Table + "('|`|\"| ) ") - if err != nil { - return err - } - - createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName)) - createSQL = reg.ReplaceAllString(createSQL, "") - - var columns []string - columnTypes, _ := m.DB.Migrator().ColumnTypes(value) - for _, columnType := range columnTypes { - if columnType.Name() != name { - columns = append(columns, fmt.Sprintf("`%v`", columnType.Name())) - } - } - - createSQL = fmt.Sprintf("PRAGMA foreign_keys=off;BEGIN TRANSACTION;"+createSQL+";INSERT INTO `%v`(%v) SELECT %v FROM `%v`;DROP TABLE `%v`;ALTER TABLE `%v` RENAME TO `%v`;COMMIT;", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), stmt.Table, stmt.Table, newTableName, stmt.Table) - - return m.DB.Exec(createSQL).Error - } else { - return err - } - }) -} - -func (m Migrator) CreateConstraint(interface{}, string) error { - return gorm.ErrNotImplemented -} - -func (m Migrator) DropConstraint(interface{}, string) error { - return gorm.ErrNotImplemented -} - -func (m Migrator) CurrentDatabase() (name string) { - var null interface{} - m.DB.Raw("PRAGMA database_list").Row().Scan(&null, &name, &null) - return -} - -func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { - for _, opt := range opts { - str := stmt.Quote(opt.DBName) - if opt.Expression != "" { - str = opt.Expression - } - - if opt.Collate != "" { - str += " COLLATE " + opt.Collate - } - - if opt.Sort != "" { - str += " " + opt.Sort - } - results = append(results, clause.Expr{SQL: str}) - } - return -} - -func (m Migrator) CreateIndex(value interface{}, name string) error { - return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if idx := stmt.Schema.LookIndex(name); idx != nil { - opts := m.BuildIndexOptions(idx.Fields, stmt) - values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} - - createIndexSQL := "CREATE " - if idx.Class != "" { - createIndexSQL += idx.Class + " " - } - createIndexSQL += "INDEX ?" - - if idx.Type != "" { - createIndexSQL += " USING " + idx.Type - } - createIndexSQL += " ON ??" - - if idx.Where != "" { - createIndexSQL += " WHERE " + idx.Where - } - - return m.DB.Exec(createIndexSQL, values...).Error - } - - return fmt.Errorf("failed to create index with name %v", name) - }) -} - -func (m Migrator) HasIndex(value interface{}, name string) bool { - var count int - m.RunWithValue(value, func(stmt *gorm.Statement) error { - if idx := stmt.Schema.LookIndex(name); idx != nil { - name = idx.Name - } - - m.DB.Raw( - "SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, name, - ).Row().Scan(&count) - return nil - }) - return count > 0 -} - -func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { - return m.RunWithValue(value, func(stmt *gorm.Statement) error { - var sql string - m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, oldName).Row().Scan(&sql) - if sql != "" { - return m.DB.Exec(strings.Replace(sql, oldName, newName, 1)).Error - } - return fmt.Errorf("failed to find index with name %v", oldName) - }) -} - -func (m Migrator) DropIndex(value interface{}, name string) error { - return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if idx := stmt.Schema.LookIndex(name); idx != nil { - name = idx.Name - } - - return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error - }) -} diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go deleted file mode 100644 index 238ad7f92..000000000 --- a/dialects/sqlite/sqlite.go +++ /dev/null @@ -1,80 +0,0 @@ -package sqlite - -import ( - "database/sql" - - "gorm.io/gorm" - "gorm.io/gorm/callbacks" - "gorm.io/gorm/clause" - "gorm.io/gorm/logger" - "gorm.io/gorm/migrator" - "gorm.io/gorm/schema" - _ "github.com/mattn/go-sqlite3" -) - -type Dialector struct { - DSN string -} - -func Open(dsn string) gorm.Dialector { - return &Dialector{DSN: dsn} -} - -func (dialector Dialector) Name() string { - return "sqlite" -} - -func (dialector Dialector) Initialize(db *gorm.DB) (err error) { - // register callbacks - callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ - LastInsertIDReversed: true, - }) - db.ConnPool, err = sql.Open("sqlite3", dialector.DSN) - return -} - -func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { - return Migrator{migrator.Migrator{Config: migrator.Config{ - DB: db, - Dialector: dialector, - CreateIndexAfterCreateTable: true, - }}} -} - -func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { - writer.WriteByte('?') -} - -func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { - writer.WriteByte('`') - writer.WriteString(str) - writer.WriteByte('`') -} - -func (dialector Dialector) Explain(sql string, vars ...interface{}) string { - return logger.ExplainSQL(sql, nil, `"`, vars...) -} - -func (dialector Dialector) DataTypeOf(field *schema.Field) string { - switch field.DataType { - case schema.Bool: - return "numeric" - case schema.Int, schema.Uint: - if field.AutoIncrement { - // https://www.sqlite.org/autoinc.html - return "integer PRIMARY KEY AUTOINCREMENT" - } else { - return "integer" - } - case schema.Float: - return "real" - case schema.String: - return "text" - case schema.Time: - return "datetime" - case schema.Bytes: - return "blob" - } - - return "" -} diff --git a/go.mod b/go.mod index 26877c7a7..faf63a46b 100644 --- a/go.mod +++ b/go.mod @@ -3,12 +3,6 @@ module gorm.io/gorm go 1.14 require ( - github.com/denisenkom/go-mssqldb v0.0.0-20200428022330-06a60b6afbbc - github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 - github.com/go-sql-driver/mysql v1.5.0 github.com/jinzhu/inflection v1.0.0 github.com/jinzhu/now v1.1.1 - github.com/lib/pq v1.1.1 - github.com/mattn/go-sqlite3 v2.0.1+incompatible - gorm.io/gorm v1.9.12 ) diff --git a/schema/field_test.go b/schema/field_test.go index 7a47f195d..fe88891f2 100644 --- a/schema/field_test.go +++ b/schema/field_test.go @@ -9,7 +9,7 @@ import ( "gorm.io/gorm" "gorm.io/gorm/schema" - "gorm.io/gorm/tests" + "gorm.io/gorm/utils/tests" ) func TestFieldValuerAndSetter(t *testing.T) { diff --git a/schema/model_test.go b/schema/model_test.go index 068b30509..a13372b59 100644 --- a/schema/model_test.go +++ b/schema/model_test.go @@ -5,7 +5,7 @@ import ( "time" "gorm.io/gorm" - "gorm.io/gorm/tests" + "gorm.io/gorm/utils/tests" ) type User struct { diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index b966164e7..f2ed41458 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -7,7 +7,7 @@ import ( "testing" "gorm.io/gorm/schema" - "gorm.io/gorm/tests" + "gorm.io/gorm/utils/tests" ) func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields []string) { diff --git a/schema/schema_test.go b/schema/schema_test.go index 6902cbf2a..1029f74f5 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -5,7 +5,7 @@ import ( "testing" "gorm.io/gorm/schema" - "gorm.io/gorm/tests" + "gorm.io/gorm/utils/tests" ) func TestParseSchema(t *testing.T) { diff --git a/tests/associations_belongs_to_test.go b/tests/associations_belongs_to_test.go index 27b82ecb0..35419666f 100644 --- a/tests/associations_belongs_to_test.go +++ b/tests/associations_belongs_to_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestBelongsToAssociation(t *testing.T) { diff --git a/tests/associations_has_many_test.go b/tests/associations_has_many_test.go index 88df85321..7ef0c2183 100644 --- a/tests/associations_has_many_test.go +++ b/tests/associations_has_many_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestHasManyAssociation(t *testing.T) { diff --git a/tests/associations_has_one_test.go b/tests/associations_has_one_test.go index 9ddfa9c51..f32a692d0 100644 --- a/tests/associations_has_one_test.go +++ b/tests/associations_has_one_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestHasOneAssociation(t *testing.T) { diff --git a/tests/associations_many2many_test.go b/tests/associations_many2many_test.go index d79cdc178..ba9695b71 100644 --- a/tests/associations_many2many_test.go +++ b/tests/associations_many2many_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestMany2ManyAssociation(t *testing.T) { diff --git a/tests/associations_test.go b/tests/associations_test.go index 2e30df8b2..44262109f 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func AssertAssociationCount(t *testing.T, data interface{}, name string, result int64, reason string) { diff --git a/tests/count_test.go b/tests/count_test.go index d8cfa405d..63238089f 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -4,7 +4,7 @@ import ( "fmt" "testing" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestCount(t *testing.T) { diff --git a/tests/create_test.go b/tests/create_test.go index 2f853c61a..c497014e8 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -6,7 +6,7 @@ import ( "github.com/jinzhu/now" "gorm.io/gorm" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestCreate(t *testing.T) { diff --git a/tests/customize_column_test.go b/tests/customize_column_test.go index 0db40869c..98dea4945 100644 --- a/tests/customize_column_test.go +++ b/tests/customize_column_test.go @@ -3,8 +3,6 @@ package tests_test import ( "testing" "time" - - . "gorm.io/gorm/tests" ) func TestCustomizeColumn(t *testing.T) { diff --git a/tests/delete_test.go b/tests/delete_test.go index 0fe2ee75d..66c396d19 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -5,7 +5,7 @@ import ( "testing" "gorm.io/gorm" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestDelete(t *testing.T) { diff --git a/tests/embedded_struct_test.go b/tests/embedded_struct_test.go index 748294606..9a1436feb 100644 --- a/tests/embedded_struct_test.go +++ b/tests/embedded_struct_test.go @@ -4,7 +4,6 @@ import ( "testing" "gorm.io/gorm" - . "gorm.io/gorm/tests" ) func TestEmbeddedStruct(t *testing.T) { diff --git a/tests/go.mod b/tests/go.mod new file mode 100644 index 000000000..3954c442f --- /dev/null +++ b/tests/go.mod @@ -0,0 +1,14 @@ +module gorm.io/gorm/tests + +go 1.14 + +require ( + github.com/jinzhu/now v1.1.1 + gorm.io/driver/mysql v0.0.0-20200602015408-0407d0c21cf0 + gorm.io/driver/postgres v0.0.0-20200602015520-15fcc29eb286 + gorm.io/driver/sqlite v0.0.0-20200602015323-284b563f81c8 + gorm.io/driver/sqlserver v0.0.0-20200602015206-ef9f739c6a30 + gorm.io/gorm v1.9.12 +) + +replace gorm.io/gorm => ../ diff --git a/tests/group_by_test.go b/tests/group_by_test.go index 5a9543484..cb4c4f43e 100644 --- a/tests/group_by_test.go +++ b/tests/group_by_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestGroupBy(t *testing.T) { diff --git a/tests/utils.go b/tests/helper_test.go similarity index 66% rename from tests/utils.go rename to tests/helper_test.go index 0b4b138ec..b05f52972 100644 --- a/tests/utils.go +++ b/tests/helper_test.go @@ -1,17 +1,13 @@ -package tests +package tests_test import ( - "database/sql/driver" - "fmt" - "go/ast" - "reflect" "sort" "strconv" "strings" "testing" "time" - "gorm.io/gorm/utils" + . "gorm.io/gorm/utils/tests" ) type Config struct { @@ -73,101 +69,6 @@ func GetUser(name string, config Config) *User { return &user } -func AssertObjEqual(t *testing.T, r, e interface{}, names ...string) { - for _, name := range names { - got := reflect.Indirect(reflect.ValueOf(r)).FieldByName(name).Interface() - expect := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface() - t.Run(name, func(t *testing.T) { - AssertEqual(t, got, expect) - }) - } -} - -func AssertEqual(t *testing.T, got, expect interface{}) { - if !reflect.DeepEqual(got, expect) { - isEqual := func() { - if curTime, ok := got.(time.Time); ok { - format := "2006-01-02T15:04:05Z07:00" - - if curTime.Round(time.Second).Format(format) != expect.(time.Time).Round(time.Second).Format(format) && curTime.Truncate(time.Second).Format(format) != expect.(time.Time).Truncate(time.Second).Format(format) { - t.Errorf("%v: expect: %v, got %v after time round", utils.FileWithLineNum(), expect.(time.Time), curTime) - } - } else if fmt.Sprint(got) != fmt.Sprint(expect) { - t.Errorf("%v: expect: %#v, got %#v", utils.FileWithLineNum(), expect, got) - } - } - - if fmt.Sprint(got) == fmt.Sprint(expect) { - return - } - - if reflect.Indirect(reflect.ValueOf(got)).IsValid() != reflect.Indirect(reflect.ValueOf(expect)).IsValid() { - t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got) - return - } - - if valuer, ok := got.(driver.Valuer); ok { - got, _ = valuer.Value() - } - - if valuer, ok := expect.(driver.Valuer); ok { - expect, _ = valuer.Value() - } - - if got != nil { - got = reflect.Indirect(reflect.ValueOf(got)).Interface() - } - - if expect != nil { - expect = reflect.Indirect(reflect.ValueOf(expect)).Interface() - } - - if reflect.ValueOf(got).IsValid() != reflect.ValueOf(expect).IsValid() { - t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got) - return - } - - if reflect.ValueOf(got).Kind() == reflect.Slice { - if reflect.ValueOf(expect).Kind() == reflect.Slice { - if reflect.ValueOf(got).Len() == reflect.ValueOf(expect).Len() { - for i := 0; i < reflect.ValueOf(got).Len(); i++ { - name := fmt.Sprintf(reflect.ValueOf(got).Type().Name()+" #%v", i) - t.Run(name, func(t *testing.T) { - AssertEqual(t, reflect.ValueOf(got).Index(i).Interface(), reflect.ValueOf(expect).Index(i).Interface()) - }) - } - } else { - name := reflect.ValueOf(got).Type().Elem().Name() - t.Errorf("%v expects length: %v, got %v", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len()) - } - return - } - } - - if reflect.ValueOf(got).Kind() == reflect.Struct { - if reflect.ValueOf(got).NumField() == reflect.ValueOf(expect).NumField() { - for i := 0; i < reflect.ValueOf(got).NumField(); i++ { - if fieldStruct := reflect.ValueOf(got).Type().Field(i); ast.IsExported(fieldStruct.Name) { - field := reflect.ValueOf(got).Field(i) - t.Run(fieldStruct.Name, func(t *testing.T) { - AssertEqual(t, field.Interface(), reflect.ValueOf(expect).Field(i).Interface()) - }) - } - } - return - } - } - - if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) { - got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface() - isEqual() - } else if reflect.ValueOf(expect).Type().ConvertibleTo(reflect.ValueOf(got).Type()) { - expect = reflect.ValueOf(got).Convert(reflect.ValueOf(got).Type()).Interface() - isEqual() - } - } -} - func CheckPet(t *testing.T, pet Pet, expect Pet) { if pet.ID != 0 { var newPet Pet diff --git a/tests/hooks_test.go b/tests/hooks_test.go index 418713a69..e2850c274 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -6,7 +6,6 @@ import ( "testing" "gorm.io/gorm" - . "gorm.io/gorm/tests" ) type Product struct { diff --git a/tests/joins_table_test.go b/tests/joins_table_test.go index 5738d8f47..b8c1be77e 100644 --- a/tests/joins_table_test.go +++ b/tests/joins_table_test.go @@ -5,7 +5,6 @@ import ( "time" "gorm.io/gorm" - . "gorm.io/gorm/tests" ) type Person struct { diff --git a/tests/joins_test.go b/tests/joins_test.go index 651b20c6d..f01c82113 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -5,7 +5,7 @@ import ( "testing" "gorm.io/gorm" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestJoins(t *testing.T) { diff --git a/tests/main_test.go b/tests/main_test.go index 2d466125c..ff293e6e4 100644 --- a/tests/main_test.go +++ b/tests/main_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestMain(m *testing.M) { diff --git a/tests/migrate_test.go b/tests/migrate_test.go index b511ab40a..5293898f1 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -7,7 +7,7 @@ import ( "time" "gorm.io/gorm" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestMigrate(t *testing.T) { diff --git a/tests/multi_primary_keys_test.go b/tests/multi_primary_keys_test.go index 139cde699..05267bbb0 100644 --- a/tests/multi_primary_keys_test.go +++ b/tests/multi_primary_keys_test.go @@ -4,8 +4,6 @@ import ( "reflect" "sort" "testing" - - . "gorm.io/gorm/tests" ) type Blog struct { @@ -36,8 +34,8 @@ func compareTags(tags []Tag, contents []string) bool { } func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { - if name := DB.Dialector.Name(); name == "sqlite" || name == "mssql" { - t.Skip("skip sqlite, mssql due to it doesn't support multiple primary keys with auto increment") + if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" { + t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment") } DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags") @@ -125,8 +123,8 @@ func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { } func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { - if name := DB.Dialector.Name(); name == "sqlite" || name == "mssql" { - t.Skip("skip sqlite, mssql due to it doesn't support multiple primary keys with auto increment") + if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" { + t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment") } DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags") @@ -246,8 +244,8 @@ func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { } func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { - if name := DB.Dialector.Name(); name == "sqlite" || name == "mssql" { - t.Skip("skip sqlite, mssql due to it doesn't support multiple primary keys with auto increment") + if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" { + t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment") } DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags") diff --git a/tests/named_polymorphic_test.go b/tests/named_polymorphic_test.go index 99a7865ab..616557846 100644 --- a/tests/named_polymorphic_test.go +++ b/tests/named_polymorphic_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) type Hamster struct { diff --git a/tests/non_std_test.go b/tests/non_std_test.go index b3ac65451..d3561b11e 100644 --- a/tests/non_std_test.go +++ b/tests/non_std_test.go @@ -3,8 +3,6 @@ package tests_test import ( "testing" "time" - - . "gorm.io/gorm/tests" ) type Animal struct { diff --git a/tests/preload_suits_test.go b/tests/preload_suits_test.go index 42e94fa0c..98f24daf3 100644 --- a/tests/preload_suits_test.go +++ b/tests/preload_suits_test.go @@ -7,7 +7,6 @@ import ( "testing" "gorm.io/gorm" - . "gorm.io/gorm/tests" ) func toJSONString(v interface{}) []byte { @@ -691,8 +690,8 @@ func TestNestedPreload12(t *testing.T) { } func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) { - if name := DB.Dialector.Name(); name == "sqlite" || name == "mssql" { - t.Skip("skip sqlite, mssql due to it doesn't support multiple primary keys with auto increment") + if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" { + t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment") } type ( diff --git a/tests/preload_test.go b/tests/preload_test.go index e4ecdc87c..06e38f096 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -6,7 +6,7 @@ import ( "testing" "gorm.io/gorm/clause" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestNestedPreload(t *testing.T) { diff --git a/tests/query_test.go b/tests/query_test.go index 9d15a41f7..f6fb1081d 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -9,7 +9,7 @@ import ( "time" "gorm.io/gorm" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestFind(t *testing.T) { diff --git a/tests/scan_test.go b/tests/scan_test.go index 262ac9a71..d6a372bb1 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -6,7 +6,7 @@ import ( "strings" "testing" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestScan(t *testing.T) { diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index 7dad081f4..7d72db150 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -11,7 +11,7 @@ import ( "time" "gorm.io/gorm" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestScannerValuer(t *testing.T) { diff --git a/tests/scopes_test.go b/tests/scopes_test.go index a2a7de3f4..c9787d362 100644 --- a/tests/scopes_test.go +++ b/tests/scopes_test.go @@ -4,7 +4,7 @@ import ( "testing" "gorm.io/gorm" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func NameIn1And2(d *gorm.DB) *gorm.DB { diff --git a/tests/soft_delete_test.go b/tests/soft_delete_test.go index 24b064982..c632c7534 100644 --- a/tests/soft_delete_test.go +++ b/tests/soft_delete_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestSoftDelete(t *testing.T) { diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index 0f3a56ed5..278a5b963 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -4,7 +4,7 @@ import ( "testing" "gorm.io/gorm" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestRow(t *testing.T) { diff --git a/tests/tests_all.sh b/tests/tests_all.sh index 3a1b45c8d..95245804a 100755 --- a/tests/tests_all.sh +++ b/tests/tests_all.sh @@ -18,8 +18,13 @@ for dialect in "${dialects[@]}" ; do if [ "$GORM_VERBOSE" = "" ] then DEBUG=false GORM_DIALECT=${dialect} go test $race -count=1 ./... + cd tests + DEBUG=false GORM_DIALECT=${dialect} go test $race -count=1 ./... else DEBUG=false GORM_DIALECT=${dialect} go test $race -count=1 -v ./... + cd tests + DEBUG=false GORM_DIALECT=${dialect} go test $race -count=1 -v ./... fi + cd .. fi done diff --git a/tests/tests.go b/tests/tests_test.go similarity index 87% rename from tests/tests.go rename to tests/tests_test.go index 42902685d..40816c3c2 100644 --- a/tests/tests.go +++ b/tests/tests_test.go @@ -1,4 +1,4 @@ -package tests +package tests_test import ( "log" @@ -7,12 +7,13 @@ import ( "path/filepath" "time" + "gorm.io/driver/mysql" + "gorm.io/driver/postgres" + "gorm.io/driver/sqlite" + "gorm.io/driver/sqlserver" "gorm.io/gorm" - "gorm.io/gorm/dialects/mssql" - "gorm.io/gorm/dialects/mysql" - "gorm.io/gorm/dialects/postgres" - "gorm.io/gorm/dialects/sqlite" "gorm.io/gorm/logger" + . "gorm.io/gorm/utils/tests" ) var DB *gorm.DB @@ -40,17 +41,17 @@ func OpenTestConnection() (db *gorm.DB, err error) { dbDSN = "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai" } db, err = gorm.Open(postgres.Open(dbDSN), &gorm.Config{}) - case "mssql": + case "sqlserver": // CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86'; // CREATE DATABASE gorm; // USE gorm; // CREATE USER gorm FROM LOGIN gorm; // sp_changedbowner 'gorm'; - log.Println("testing mssql...") + log.Println("testing sqlserver...") if dbDSN == "" { dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" } - db, err = gorm.Open(mssql.Open(dbDSN), &gorm.Config{}) + db, err = gorm.Open(sqlserver.Open(dbDSN), &gorm.Config{}) default: log.Println("testing sqlite3...") db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}) @@ -90,8 +91,3 @@ func RunMigrations() { } } } - -func Now() *time.Time { - now := time.Now() - return &now -} diff --git a/tests/transaction_test.go b/tests/transaction_test.go index 4ff1b485c..b810e3bb5 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -6,7 +6,7 @@ import ( "testing" "gorm.io/gorm" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestTransaction(t *testing.T) { diff --git a/tests/update_belongs_to_test.go b/tests/update_belongs_to_test.go index 7c578b387..47076e691 100644 --- a/tests/update_belongs_to_test.go +++ b/tests/update_belongs_to_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestUpdateBelongsTo(t *testing.T) { diff --git a/tests/update_has_many_test.go b/tests/update_has_many_test.go index 5501c5193..01ea2e3ae 100644 --- a/tests/update_has_many_test.go +++ b/tests/update_has_many_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestUpdateHasManyAssociations(t *testing.T) { diff --git a/tests/update_has_one_test.go b/tests/update_has_one_test.go index 721c302a0..7b29f424b 100644 --- a/tests/update_has_one_test.go +++ b/tests/update_has_one_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestUpdateHasOne(t *testing.T) { diff --git a/tests/update_many2many_test.go b/tests/update_many2many_test.go index 5548444fe..a46deeb04 100644 --- a/tests/update_many2many_test.go +++ b/tests/update_many2many_test.go @@ -3,7 +3,7 @@ package tests_test import ( "testing" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestUpdateMany2ManyAssociations(t *testing.T) { diff --git a/tests/update_test.go b/tests/update_test.go index aef7f4ce1..524e9ea6d 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -8,7 +8,7 @@ import ( "time" "gorm.io/gorm" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestUpdate(t *testing.T) { diff --git a/tests/upsert_test.go b/tests/upsert_test.go index 87b223b47..412be305e 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -5,7 +5,7 @@ import ( "time" "gorm.io/gorm/clause" - . "gorm.io/gorm/tests" + . "gorm.io/gorm/utils/tests" ) func TestUpsert(t *testing.T) { diff --git a/tests/dummy_dialecter.go b/utils/tests/dummy_dialecter.go similarity index 100% rename from tests/dummy_dialecter.go rename to utils/tests/dummy_dialecter.go diff --git a/tests/model.go b/utils/tests/models.go similarity index 100% rename from tests/model.go rename to utils/tests/models.go diff --git a/utils/tests/utils.go b/utils/tests/utils.go new file mode 100644 index 000000000..5248e6200 --- /dev/null +++ b/utils/tests/utils.go @@ -0,0 +1,112 @@ +package tests + +import ( + "database/sql/driver" + "fmt" + "go/ast" + "reflect" + "testing" + "time" + + "gorm.io/gorm/utils" +) + +func AssertObjEqual(t *testing.T, r, e interface{}, names ...string) { + for _, name := range names { + got := reflect.Indirect(reflect.ValueOf(r)).FieldByName(name).Interface() + expect := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface() + t.Run(name, func(t *testing.T) { + AssertEqual(t, got, expect) + }) + } +} + +func AssertEqual(t *testing.T, got, expect interface{}) { + if !reflect.DeepEqual(got, expect) { + isEqual := func() { + if curTime, ok := got.(time.Time); ok { + format := "2006-01-02T15:04:05Z07:00" + + if curTime.Round(time.Second).Format(format) != expect.(time.Time).Round(time.Second).Format(format) && curTime.Truncate(time.Second).Format(format) != expect.(time.Time).Truncate(time.Second).Format(format) { + t.Errorf("%v: expect: %v, got %v after time round", utils.FileWithLineNum(), expect.(time.Time), curTime) + } + } else if fmt.Sprint(got) != fmt.Sprint(expect) { + t.Errorf("%v: expect: %#v, got %#v", utils.FileWithLineNum(), expect, got) + } + } + + if fmt.Sprint(got) == fmt.Sprint(expect) { + return + } + + if reflect.Indirect(reflect.ValueOf(got)).IsValid() != reflect.Indirect(reflect.ValueOf(expect)).IsValid() { + t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got) + return + } + + if valuer, ok := got.(driver.Valuer); ok { + got, _ = valuer.Value() + } + + if valuer, ok := expect.(driver.Valuer); ok { + expect, _ = valuer.Value() + } + + if got != nil { + got = reflect.Indirect(reflect.ValueOf(got)).Interface() + } + + if expect != nil { + expect = reflect.Indirect(reflect.ValueOf(expect)).Interface() + } + + if reflect.ValueOf(got).IsValid() != reflect.ValueOf(expect).IsValid() { + t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got) + return + } + + if reflect.ValueOf(got).Kind() == reflect.Slice { + if reflect.ValueOf(expect).Kind() == reflect.Slice { + if reflect.ValueOf(got).Len() == reflect.ValueOf(expect).Len() { + for i := 0; i < reflect.ValueOf(got).Len(); i++ { + name := fmt.Sprintf(reflect.ValueOf(got).Type().Name()+" #%v", i) + t.Run(name, func(t *testing.T) { + AssertEqual(t, reflect.ValueOf(got).Index(i).Interface(), reflect.ValueOf(expect).Index(i).Interface()) + }) + } + } else { + name := reflect.ValueOf(got).Type().Elem().Name() + t.Errorf("%v expects length: %v, got %v", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len()) + } + return + } + } + + if reflect.ValueOf(got).Kind() == reflect.Struct { + if reflect.ValueOf(got).NumField() == reflect.ValueOf(expect).NumField() { + for i := 0; i < reflect.ValueOf(got).NumField(); i++ { + if fieldStruct := reflect.ValueOf(got).Type().Field(i); ast.IsExported(fieldStruct.Name) { + field := reflect.ValueOf(got).Field(i) + t.Run(fieldStruct.Name, func(t *testing.T) { + AssertEqual(t, field.Interface(), reflect.ValueOf(expect).Field(i).Interface()) + }) + } + } + return + } + } + + if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) { + got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface() + isEqual() + } else if reflect.ValueOf(expect).Type().ConvertibleTo(reflect.ValueOf(got).Type()) { + expect = reflect.ValueOf(got).Convert(reflect.ValueOf(got).Type()).Interface() + isEqual() + } + } +} + +func Now() *time.Time { + now := time.Now() + return &now +}