Skip to content

Commit

Permalink
Merge branch 'orig'
Browse files Browse the repository at this point in the history
  • Loading branch information
glebarez committed Jul 9, 2023
2 parents 963e2cc + 397ec6f commit 205217f
Show file tree
Hide file tree
Showing 6 changed files with 209 additions and 43 deletions.
2 changes: 1 addition & 1 deletion .gitignore
@@ -1 +1 @@
vendor/
.idea/
42 changes: 25 additions & 17 deletions ddlmod.go
Expand Up @@ -13,15 +13,26 @@ import (

var (
sqliteSeparator = "`|\"|'|\t"
indexRegexp = regexp.MustCompile(fmt.Sprintf("(?is)CREATE(?: UNIQUE)? INDEX [%v]?[\\w\\d-]+[%v]? ON (.*)$", sqliteSeparator, sqliteSeparator))
tableRegexp = regexp.MustCompile(fmt.Sprintf("(?is)(CREATE TABLE [%v]?[\\w\\d-]+[%v]?)(?: \\((.*)\\))?", sqliteSeparator, sqliteSeparator))
indexRegexp = regexp.MustCompile(fmt.Sprintf(`(?is)CREATE(?: UNIQUE)? INDEX [%v]?[\w\d-]+[%v]? ON (.*)$`, sqliteSeparator, sqliteSeparator))
tableRegexp = regexp.MustCompile(fmt.Sprintf(`(?is)(CREATE TABLE [%v]?[\w\d-]+[%v]?)(?:\s*\((.*)\))?`, sqliteSeparator, sqliteSeparator))
separatorRegexp = regexp.MustCompile(fmt.Sprintf("[%v]", sqliteSeparator))
columnsRegexp = regexp.MustCompile(fmt.Sprintf("\\([%v]?([\\w\\d]+)[%v]?(?:,[%v]?([\\w\\d]+)[%v]){0,}\\)", sqliteSeparator, sqliteSeparator, sqliteSeparator, sqliteSeparator))
columnRegexp = regexp.MustCompile(fmt.Sprintf("^[%v]?([\\w\\d]+)[%v]?\\s+([\\w\\(\\)\\d]+)(.*)$", sqliteSeparator, sqliteSeparator))
defaultValueRegexp = regexp.MustCompile("(?i) DEFAULT \\(?(.+)?\\)?( |COLLATE|GENERATED|$)")
columnsRegexp = regexp.MustCompile(fmt.Sprintf(`[(,][%v]?(\w+)[%v]?`, sqliteSeparator, sqliteSeparator))
columnRegexp = regexp.MustCompile(fmt.Sprintf(`^[%v]?([\w\d]+)[%v]?\s+([\w\(\)\d]+)(.*)$`, sqliteSeparator, sqliteSeparator))
defaultValueRegexp = regexp.MustCompile(`(?i) DEFAULT \(?(.+)?\)?( |COLLATE|GENERATED|$)`)
regRealDataType = regexp.MustCompile(`[^\d](\d+)[^\d]?`)
)

func getAllColumns(s string) []string {
allMatches := columnsRegexp.FindAllStringSubmatch(s, -1)
columns := make([]string, 0, len(allMatches))
for _, matches := range allMatches {
if len(matches) > 1 {
columns = append(columns, matches[1])
}
}
return columns
}

type ddl struct {
head string
fields []string
Expand Down Expand Up @@ -98,15 +109,12 @@ func parseDDL(strs ...string) (*ddl, error) {
}

if strings.HasPrefix(fUpper, "PRIMARY KEY") {
matches := columnsRegexp.FindStringSubmatch(f)
if len(matches) > 1 {
for _, name := range matches[1:] {
for idx, column := range result.columns {
if column.NameValue.String == name {
column.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true}
result.columns[idx] = column
break
}
for _, name := range getAllColumns(f) {
for idx, column := range result.columns {
if column.NameValue.String == name {
column.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true}
result.columns[idx] = column
break
}
}
}
Expand Down Expand Up @@ -151,10 +159,10 @@ func parseDDL(strs ...string) (*ddl, error) {
}
}
} else if matches := indexRegexp.FindStringSubmatch(str); len(matches) > 0 {
if columns := columnsRegexp.FindStringSubmatch(matches[1]); len(columns) == 1 {
for _, column := range getAllColumns(matches[1]) {
for idx, c := range result.columns {
if c.NameValue.String == columns[0] {
c.UniqueValue = sql.NullBool{Bool: true, Valid: true}
if c.NameValue.String == column {
c.UniqueValue = sql.NullBool{Bool: strings.ToUpper(strings.Fields(str)[1]) == "UNIQUE", Valid: true}
result.columns[idx] = c
}
}
Expand Down
109 changes: 107 additions & 2 deletions ddlmod_test.go
Expand Up @@ -20,7 +20,7 @@ func TestParseDDL(t *testing.T) {
"CREATE UNIQUE INDEX `idx_profiles_refer` ON `profiles`(`text`)",
}, 6, []migrator.ColumnType{
{NameValue: sql.NullString{String: "id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}},
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 500, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(500)", Valid: true}, DefaultValueValue: sql.NullString{String: "hello", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 500, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(500)", Valid: true}, DefaultValueValue: sql.NullString{String: "hello", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "age", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{String: "18", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "user_id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
},
Expand Down Expand Up @@ -56,11 +56,47 @@ func TestParseDDL(t *testing.T) {
ColumnTypeValue: sql.NullString{String: "int", Valid: true},
NullableValue: sql.NullBool{Bool: false, Valid: true},
DefaultValueValue: sql.NullString{Valid: false},
UniqueValue: sql.NullBool{Valid: true},
UniqueValue: sql.NullBool{Bool: true, Valid: true},
PrimaryKeyValue: sql.NullBool{Valid: true},
},
},
},
{
"unique index",
[]string{
"CREATE TABLE `test-b` (`field` integer NOT NULL)",
"CREATE UNIQUE INDEX `idx_uq` ON `test-b`(`field`) WHERE field = 0",
},
1,
[]migrator.ColumnType{
{
NameValue: sql.NullString{String: "field", Valid: true},
DataTypeValue: sql.NullString{String: "integer", Valid: true},
ColumnTypeValue: sql.NullString{String: "integer", Valid: true},
PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},
UniqueValue: sql.NullBool{Bool: true, Valid: true},
NullableValue: sql.NullBool{Bool: false, Valid: true},
},
},
},
{
"non-unique index",
[]string{
"CREATE TABLE `test-c` (`field` integer NOT NULL)",
"CREATE INDEX `idx_uq` ON `test-b`(`field`) WHERE field = 0",
},
1,
[]migrator.ColumnType{
{
NameValue: sql.NullString{String: "field", Valid: true},
DataTypeValue: sql.NullString{String: "integer", Valid: true},
ColumnTypeValue: sql.NullString{String: "integer", Valid: true},
PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},
UniqueValue: sql.NullBool{Bool: false, Valid: true},
NullableValue: sql.NullBool{Bool: false, Valid: true},
},
},
},
}

for _, p := range params {
Expand All @@ -80,6 +116,75 @@ func TestParseDDL(t *testing.T) {
}
}

func TestParseDDL_Whitespaces(t *testing.T) {
testColumns := []migrator.ColumnType{
{
NameValue: sql.NullString{String: "id", Valid: true},
DataTypeValue: sql.NullString{String: "integer", Valid: true},
ColumnTypeValue: sql.NullString{String: "integer", Valid: true},
NullableValue: sql.NullBool{Bool: false, Valid: true},
DefaultValueValue: sql.NullString{Valid: false},
UniqueValue: sql.NullBool{Bool: true, Valid: true},
PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true},
},
{
NameValue: sql.NullString{String: "dark_mode", Valid: true},
DataTypeValue: sql.NullString{String: "numeric", Valid: true},
ColumnTypeValue: sql.NullString{String: "numeric", Valid: true},
NullableValue: sql.NullBool{Valid: true},
DefaultValueValue: sql.NullString{String: "true", Valid: true},
UniqueValue: sql.NullBool{Bool: false, Valid: true},
PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},
},
}

params := []struct {
name string
sql []string
nFields int
columns []migrator.ColumnType
}{
{
"with_newline",
[]string{"CREATE TABLE `users`\n(\nid integer primary key unique,\ndark_mode numeric DEFAULT true)"},
2,
testColumns,
},
{
"with_newline_2",
[]string{"CREATE TABLE `users` (\n\nid integer primary key unique,\ndark_mode numeric DEFAULT true)"},
2,
testColumns,
},
{
"with_missing_space",
[]string{"CREATE TABLE `users`(id integer primary key unique, dark_mode numeric DEFAULT true)"},
2,
testColumns,
},
{
"with_many_spaces",
[]string{"CREATE TABLE `users` (id integer primary key unique, dark_mode numeric DEFAULT true)"},
2,
testColumns,
},
}
for _, p := range params {
t.Run(p.name, func(t *testing.T) {
ddl, err := parseDDL(p.sql...)

if err != nil {
panic(err.Error())
}

if len(ddl.fields) != p.nFields {
t.Fatalf("fields length doesn't match: expect: %v, got %v", p.nFields, len(ddl.fields))
}
tests.AssertEqual(t, ddl.columns, p.columns)
})
}
}

func TestParseDDL_error(t *testing.T) {
params := []struct {
name string
Expand Down
40 changes: 40 additions & 0 deletions error_translator.go
@@ -0,0 +1,40 @@
package sqlite

import (
"encoding/json"

"gorm.io/gorm"
)

// The error codes to map sqlite errors to gorm errors, here is a reference about error codes for sqlite https://www.sqlite.org/rescode.html.
var errCodes = map[int]error{
1555: gorm.ErrDuplicatedKey,
2067: gorm.ErrDuplicatedKey,
787: gorm.ErrForeignKeyViolated,
}

type ErrMessage struct {
Code int `json:"Code"`
ExtendedCode int `json:"ExtendedCode"`
SystemErrno int `json:"SystemErrno"`
}

// Translate it will translate the error to native gorm errors.
// We are not using go-sqlite3 error type intentionally here because it will need the CGO_ENABLED=1 and cross-C-compiler.
func (dialector Dialector) Translate(err error) error {
parsedErr, marshalErr := json.Marshal(err)
if marshalErr != nil {
return err
}

var errMsg ErrMessage
unmarshalErr := json.Unmarshal(parsedErr, &errMsg)
if unmarshalErr != nil {
return err
}

if translatedErr, found := errCodes[errMsg.ExtendedCode]; found {
return translatedErr
}
return err
}
52 changes: 30 additions & 22 deletions migrator.go
Expand Up @@ -271,37 +271,40 @@ func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statem

func (m Migrator) CreateIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if idx := stmt.Schema.LookIndex(name); idx != nil {
opts := m.BuildIndexOptions(idx.Fields, stmt)
values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts}
if stmt.Schema != nil {
if idx := stmt.Schema.LookIndex(name); idx != nil {
opts := m.BuildIndexOptions(idx.Fields, stmt)
values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts}

createIndexSQL := "CREATE "
if idx.Class != "" {
createIndexSQL += idx.Class + " "
}
createIndexSQL += "INDEX ?"
createIndexSQL := "CREATE "
if idx.Class != "" {
createIndexSQL += idx.Class + " "
}
createIndexSQL += "INDEX ?"

if idx.Type != "" {
createIndexSQL += " USING " + idx.Type
}
createIndexSQL += " ON ??"
if idx.Type != "" {
createIndexSQL += " USING " + idx.Type
}
createIndexSQL += " ON ??"

if idx.Where != "" {
createIndexSQL += " WHERE " + idx.Where
}
if idx.Where != "" {
createIndexSQL += " WHERE " + idx.Where
}

return m.DB.Exec(createIndexSQL, values...).Error
return m.DB.Exec(createIndexSQL, values...).Error
}
}

return fmt.Errorf("failed to create index with name %v", name)
})
}

func (m Migrator) HasIndex(value interface{}, name string) bool {
var count int
m.RunWithValue(value, func(stmt *gorm.Statement) error {
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
if stmt.Schema != nil {
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
}

if name != "" {
Expand All @@ -319,6 +322,9 @@ func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error
var sql string
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, oldName).Row().Scan(&sql)
if sql != "" {
if err := m.DropIndex(value, oldName); err != nil {
return err
}
return m.DB.Exec(strings.Replace(sql, oldName, newName, 1)).Error
}
return fmt.Errorf("failed to find index with name %v", oldName)
Expand All @@ -327,8 +333,10 @@ func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error

func (m Migrator) DropIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
if stmt.Schema != nil {
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
}

return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error
Expand Down Expand Up @@ -390,7 +398,7 @@ func (m Migrator) recreateTable(value interface{}, tablePtr *string,
return nil
}

tableReg, err := regexp.Compile(" ('|`|\"| )" + table + "('|`|\"| ) ")
tableReg, err := regexp.Compile("\\s*('|`|\")?\\b" + table + "\\b('|`|\")?\\s*")
if err != nil {
return err
}
Expand Down
7 changes: 6 additions & 1 deletion sqlite.go
Expand Up @@ -180,7 +180,12 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string {
case schema.String:
return "text"
case schema.Time:
return "datetime"
// Distinguish between schema.Time and tag time
if val, ok := field.TagSettings["TYPE"]; ok {
return val
} else {
return "datetime"
}
case schema.Bytes:
return "blob"
}
Expand Down

0 comments on commit 205217f

Please sign in to comment.