diff --git a/.gitignore b/.gitignore index a725465..9f11b75 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1 @@ -vendor/ \ No newline at end of file +.idea/ diff --git a/ddlmod.go b/ddlmod.go index 87886d2..39cc13a 100644 --- a/ddlmod.go +++ b/ddlmod.go @@ -13,15 +13,26 @@ import ( var ( sqliteSeparator = "`|\"|'|\t" - indexRegexp = regexp.MustCompile(fmt.Sprintf("(?is)CREATE(?: UNIQUE)? INDEX [%v]?[\\w\\d-]+[%v]? ON (.*)$", sqliteSeparator, sqliteSeparator)) - tableRegexp = regexp.MustCompile(fmt.Sprintf("(?is)(CREATE TABLE [%v]?[\\w\\d-]+[%v]?)(?: \\((.*)\\))?", sqliteSeparator, sqliteSeparator)) + indexRegexp = regexp.MustCompile(fmt.Sprintf(`(?is)CREATE(?: UNIQUE)? INDEX [%v]?[\w\d-]+[%v]? ON (.*)$`, sqliteSeparator, sqliteSeparator)) + tableRegexp = regexp.MustCompile(fmt.Sprintf(`(?is)(CREATE TABLE [%v]?[\w\d-]+[%v]?)(?:\s*\((.*)\))?`, sqliteSeparator, sqliteSeparator)) separatorRegexp = regexp.MustCompile(fmt.Sprintf("[%v]", sqliteSeparator)) - columnsRegexp = regexp.MustCompile(fmt.Sprintf("\\([%v]?([\\w\\d]+)[%v]?(?:,[%v]?([\\w\\d]+)[%v]){0,}\\)", sqliteSeparator, sqliteSeparator, sqliteSeparator, sqliteSeparator)) - columnRegexp = regexp.MustCompile(fmt.Sprintf("^[%v]?([\\w\\d]+)[%v]?\\s+([\\w\\(\\)\\d]+)(.*)$", sqliteSeparator, sqliteSeparator)) - defaultValueRegexp = regexp.MustCompile("(?i) DEFAULT \\(?(.+)?\\)?( |COLLATE|GENERATED|$)") + columnsRegexp = regexp.MustCompile(fmt.Sprintf(`[(,][%v]?(\w+)[%v]?`, sqliteSeparator, sqliteSeparator)) + columnRegexp = regexp.MustCompile(fmt.Sprintf(`^[%v]?([\w\d]+)[%v]?\s+([\w\(\)\d]+)(.*)$`, sqliteSeparator, sqliteSeparator)) + defaultValueRegexp = regexp.MustCompile(`(?i) DEFAULT \(?(.+)?\)?( |COLLATE|GENERATED|$)`) regRealDataType = regexp.MustCompile(`[^\d](\d+)[^\d]?`) ) +func getAllColumns(s string) []string { + allMatches := columnsRegexp.FindAllStringSubmatch(s, -1) + columns := make([]string, 0, len(allMatches)) + for _, matches := range allMatches { + if len(matches) > 1 { + columns = append(columns, matches[1]) + } + } + return columns +} + type ddl struct { head string fields []string @@ -98,15 +109,12 @@ func parseDDL(strs ...string) (*ddl, error) { } if strings.HasPrefix(fUpper, "PRIMARY KEY") { - matches := columnsRegexp.FindStringSubmatch(f) - if len(matches) > 1 { - for _, name := range matches[1:] { - for idx, column := range result.columns { - if column.NameValue.String == name { - column.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true} - result.columns[idx] = column - break - } + for _, name := range getAllColumns(f) { + for idx, column := range result.columns { + if column.NameValue.String == name { + column.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true} + result.columns[idx] = column + break } } } @@ -151,10 +159,10 @@ func parseDDL(strs ...string) (*ddl, error) { } } } else if matches := indexRegexp.FindStringSubmatch(str); len(matches) > 0 { - if columns := columnsRegexp.FindStringSubmatch(matches[1]); len(columns) == 1 { + for _, column := range getAllColumns(matches[1]) { for idx, c := range result.columns { - if c.NameValue.String == columns[0] { - c.UniqueValue = sql.NullBool{Bool: true, Valid: true} + if c.NameValue.String == column { + c.UniqueValue = sql.NullBool{Bool: strings.ToUpper(strings.Fields(str)[1]) == "UNIQUE", Valid: true} result.columns[idx] = c } } diff --git a/ddlmod_test.go b/ddlmod_test.go index edc1c47..763c3ce 100644 --- a/ddlmod_test.go +++ b/ddlmod_test.go @@ -20,7 +20,7 @@ func TestParseDDL(t *testing.T) { "CREATE UNIQUE INDEX `idx_profiles_refer` ON `profiles`(`text`)", }, 6, []migrator.ColumnType{ {NameValue: sql.NullString{String: "id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}}, - {NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 500, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(500)", Valid: true}, DefaultValueValue: sql.NullString{String: "hello", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, + {NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 500, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(500)", Valid: true}, DefaultValueValue: sql.NullString{String: "hello", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, {NameValue: sql.NullString{String: "age", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{String: "18", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, {NameValue: sql.NullString{String: "user_id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, }, @@ -56,11 +56,47 @@ func TestParseDDL(t *testing.T) { ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, - UniqueValue: sql.NullBool{Valid: true}, + UniqueValue: sql.NullBool{Bool: true, Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}, }, }, }, + { + "unique index", + []string{ + "CREATE TABLE `test-b` (`field` integer NOT NULL)", + "CREATE UNIQUE INDEX `idx_uq` ON `test-b`(`field`) WHERE field = 0", + }, + 1, + []migrator.ColumnType{ + { + NameValue: sql.NullString{String: "field", Valid: true}, + DataTypeValue: sql.NullString{String: "integer", Valid: true}, + ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, + PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, + UniqueValue: sql.NullBool{Bool: true, Valid: true}, + NullableValue: sql.NullBool{Bool: false, Valid: true}, + }, + }, + }, + { + "non-unique index", + []string{ + "CREATE TABLE `test-c` (`field` integer NOT NULL)", + "CREATE INDEX `idx_uq` ON `test-b`(`field`) WHERE field = 0", + }, + 1, + []migrator.ColumnType{ + { + NameValue: sql.NullString{String: "field", Valid: true}, + DataTypeValue: sql.NullString{String: "integer", Valid: true}, + ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, + PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, + UniqueValue: sql.NullBool{Bool: false, Valid: true}, + NullableValue: sql.NullBool{Bool: false, Valid: true}, + }, + }, + }, } for _, p := range params { @@ -80,6 +116,75 @@ func TestParseDDL(t *testing.T) { } } +func TestParseDDL_Whitespaces(t *testing.T) { + testColumns := []migrator.ColumnType{ + { + NameValue: sql.NullString{String: "id", Valid: true}, + DataTypeValue: sql.NullString{String: "integer", Valid: true}, + ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, + NullableValue: sql.NullBool{Bool: false, Valid: true}, + DefaultValueValue: sql.NullString{Valid: false}, + UniqueValue: sql.NullBool{Bool: true, Valid: true}, + PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true}, + }, + { + NameValue: sql.NullString{String: "dark_mode", Valid: true}, + DataTypeValue: sql.NullString{String: "numeric", Valid: true}, + ColumnTypeValue: sql.NullString{String: "numeric", Valid: true}, + NullableValue: sql.NullBool{Valid: true}, + DefaultValueValue: sql.NullString{String: "true", Valid: true}, + UniqueValue: sql.NullBool{Bool: false, Valid: true}, + PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, + }, + } + + params := []struct { + name string + sql []string + nFields int + columns []migrator.ColumnType + }{ + { + "with_newline", + []string{"CREATE TABLE `users`\n(\nid integer primary key unique,\ndark_mode numeric DEFAULT true)"}, + 2, + testColumns, + }, + { + "with_newline_2", + []string{"CREATE TABLE `users` (\n\nid integer primary key unique,\ndark_mode numeric DEFAULT true)"}, + 2, + testColumns, + }, + { + "with_missing_space", + []string{"CREATE TABLE `users`(id integer primary key unique, dark_mode numeric DEFAULT true)"}, + 2, + testColumns, + }, + { + "with_many_spaces", + []string{"CREATE TABLE `users` (id integer primary key unique, dark_mode numeric DEFAULT true)"}, + 2, + testColumns, + }, + } + for _, p := range params { + t.Run(p.name, func(t *testing.T) { + ddl, err := parseDDL(p.sql...) + + if err != nil { + panic(err.Error()) + } + + if len(ddl.fields) != p.nFields { + t.Fatalf("fields length doesn't match: expect: %v, got %v", p.nFields, len(ddl.fields)) + } + tests.AssertEqual(t, ddl.columns, p.columns) + }) + } +} + func TestParseDDL_error(t *testing.T) { params := []struct { name string diff --git a/error_translator.go b/error_translator.go new file mode 100644 index 0000000..f674cb1 --- /dev/null +++ b/error_translator.go @@ -0,0 +1,40 @@ +package sqlite + +import ( + "encoding/json" + + "gorm.io/gorm" +) + +// The error codes to map sqlite errors to gorm errors, here is a reference about error codes for sqlite https://www.sqlite.org/rescode.html. +var errCodes = map[int]error{ + 1555: gorm.ErrDuplicatedKey, + 2067: gorm.ErrDuplicatedKey, + 787: gorm.ErrForeignKeyViolated, +} + +type ErrMessage struct { + Code int `json:"Code"` + ExtendedCode int `json:"ExtendedCode"` + SystemErrno int `json:"SystemErrno"` +} + +// Translate it will translate the error to native gorm errors. +// We are not using go-sqlite3 error type intentionally here because it will need the CGO_ENABLED=1 and cross-C-compiler. +func (dialector Dialector) Translate(err error) error { + parsedErr, marshalErr := json.Marshal(err) + if marshalErr != nil { + return err + } + + var errMsg ErrMessage + unmarshalErr := json.Unmarshal(parsedErr, &errMsg) + if unmarshalErr != nil { + return err + } + + if translatedErr, found := errCodes[errMsg.ExtendedCode]; found { + return translatedErr + } + return err +} diff --git a/migrator.go b/migrator.go index 9d11cfc..fd2eeb4 100644 --- a/migrator.go +++ b/migrator.go @@ -271,28 +271,29 @@ func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statem func (m Migrator) CreateIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if idx := stmt.Schema.LookIndex(name); idx != nil { - opts := m.BuildIndexOptions(idx.Fields, stmt) - values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} + if stmt.Schema != nil { + if idx := stmt.Schema.LookIndex(name); idx != nil { + opts := m.BuildIndexOptions(idx.Fields, stmt) + values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} - createIndexSQL := "CREATE " - if idx.Class != "" { - createIndexSQL += idx.Class + " " - } - createIndexSQL += "INDEX ?" + createIndexSQL := "CREATE " + if idx.Class != "" { + createIndexSQL += idx.Class + " " + } + createIndexSQL += "INDEX ?" - if idx.Type != "" { - createIndexSQL += " USING " + idx.Type - } - createIndexSQL += " ON ??" + if idx.Type != "" { + createIndexSQL += " USING " + idx.Type + } + createIndexSQL += " ON ??" - if idx.Where != "" { - createIndexSQL += " WHERE " + idx.Where - } + if idx.Where != "" { + createIndexSQL += " WHERE " + idx.Where + } - return m.DB.Exec(createIndexSQL, values...).Error + return m.DB.Exec(createIndexSQL, values...).Error + } } - return fmt.Errorf("failed to create index with name %v", name) }) } @@ -300,8 +301,10 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { func (m Migrator) HasIndex(value interface{}, name string) bool { var count int m.RunWithValue(value, func(stmt *gorm.Statement) error { - if idx := stmt.Schema.LookIndex(name); idx != nil { - name = idx.Name + if stmt.Schema != nil { + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } } if name != "" { @@ -319,6 +322,9 @@ func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error var sql string m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, oldName).Row().Scan(&sql) if sql != "" { + if err := m.DropIndex(value, oldName); err != nil { + return err + } return m.DB.Exec(strings.Replace(sql, oldName, newName, 1)).Error } return fmt.Errorf("failed to find index with name %v", oldName) @@ -327,8 +333,10 @@ func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error func (m Migrator) DropIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if idx := stmt.Schema.LookIndex(name); idx != nil { - name = idx.Name + if stmt.Schema != nil { + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } } return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error @@ -390,7 +398,7 @@ func (m Migrator) recreateTable(value interface{}, tablePtr *string, return nil } - tableReg, err := regexp.Compile(" ('|`|\"| )" + table + "('|`|\"| ) ") + tableReg, err := regexp.Compile("\\s*('|`|\")?\\b" + table + "\\b('|`|\")?\\s*") if err != nil { return err } diff --git a/sqlite.go b/sqlite.go index f4207fa..71eddf2 100644 --- a/sqlite.go +++ b/sqlite.go @@ -180,7 +180,12 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string { case schema.String: return "text" case schema.Time: - return "datetime" + // Distinguish between schema.Time and tag time + if val, ok := field.TagSettings["TYPE"]; ok { + return val + } else { + return "datetime" + } case schema.Bytes: return "blob" }