From b35958361a290fdd985611cfd549ab555af54b4f Mon Sep 17 00:00:00 2001 From: mackee Date: Fri, 19 Nov 2021 19:01:27 +0900 Subject: [PATCH 1/5] add feature: support bulk insert --- _example/user.gen.go | 86 +++++++++++++++++++++++++++-- _example/user_hook.go | 3 - _example/user_item.gen.go | 86 +++++++++++++++++++++++++++-- _example/user_test.go | 20 +++---- _example/user_withmysql_test.go | 8 ++- column.go | 80 ++++++++++++++++++++++++++- generator_go116.go | 2 +- template/insert.tmpl | 98 +++++++++++++++++++++++++++++++-- template/table.tmpl | 1 + 9 files changed, 354 insertions(+), 30 deletions(-) diff --git a/_example/user.gen.go b/_example/user.gen.go index 0ce425d..c76405e 100644 --- a/_example/user.gen.go +++ b/_example/user.gen.go @@ -3,6 +3,7 @@ package example import ( "context" + "fmt" "strconv" "strings" @@ -681,14 +682,14 @@ func (q userInsertSQL) ValueUpdatedAt(v mysql.NullTime) userInsertSQL { } func (q userInsertSQL) ToSql() (string, []interface{}, error) { - query, vs, err := q.toSql() + query, vs, err := q.userInsertSQLToSql() if err != nil { return "", []interface{}{}, err } return query + ";", vs, nil } -func (q userInsertSQL) toSql() (string, []interface{}, error) { +func (q userInsertSQL) userInsertSQLToSql() (string, []interface{}, error) { var err error var s interface{} = User{} if t, ok := s.(userDefaultInsertHooker); ok { @@ -750,8 +751,12 @@ type userDefaultInsertHooker interface { DefaultInsertHook(userInsertSQL) (userInsertSQL, error) } +type userInsertSQLToSqler interface { + userInsertSQLToSql() (string, []interface{}, error) +} + type userInsertOnDuplicateKeyUpdateSQL struct { - insertSQL userInsertSQL + insertSQL userInsertSQLToSqler onDuplicateKeyUpdateMap sqlla.SetMap } @@ -870,7 +875,7 @@ func (q userInsertOnDuplicateKeyUpdateSQL) ToSql() (string, []interface{}, error } } - query, vs, err := q.insertSQL.toSql() + query, vs, err := q.insertSQL.userInsertSQLToSql() if err != nil { return "", []interface{}{}, err } @@ -905,6 +910,79 @@ type userDefaultInsertOnDuplicateKeyUpdateHooker interface { DefaultInsertOnDuplicateKeyUpdateHook(userInsertOnDuplicateKeyUpdateSQL) (userInsertOnDuplicateKeyUpdateSQL, error) } +type userBulkInsertSQL struct { + insertSQLs []userInsertSQL +} + +func NewUserBulkInsertSQL(insertSQLs ...userInsertSQL) userBulkInsertSQL { + return userBulkInsertSQL{ + insertSQLs: insertSQLs, + } +} + +func (q userBulkInsertSQL) userInsertSQLToSql() (string, []interface{}, error) { + if len(q.insertSQLs) == 0 { + return "", []interface{}{}, fmt.Errorf("sqlla: This userBulkInsertSQL does not have InsertSQL") + } + iqs := make([]userInsertSQL, len(q.insertSQLs)) + copy(iqs, q.insertSQLs) + + var s interface{} = User{} + if t, ok := s.(userDefaultInsertHooker); ok { + for i, iq := range iqs { + var err error + iq, err = t.DefaultInsertHook(iq) + if err != nil { + return "", []interface{}{}, err + } + iqs[i] = iq + } + } + + sms := make(sqlla.SetMaps, 0, len(q.insertSQLs)) + for _, iq := range q.insertSQLs { + sms = append(sms, iq.setMap) + } + + query, vs, err := sms.ToInsertSql() + if err != nil { + return "", []interface{}{}, err + } + + return "INSERT INTO " + query, vs, nil +} + +func (q userBulkInsertSQL) ToSql() (string, []interface{}, error) { + query, vs, err := q.userInsertSQLToSql() + if err != nil { + return "", []interface{}{}, err + } + return query + ";", vs, nil +} + +func (q userBulkInsertSQL) OnDuplicateKeyUpdate() userInsertOnDuplicateKeyUpdateSQL { + return userInsertOnDuplicateKeyUpdateSQL{ + insertSQL: q, + onDuplicateKeyUpdateMap: sqlla.SetMap{}, + } +} + +func (q userBulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (User, error) { + query, args, err := q.ToSql() + if err != nil { + return User{}, err + } + result, err := db.ExecContext(ctx, query, args...) + if err != nil { + return User{}, err + } + id, err := result.LastInsertId() + if err != nil { + return User{}, err + } + return NewUserSQL().Select().PkColumn(id).SingleContext(ctx, db) +} + type userDeleteSQL struct { userSQL } diff --git a/_example/user_hook.go b/_example/user_hook.go index 48d625c..36a0108 100644 --- a/_example/user_hook.go +++ b/_example/user_hook.go @@ -12,9 +12,6 @@ func (u User) DefaultInsertHook(q userInsertSQL) (userInsertSQL, error) { } func (u User) DefaultInsertOnDuplicateKeyUpdateHook(q userInsertOnDuplicateKeyUpdateSQL) (userInsertOnDuplicateKeyUpdateSQL, error) { - now := time.Now() - q.insertSQL = q.insertSQL.ValueUpdatedAt(mysql.NullTime{Time: now, Valid: true}) - return q.SameOnUpdateUpdatedAt(), nil } diff --git a/_example/user_item.gen.go b/_example/user_item.gen.go index 2ab147c..18f5a33 100644 --- a/_example/user_item.gen.go +++ b/_example/user_item.gen.go @@ -3,6 +3,7 @@ package example import ( "context" + "fmt" "strconv" "strings" @@ -614,14 +615,14 @@ func (q userItemInsertSQL) ValueUsedAt(v mysql.NullTime) userItemInsertSQL { } func (q userItemInsertSQL) ToSql() (string, []interface{}, error) { - query, vs, err := q.toSql() + query, vs, err := q.userItemInsertSQLToSql() if err != nil { return "", []interface{}{}, err } return query + ";", vs, nil } -func (q userItemInsertSQL) toSql() (string, []interface{}, error) { +func (q userItemInsertSQL) userItemInsertSQLToSql() (string, []interface{}, error) { var err error var s interface{} = UserItem{} if t, ok := s.(userItemDefaultInsertHooker); ok { @@ -683,8 +684,12 @@ type userItemDefaultInsertHooker interface { DefaultInsertHook(userItemInsertSQL) (userItemInsertSQL, error) } +type userItemInsertSQLToSqler interface { + userItemInsertSQLToSql() (string, []interface{}, error) +} + type userItemInsertOnDuplicateKeyUpdateSQL struct { - insertSQL userItemInsertSQL + insertSQL userItemInsertSQLToSqler onDuplicateKeyUpdateMap sqlla.SetMap } @@ -788,7 +793,7 @@ func (q userItemInsertOnDuplicateKeyUpdateSQL) ToSql() (string, []interface{}, e } } - query, vs, err := q.insertSQL.toSql() + query, vs, err := q.insertSQL.userItemInsertSQLToSql() if err != nil { return "", []interface{}{}, err } @@ -823,6 +828,79 @@ type userItemDefaultInsertOnDuplicateKeyUpdateHooker interface { DefaultInsertOnDuplicateKeyUpdateHook(userItemInsertOnDuplicateKeyUpdateSQL) (userItemInsertOnDuplicateKeyUpdateSQL, error) } +type userItemBulkInsertSQL struct { + insertSQLs []userItemInsertSQL +} + +func NewUserItemBulkInsertSQL(insertSQLs ...userItemInsertSQL) userItemBulkInsertSQL { + return userItemBulkInsertSQL{ + insertSQLs: insertSQLs, + } +} + +func (q userItemBulkInsertSQL) userItemInsertSQLToSql() (string, []interface{}, error) { + if len(q.insertSQLs) == 0 { + return "", []interface{}{}, fmt.Errorf("sqlla: This userItemBulkInsertSQL does not have InsertSQL") + } + iqs := make([]userItemInsertSQL, len(q.insertSQLs)) + copy(iqs, q.insertSQLs) + + var s interface{} = UserItem{} + if t, ok := s.(userItemDefaultInsertHooker); ok { + for i, iq := range iqs { + var err error + iq, err = t.DefaultInsertHook(iq) + if err != nil { + return "", []interface{}{}, err + } + iqs[i] = iq + } + } + + sms := make(sqlla.SetMaps, 0, len(q.insertSQLs)) + for _, iq := range q.insertSQLs { + sms = append(sms, iq.setMap) + } + + query, vs, err := sms.ToInsertSql() + if err != nil { + return "", []interface{}{}, err + } + + return "INSERT INTO " + query, vs, nil +} + +func (q userItemBulkInsertSQL) ToSql() (string, []interface{}, error) { + query, vs, err := q.userItemInsertSQLToSql() + if err != nil { + return "", []interface{}{}, err + } + return query + ";", vs, nil +} + +func (q userItemBulkInsertSQL) OnDuplicateKeyUpdate() userItemInsertOnDuplicateKeyUpdateSQL { + return userItemInsertOnDuplicateKeyUpdateSQL{ + insertSQL: q, + onDuplicateKeyUpdateMap: sqlla.SetMap{}, + } +} + +func (q userItemBulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (UserItem, error) { + query, args, err := q.ToSql() + if err != nil { + return UserItem{}, err + } + result, err := db.ExecContext(ctx, query, args...) + if err != nil { + return UserItem{}, err + } + id, err := result.LastInsertId() + if err != nil { + return UserItem{}, err + } + return NewUserItemSQL().Select().PkColumn(id).SingleContext(ctx, db) +} + type userItemDeleteSQL struct { userItemSQL } diff --git a/_example/user_test.go b/_example/user_test.go index 577fd0c..c46971b 100644 --- a/_example/user_test.go +++ b/_example/user_test.go @@ -190,24 +190,24 @@ func TestInsert(t *testing.T) { if err != nil { t.Error("unexpected error:", err) } - switch query { - case "INSERT INTO user (`name`,`created_at`) VALUES(?,?);": - if !reflect.DeepEqual(args[0], "hogehoge") { - t.Error("unexpected args:", args) - } - case "INSERT INTO user (`created_at`,`name`) VALUES(?,?);": - if !reflect.DeepEqual(args[1], "hogehoge") { - t.Error("unexpected args:", args) - } - default: + expected := "INSERT INTO user (`created_at`,`name`) VALUES(?,?);" + if query != expected { t.Error("unexpected query:", query) } + if !reflect.DeepEqual(args[1], "hogehoge") { + t.Error("unexpected args:", args) + } } func TestInsertOnDuplicateKeyUpdate(t *testing.T) { + now := time.Now() q := NewUserSQL().Insert(). ValueID(1). ValueName("hogehoge"). + ValueUpdatedAt(mysql.NullTime{ + Valid: true, + Time: now, + }). OnDuplicateKeyUpdate(). ValueOnUpdateAge(sql.NullInt64{ Valid: true, diff --git a/_example/user_withmysql_test.go b/_example/user_withmysql_test.go index 4896412..6adff69 100644 --- a/_example/user_withmysql_test.go +++ b/_example/user_withmysql_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql" "github.com/mackee/go-sqlla/v2" "github.com/ory/dockertest/v3" @@ -84,12 +85,14 @@ func TestMain(m *testing.M) { func TestInsertOnDuplicateKeyUpdate__WithMySQL(t *testing.T) { ctx := context.Background() + now1 := time.Now() q1 := NewUserSQL().Insert(). ValueName("hogehoge"). ValueRate(3.14). ValueIconImage([]byte{}). - ValueAge(sql.NullInt64{Valid: true, Int64: 17}) + ValueAge(sql.NullInt64{Valid: true, Int64: 17}). + ValueUpdatedAt(mysql.NullTime{Valid: true, Time: now1}) query, args, _ := q1.ToSql() t.Logf("query=%s, args=%+v", query, args) r1, err := q1.ExecContext(ctx, db) @@ -97,12 +100,13 @@ func TestInsertOnDuplicateKeyUpdate__WithMySQL(t *testing.T) { t.Fatal("unexpected error:", err) } - time.Sleep(1 * time.Second) + now2 := now1.Add(1 * time.Second) q2 := NewUserSQL().Insert(). ValueName("hogehoge"). ValueAge(sql.NullInt64{Valid: true, Int64: 17}). ValueIconImage([]byte{}). + ValueUpdatedAt(mysql.NullTime{Valid: true, Time: now2}). OnDuplicateKeyUpdate(). RawValueOnUpdateAge(sqlla.SetMapRawValue("`age` + 1")) r2, err := q2.ExecContext(ctx, db) diff --git a/column.go b/column.go index df87197..4eabc62 100644 --- a/column.go +++ b/column.go @@ -1,5 +1,11 @@ package sqlla +import ( + "fmt" + "sort" + "strings" +) + type Where []Expr func (wh Where) ToSql() (string, []interface{}, error) { @@ -29,6 +35,41 @@ type SetMapRawValue string type SetMap map[string]interface{} +func (sm SetMap) NewIterator() *SetMapIterator { + keys := make(sort.StringSlice, 0, len(sm)) + for k := range sm { + keys = append(keys, k) + } + sort.Sort(keys) + return &SetMapIterator{ + sm: sm, + keys: keys, + cursor: -1, + } +} + +type SetMapIterator struct { + sm SetMap + cursor int + keys []string +} + +func (s *SetMapIterator) Iterate() bool { + s.cursor++ + if len(s.keys)-1 < s.cursor { + return false + } + return true +} + +func (s *SetMapIterator) Key() string { + return s.keys[s.cursor] +} + +func (s *SetMapIterator) Value() interface{} { + return s.sm[s.keys[s.cursor]] +} + func (sm SetMap) ToUpdateSql() (string, []interface{}, error) { var setColumns string vs := []interface{}{} @@ -49,11 +90,13 @@ func (sm SetMap) ToUpdateSql() (string, []interface{}, error) { return setColumns, vs, nil } -func (sm SetMap) ToInsertSql() (string, []interface{}, error) { +func (sm SetMap) ToInsertColumnsAndValues() (string, string, []interface{}) { qs, ps := "(", "(" vs := []interface{}{} columnCount := 0 - for k, v := range sm { + iter := sm.NewIterator() + for iter.Iterate() { + k, v := iter.Key(), iter.Value() if columnCount != 0 { qs += "," ps += "," @@ -65,6 +108,39 @@ func (sm SetMap) ToInsertSql() (string, []interface{}, error) { } qs += ")" ps += ")" + return qs, ps, vs +} +func (sm SetMap) ToInsertSql() (string, []interface{}, error) { + qs, ps, vs := sm.ToInsertColumnsAndValues() return qs + " VALUES" + ps, vs, nil } + +type SetMaps []SetMap + +func (s SetMaps) ToInsertSql() (string, []interface{}, error) { + if len(s) == 0 { + return "", nil, fmt.Errorf("sqlla: SetMaps is empty") + } + + first := s[0] + columns, values, vs := first.ToInsertColumnsAndValues() + var b strings.Builder + if _, err := b.WriteString(values); err != nil { + return "", nil, err + } + for i, _s := range s[1:] { + _columns, _values, _vs := _s.ToInsertColumnsAndValues() + if columns != _columns { + return "", nil, fmt.Errorf("sqlla: two SetMap are not match keys: [0]=%s, [%d]=%s", columns, i, _columns) + } + vs = append(vs, _vs...) + if _, err := b.WriteString(","); err != nil { + return "", nil, err + } + if _, err := b.WriteString(_values); err != nil { + return "", nil, err + } + } + return columns + " VALUES" + b.String(), vs, nil +} diff --git a/generator_go116.go b/generator_go116.go index e8d31d6..aab93b7 100644 --- a/generator_go116.go +++ b/generator_go116.go @@ -67,7 +67,7 @@ func WriteCode(w io.Writer, table *Table) error { } bs, err := format.Source(buf.Bytes()) if err != nil { - return errors.Wrapf(err, "fail to format") + return errors.Wrapf(err, "fail to format: table=%s", table.Name) } _, err = w.Write(bs) return err diff --git a/template/insert.tmpl b/template/insert.tmpl index 6ffa87e..5f5701d 100644 --- a/template/insert.tmpl +++ b/template/insert.tmpl @@ -16,14 +16,14 @@ func (q {{ $camelName }}SQL) Insert() {{ $camelName }}InsertSQL { {{ range .Columns }}{{ template "InsertColumn" . }}{{ end }} func (q {{ $camelName }}InsertSQL) ToSql() (string, []interface{}, error) { - query, vs, err := q.toSql() + query, vs, err := q.{{ $camelName }}InsertSQLToSql() if err != nil { return "", []interface{}{}, err } return query + ";", vs, nil } -func (q {{ $camelName }}InsertSQL) toSql() (string, []interface{}, error) { +func (q {{ $camelName }}InsertSQL) {{ $camelName }}InsertSQLToSql() (string, []interface{}, error) { var err error var s interface{} = {{ .StructName }}{} if t, ok := s.({{ $camelName }}DefaultInsertHooker); ok { @@ -110,8 +110,12 @@ type {{ $camelName }}DefaultInsertHooker interface { DefaultInsertHook({{ $camelName }}InsertSQL) ({{ $camelName }}InsertSQL, error) } +type {{ $camelName }}InsertSQLToSqler interface { + {{ $camelName }}InsertSQLToSql() (string, []interface{}, error) +} + type {{ $camelName }}InsertOnDuplicateKeyUpdateSQL struct { - insertSQL {{ $camelName }}InsertSQL + insertSQL {{ $camelName }}InsertSQLToSqler onDuplicateKeyUpdateMap sqlla.SetMap } @@ -127,7 +131,7 @@ func (q {{ $camelName }}InsertOnDuplicateKeyUpdateSQL) ToSql() (string, []interf } } - query, vs, err := q.insertSQL.toSql() + query, vs, err := q.insertSQL.{{ $camelName }}InsertSQLToSql() if err != nil { return "", []interface{}{}, err } @@ -173,4 +177,90 @@ func (q {{ $camelName }}InsertOnDuplicateKeyUpdateSQL) ExecContext(ctx context.C type {{ $camelName }}DefaultInsertOnDuplicateKeyUpdateHooker interface { DefaultInsertOnDuplicateKeyUpdateHook({{ $camelName }}InsertOnDuplicateKeyUpdateSQL) ({{ $camelName }}InsertOnDuplicateKeyUpdateSQL, error) } + +type {{ $camelName }}BulkInsertSQL struct { + insertSQLs []{{ $camelName }}InsertSQL +} + +func New{{ (.Name | toCamel) }}BulkInsertSQL(insertSQLs ...{{ $camelName }}InsertSQL) {{ $camelName }}BulkInsertSQL { + return {{ $camelName }}BulkInsertSQL{ + insertSQLs: insertSQLs, + } +} + +func (q {{ $camelName }}BulkInsertSQL) {{ $camelName }}InsertSQLToSql() (string, []interface{}, error) { + if len(q.insertSQLs) == 0 { + return "", []interface{}{}, fmt.Errorf("sqlla: This {{ $camelName }}BulkInsertSQL does not have InsertSQL") + } + iqs := make([]{{ $camelName }}InsertSQL, len(q.insertSQLs)) + copy(iqs, q.insertSQLs) + + var s interface{} = {{ .StructName }}{} + if t, ok := s.({{ $camelName }}DefaultInsertHooker); ok { + for i, iq := range iqs { + var err error + iq, err = t.DefaultInsertHook(iq) + if err != nil { + return "", []interface{}{}, err + } + iqs[i] = iq + } + } + + sms := make(sqlla.SetMaps, 0, len(q.insertSQLs)) + for _, iq := range q.insertSQLs { + sms = append(sms, iq.setMap) + } + + query, vs, err := sms.ToInsertSql() + if err != nil { + return "", []interface{}{}, err + } + + return "INSERT INTO " + query, vs, nil +} + +func (q {{ $camelName }}BulkInsertSQL) ToSql() (string, []interface{}, error) { + query, vs, err := q.{{ $camelName }}InsertSQLToSql() + if err != nil { + return "", []interface{}{}, err + } + return query + ";", vs, nil +} + +func (q {{ $camelName }}BulkInsertSQL) OnDuplicateKeyUpdate() {{ $camelName }}InsertOnDuplicateKeyUpdateSQL { + return {{ $camelName }}InsertOnDuplicateKeyUpdateSQL{ + insertSQL: q, + onDuplicateKeyUpdateMap: sqlla.SetMap{}, + } +} + +{{ if .HasPk -}} +func (q {{ $camelName }}BulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) ({{ .StructName }}, error) { +{{- else -}} +func (q {{ $camelName }}BulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (sql.Result, error) { +{{- end }} + query, args, err := q.ToSql() + if err != nil { + {{ if .HasPk -}} + return {{ .StructName }}{}, err + {{- else }} + return nil, err + {{- end }} + } + result, err := db.ExecContext(ctx, query, args...) + {{ if .HasPk -}} + if err != nil { + return {{ .StructName }}{}, err + } + id, err := result.LastInsertId() + if err != nil { + return {{ .StructName }}{}, err + } + return {{ $constructor }}().Select().PkColumn(id).SingleContext(ctx, db) + {{- else -}} + return result, err + {{- end }} +} + {{ end }} diff --git a/template/table.tmpl b/template/table.tmpl index abffbc1..2626549 100644 --- a/template/table.tmpl +++ b/template/table.tmpl @@ -5,6 +5,7 @@ import ( "strings" "strconv" "context" + "fmt" "database/sql" {{ range .AdditionalPackages -}} From 0af7ded9c9f76633076860c7c089077d6160f82d Mon Sep 17 00:00:00 2001 From: mackee Date: Fri, 19 Nov 2021 19:26:20 +0900 Subject: [PATCH 2/5] add tests of bulk insert --- _example/user.gen.go | 2 +- _example/user_item.gen.go | 2 +- _example/user_test.go | 22 ++++++++++++++++++++++ _example/user_withmysql_test.go | 1 - column.go | 2 +- template/insert.tmpl | 2 +- 6 files changed, 26 insertions(+), 5 deletions(-) diff --git a/_example/user.gen.go b/_example/user.gen.go index c76405e..967edef 100644 --- a/_example/user.gen.go +++ b/_example/user.gen.go @@ -949,7 +949,7 @@ func (q userBulkInsertSQL) userInsertSQLToSql() (string, []interface{}, error) { return "", []interface{}{}, err } - return "INSERT INTO " + query, vs, nil + return "INSERT INTO `user` " + query, vs, nil } func (q userBulkInsertSQL) ToSql() (string, []interface{}, error) { diff --git a/_example/user_item.gen.go b/_example/user_item.gen.go index 18f5a33..f31ad5f 100644 --- a/_example/user_item.gen.go +++ b/_example/user_item.gen.go @@ -867,7 +867,7 @@ func (q userItemBulkInsertSQL) userItemInsertSQLToSql() (string, []interface{}, return "", []interface{}{}, err } - return "INSERT INTO " + query, vs, nil + return "INSERT INTO `user_item` " + query, vs, nil } func (q userItemBulkInsertSQL) ToSql() (string, []interface{}, error) { diff --git a/_example/user_test.go b/_example/user_test.go index c46971b..a9da83e 100644 --- a/_example/user_test.go +++ b/_example/user_test.go @@ -6,6 +6,7 @@ import ( "os" "reflect" "regexp" + "strconv" "strings" "testing" "time" @@ -229,6 +230,27 @@ func TestInsertOnDuplicateKeyUpdate(t *testing.T) { } } +func TestBulkInsert(t *testing.T) { + items := make([]userItemInsertSQL, 0, 10) + for i := 1; i <= 10; i++ { + q := NewUserItemSQL().Insert(). + ValueUserID(42). + ValueItemID(strconv.Itoa(i)) + items = append(items, q) + } + query, vs, err := NewUserItemBulkInsertSQL(items...).ToSql() + if err != nil { + t.Error("unexpected error:", err) + } + expected := "INSERT INTO `user_item` (`item_id`,`user_id`) VALUES (?,?),(?,?),(?,?),(?,?),(?,?),(?,?),(?,?),(?,?),(?,?),(?,?);" + if query != expected { + t.Error("query is not match:", query) + } + if !reflect.DeepEqual(vs, []interface{}{"1", uint64(42), "2", uint64(42), "3", uint64(42), "4", uint64(42), "5", uint64(42), "6", uint64(42), "7", uint64(42), "8", uint64(42), "9", uint64(42), "10", uint64(42)}) { + t.Errorf("vs is not valid: %+v", vs) + } +} + func TestDelete(t *testing.T) { q := NewUserSQL().Delete().Name("hogehoge") query, args, err := q.ToSql() diff --git a/_example/user_withmysql_test.go b/_example/user_withmysql_test.go index 6adff69..18e71c8 100644 --- a/_example/user_withmysql_test.go +++ b/_example/user_withmysql_test.go @@ -123,5 +123,4 @@ func TestInsertOnDuplicateKeyUpdate__WithMySQL(t *testing.T) { if r2.UpdatedAt.Time.Unix() <= r1.UpdatedAt.Time.Unix() { t.Fatal("updated_at does not updated:", r1.UpdatedAt.Time.Unix(), r2.UpdatedAt.Time.Unix()) } - } diff --git a/column.go b/column.go index 4eabc62..f2b0c0a 100644 --- a/column.go +++ b/column.go @@ -142,5 +142,5 @@ func (s SetMaps) ToInsertSql() (string, []interface{}, error) { return "", nil, err } } - return columns + " VALUES" + b.String(), vs, nil + return columns + " VALUES " + b.String(), vs, nil } diff --git a/template/insert.tmpl b/template/insert.tmpl index 5f5701d..678f2c9 100644 --- a/template/insert.tmpl +++ b/template/insert.tmpl @@ -217,7 +217,7 @@ func (q {{ $camelName }}BulkInsertSQL) {{ $camelName }}InsertSQLToSql() (string, return "", []interface{}{}, err } - return "INSERT INTO " + query, vs, nil + return "INSERT INTO `{{ .Name }}` " + query, vs, nil } func (q {{ $camelName }}BulkInsertSQL) ToSql() (string, []interface{}, error) { From 5c461f08c932ebb780f7418433ba5afdd8c40528 Mon Sep 17 00:00:00 2001 From: mackee Date: Fri, 19 Nov 2021 19:39:14 +0900 Subject: [PATCH 3/5] change API: *SQL has a method the BulkInsert() --- _example/user.gen.go | 10 +++++++--- _example/user_item.gen.go | 10 +++++++--- _example/user_test.go | 6 +++--- template/insert.tmpl | 10 +++++++--- 4 files changed, 24 insertions(+), 12 deletions(-) diff --git a/_example/user.gen.go b/_example/user.gen.go index 967edef..01a2a95 100644 --- a/_example/user.gen.go +++ b/_example/user.gen.go @@ -914,12 +914,16 @@ type userBulkInsertSQL struct { insertSQLs []userInsertSQL } -func NewUserBulkInsertSQL(insertSQLs ...userInsertSQL) userBulkInsertSQL { - return userBulkInsertSQL{ - insertSQLs: insertSQLs, +func (q userSQL) BulkInsert() *userBulkInsertSQL { + return &userBulkInsertSQL{ + insertSQLs: []userInsertSQL{}, } } +func (q *userBulkInsertSQL) Append(iqs ...userInsertSQL) { + q.insertSQLs = append(q.insertSQLs, iqs...) +} + func (q userBulkInsertSQL) userInsertSQLToSql() (string, []interface{}, error) { if len(q.insertSQLs) == 0 { return "", []interface{}{}, fmt.Errorf("sqlla: This userBulkInsertSQL does not have InsertSQL") diff --git a/_example/user_item.gen.go b/_example/user_item.gen.go index f31ad5f..ab89f0e 100644 --- a/_example/user_item.gen.go +++ b/_example/user_item.gen.go @@ -832,12 +832,16 @@ type userItemBulkInsertSQL struct { insertSQLs []userItemInsertSQL } -func NewUserItemBulkInsertSQL(insertSQLs ...userItemInsertSQL) userItemBulkInsertSQL { - return userItemBulkInsertSQL{ - insertSQLs: insertSQLs, +func (q userItemSQL) BulkInsert() *userItemBulkInsertSQL { + return &userItemBulkInsertSQL{ + insertSQLs: []userItemInsertSQL{}, } } +func (q *userItemBulkInsertSQL) Append(iqs ...userItemInsertSQL) { + q.insertSQLs = append(q.insertSQLs, iqs...) +} + func (q userItemBulkInsertSQL) userItemInsertSQLToSql() (string, []interface{}, error) { if len(q.insertSQLs) == 0 { return "", []interface{}{}, fmt.Errorf("sqlla: This userItemBulkInsertSQL does not have InsertSQL") diff --git a/_example/user_test.go b/_example/user_test.go index a9da83e..441c9e1 100644 --- a/_example/user_test.go +++ b/_example/user_test.go @@ -231,14 +231,14 @@ func TestInsertOnDuplicateKeyUpdate(t *testing.T) { } func TestBulkInsert(t *testing.T) { - items := make([]userItemInsertSQL, 0, 10) + items := NewUserItemSQL().BulkInsert() for i := 1; i <= 10; i++ { q := NewUserItemSQL().Insert(). ValueUserID(42). ValueItemID(strconv.Itoa(i)) - items = append(items, q) + items.Append(q) } - query, vs, err := NewUserItemBulkInsertSQL(items...).ToSql() + query, vs, err := items.ToSql() if err != nil { t.Error("unexpected error:", err) } diff --git a/template/insert.tmpl b/template/insert.tmpl index 678f2c9..0486121 100644 --- a/template/insert.tmpl +++ b/template/insert.tmpl @@ -182,12 +182,16 @@ type {{ $camelName }}BulkInsertSQL struct { insertSQLs []{{ $camelName }}InsertSQL } -func New{{ (.Name | toCamel) }}BulkInsertSQL(insertSQLs ...{{ $camelName }}InsertSQL) {{ $camelName }}BulkInsertSQL { - return {{ $camelName }}BulkInsertSQL{ - insertSQLs: insertSQLs, +func (q {{ $camelName }}SQL) BulkInsert() *{{ $camelName }}BulkInsertSQL { + return &{{ $camelName }}BulkInsertSQL{ + insertSQLs: []{{ $camelName }}InsertSQL{}, } } +func (q *{{ $camelName }}BulkInsertSQL) Append(iqs ...{{ $camelName }}InsertSQL) { + q.insertSQLs = append(q.insertSQLs, iqs...) +} + func (q {{ $camelName }}BulkInsertSQL) {{ $camelName }}InsertSQLToSql() (string, []interface{}, error) { if len(q.insertSQLs) == 0 { return "", []interface{}{}, fmt.Errorf("sqlla: This {{ $camelName }}BulkInsertSQL does not have InsertSQL") From 03d208ca6d5f6bf02743606ee3fb1907601a5f20 Mon Sep 17 00:00:00 2001 From: mackee Date: Fri, 19 Nov 2021 19:50:27 +0900 Subject: [PATCH 4/5] add test for bulk insert with insert ~ on duplicate key update --- _example/user_test.go | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/_example/user_test.go b/_example/user_test.go index 441c9e1..0158719 100644 --- a/_example/user_test.go +++ b/_example/user_test.go @@ -251,6 +251,35 @@ func TestBulkInsert(t *testing.T) { } } +func TestBulkInsertWithOnDuplicateKeyUpdate(t *testing.T) { + items := NewUserItemSQL().BulkInsert() + items.Append(NewUserItemSQL().Insert().ValueUserID(42).ValueItemID("1").ValueIsUsed(true)) + items.Append(NewUserItemSQL().Insert().ValueUserID(42).ValueItemID("2").ValueIsUsed(true)) + + now := mysql.NullTime{ + Valid: true, + Time: time.Now(), + } + query, vs, err := items. + OnDuplicateKeyUpdate(). + SameOnUpdateIsUsed(). + ValueOnUpdateUsedAt(now). + ToSql() + if err != nil { + t.Error("unexpected error:", err) + } + bulkInsertQuery := "INSERT INTO `user_item` (`is_used`,`item_id`,`user_id`) VALUES (?,?,?),(?,?,?) " + expected1 := bulkInsertQuery + "ON DUPLICATE KEY UPDATE `is_used` = VALUES(`is_used`), `used_at` = ?;" + expected2 := bulkInsertQuery + "ON DUPLICATE KEY UPDATE `used_at` = ?, `is_used` = VALUES(`is_used`);" + if query != expected1 && query != expected2 { + t.Error("query is not match:", query) + } + if !reflect.DeepEqual(vs, []interface{}{true, "1", uint64(42), true, "2", uint64(42), now}) { + t.Errorf("vs is not valid: %+v", vs) + } + +} + func TestDelete(t *testing.T) { q := NewUserSQL().Delete().Name("hogehoge") query, args, err := q.ToSql() From 7a7df9cb1fa48741c5754dea83ff0b07afe50080 Mon Sep 17 00:00:00 2001 From: mackee Date: Fri, 19 Nov 2021 20:26:40 +0900 Subject: [PATCH 5/5] add tests of bulk insert with on duplicate key update on mysql --- _example/user.gen.go | 21 +++---- _example/user_item.gen.go | 21 +++---- _example/user_test.go | 6 +- _example/user_withmysql_test.go | 107 ++++++++++++++++++++++++++++++++ template/insert.tmpl | 29 ++------- 5 files changed, 130 insertions(+), 54 deletions(-) diff --git a/_example/user.gen.go b/_example/user.gen.go index 01a2a95..d1d9680 100644 --- a/_example/user.gen.go +++ b/_example/user.gen.go @@ -924,9 +924,9 @@ func (q *userBulkInsertSQL) Append(iqs ...userInsertSQL) { q.insertSQLs = append(q.insertSQLs, iqs...) } -func (q userBulkInsertSQL) userInsertSQLToSql() (string, []interface{}, error) { +func (q *userBulkInsertSQL) userInsertSQLToSql() (string, []interface{}, error) { if len(q.insertSQLs) == 0 { - return "", []interface{}{}, fmt.Errorf("sqlla: This userBulkInsertSQL does not have InsertSQL") + return "", []interface{}{}, fmt.Errorf("sqlla: This userBulkInsertSQL's InsertSQL was empty") } iqs := make([]userInsertSQL, len(q.insertSQLs)) copy(iqs, q.insertSQLs) @@ -956,7 +956,7 @@ func (q userBulkInsertSQL) userInsertSQLToSql() (string, []interface{}, error) { return "INSERT INTO `user` " + query, vs, nil } -func (q userBulkInsertSQL) ToSql() (string, []interface{}, error) { +func (q *userBulkInsertSQL) ToSql() (string, []interface{}, error) { query, vs, err := q.userInsertSQLToSql() if err != nil { return "", []interface{}{}, err @@ -964,27 +964,20 @@ func (q userBulkInsertSQL) ToSql() (string, []interface{}, error) { return query + ";", vs, nil } -func (q userBulkInsertSQL) OnDuplicateKeyUpdate() userInsertOnDuplicateKeyUpdateSQL { +func (q *userBulkInsertSQL) OnDuplicateKeyUpdate() userInsertOnDuplicateKeyUpdateSQL { return userInsertOnDuplicateKeyUpdateSQL{ insertSQL: q, onDuplicateKeyUpdateMap: sqlla.SetMap{}, } } -func (q userBulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (User, error) { +func (q *userBulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (sql.Result, error) { query, args, err := q.ToSql() if err != nil { - return User{}, err + return nil, err } result, err := db.ExecContext(ctx, query, args...) - if err != nil { - return User{}, err - } - id, err := result.LastInsertId() - if err != nil { - return User{}, err - } - return NewUserSQL().Select().PkColumn(id).SingleContext(ctx, db) + return result, err } type userDeleteSQL struct { diff --git a/_example/user_item.gen.go b/_example/user_item.gen.go index ab89f0e..f3dfee2 100644 --- a/_example/user_item.gen.go +++ b/_example/user_item.gen.go @@ -842,9 +842,9 @@ func (q *userItemBulkInsertSQL) Append(iqs ...userItemInsertSQL) { q.insertSQLs = append(q.insertSQLs, iqs...) } -func (q userItemBulkInsertSQL) userItemInsertSQLToSql() (string, []interface{}, error) { +func (q *userItemBulkInsertSQL) userItemInsertSQLToSql() (string, []interface{}, error) { if len(q.insertSQLs) == 0 { - return "", []interface{}{}, fmt.Errorf("sqlla: This userItemBulkInsertSQL does not have InsertSQL") + return "", []interface{}{}, fmt.Errorf("sqlla: This userItemBulkInsertSQL's InsertSQL was empty") } iqs := make([]userItemInsertSQL, len(q.insertSQLs)) copy(iqs, q.insertSQLs) @@ -874,7 +874,7 @@ func (q userItemBulkInsertSQL) userItemInsertSQLToSql() (string, []interface{}, return "INSERT INTO `user_item` " + query, vs, nil } -func (q userItemBulkInsertSQL) ToSql() (string, []interface{}, error) { +func (q *userItemBulkInsertSQL) ToSql() (string, []interface{}, error) { query, vs, err := q.userItemInsertSQLToSql() if err != nil { return "", []interface{}{}, err @@ -882,27 +882,20 @@ func (q userItemBulkInsertSQL) ToSql() (string, []interface{}, error) { return query + ";", vs, nil } -func (q userItemBulkInsertSQL) OnDuplicateKeyUpdate() userItemInsertOnDuplicateKeyUpdateSQL { +func (q *userItemBulkInsertSQL) OnDuplicateKeyUpdate() userItemInsertOnDuplicateKeyUpdateSQL { return userItemInsertOnDuplicateKeyUpdateSQL{ insertSQL: q, onDuplicateKeyUpdateMap: sqlla.SetMap{}, } } -func (q userItemBulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (UserItem, error) { +func (q *userItemBulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (sql.Result, error) { query, args, err := q.ToSql() if err != nil { - return UserItem{}, err + return nil, err } result, err := db.ExecContext(ctx, query, args...) - if err != nil { - return UserItem{}, err - } - id, err := result.LastInsertId() - if err != nil { - return UserItem{}, err - } - return NewUserItemSQL().Select().PkColumn(id).SingleContext(ctx, db) + return result, err } type userItemDeleteSQL struct { diff --git a/_example/user_test.go b/_example/user_test.go index 0158719..041b140 100644 --- a/_example/user_test.go +++ b/_example/user_test.go @@ -253,8 +253,10 @@ func TestBulkInsert(t *testing.T) { func TestBulkInsertWithOnDuplicateKeyUpdate(t *testing.T) { items := NewUserItemSQL().BulkInsert() - items.Append(NewUserItemSQL().Insert().ValueUserID(42).ValueItemID("1").ValueIsUsed(true)) - items.Append(NewUserItemSQL().Insert().ValueUserID(42).ValueItemID("2").ValueIsUsed(true)) + items.Append( + NewUserItemSQL().Insert().ValueUserID(42).ValueItemID("1").ValueIsUsed(true), + NewUserItemSQL().Insert().ValueUserID(42).ValueItemID("2").ValueIsUsed(true), + ) now := mysql.NullTime{ Valid: true, diff --git a/_example/user_withmysql_test.go b/_example/user_withmysql_test.go index 18e71c8..506aa47 100644 --- a/_example/user_withmysql_test.go +++ b/_example/user_withmysql_test.go @@ -9,6 +9,7 @@ import ( "io/ioutil" "log" "os" + "strconv" "strings" "testing" "time" @@ -124,3 +125,109 @@ func TestInsertOnDuplicateKeyUpdate__WithMySQL(t *testing.T) { t.Fatal("updated_at does not updated:", r1.UpdatedAt.Time.Unix(), r2.UpdatedAt.Time.Unix()) } } + +func TestBulkInsert__WithMySQL(t *testing.T) { + ctx := context.Background() + + if _, err := NewUserItemSQL().Delete().ExecContext(ctx, db); err != nil { + t.Fatal("unexpected error:", err) + } + + items := NewUserItemSQL().BulkInsert() + items.Append( + NewUserItemSQL().Insert().ValueUserID(42).ValueItemID("1").ValueIsUsed(true), + NewUserItemSQL().Insert().ValueUserID(42).ValueItemID("2").ValueIsUsed(true), + ) + + if _, err := items.ExecContext(ctx, db); err != nil { + t.Fatal("unexpected error:", err) + } + + uis, err := NewUserItemSQL().Select().AllContext(ctx, db) + if err != nil { + t.Fatal("unexpected error:", err) + } + for i, ui := range uis { + if ui.UserId != 42 { + t.Error("UserId is not match:", ui.UserId) + } + if ui.ItemId != strconv.Itoa(i+1) { + t.Errorf("ItemId is not match: index=%d, got=%s", i, ui.ItemId) + } + if !ui.IsUsed { + t.Error("IsUsed is false") + } + } +} + +func TestBulkInsertOnDuplicateKeyUpdate__WithMySQL(t *testing.T) { + ctx := context.Background() + + if _, err := NewUserItemSQL().Delete().ExecContext(ctx, db); err != nil { + t.Fatal("unexpected error:", err) + } + + items := NewUserItemSQL().BulkInsert() + items.Append( + NewUserItemSQL().Insert().ValueUserID(42).ValueItemID("1").ValueIsUsed(false), + NewUserItemSQL().Insert().ValueUserID(42).ValueItemID("2").ValueIsUsed(false), + ) + + if _, err := items.ExecContext(ctx, db); err != nil { + t.Fatal("unexpected error:", err) + } + + uis, err := NewUserItemSQL().Select().AllContext(ctx, db) + if err != nil { + t.Fatal("unexpected error:", err) + } + uitems := NewUserItemSQL().BulkInsert() + for _, ui := range uis { + uitems.Append( + NewUserItemSQL().Insert(). + ValueID(ui.Id). + ValueUserID(42). + ValueItemID(ui.ItemId). + ValueIsUsed(true), + ) + } + uitems.Append( + NewUserItemSQL().Insert(). + ValueID(uis[len(uis)-1].Id + 1). + ValueUserID(42). + ValueItemID("3"). + ValueIsUsed(true), + ) + now := time.Now() + dup := uitems.OnDuplicateKeyUpdate(). + SameOnUpdateIsUsed(). + ValueOnUpdateUsedAt(mysql.NullTime{ + Valid: true, + Time: now, + }) + + if _, err := dup.ExecContext(ctx, db); err != nil { + t.Fatal("unexpected error:", err) + } + + uuis, err := NewUserItemSQL().Select().OrderByID(sqlla.Asc).AllContext(ctx, db) + if err != nil { + t.Fatal("unexpected error:", err) + } + for i, ui := range uuis { + if !ui.IsUsed { + t.Errorf("IsUsed is false: index=%d", i) + } + switch i { + case 0, 1: + if !ui.UsedAt.Valid { + t.Errorf("UsedAt is not valid: index=%d", i) + } + case 2: + if ui.UsedAt.Valid { + t.Errorf("UsedAt is valid: index=%d", i) + } + } + } + +} diff --git a/template/insert.tmpl b/template/insert.tmpl index 0486121..a890641 100644 --- a/template/insert.tmpl +++ b/template/insert.tmpl @@ -192,9 +192,9 @@ func (q *{{ $camelName }}BulkInsertSQL) Append(iqs ...{{ $camelName }}InsertSQL) q.insertSQLs = append(q.insertSQLs, iqs...) } -func (q {{ $camelName }}BulkInsertSQL) {{ $camelName }}InsertSQLToSql() (string, []interface{}, error) { +func (q *{{ $camelName }}BulkInsertSQL) {{ $camelName }}InsertSQLToSql() (string, []interface{}, error) { if len(q.insertSQLs) == 0 { - return "", []interface{}{}, fmt.Errorf("sqlla: This {{ $camelName }}BulkInsertSQL does not have InsertSQL") + return "", []interface{}{}, fmt.Errorf("sqlla: This {{ $camelName }}BulkInsertSQL{{ "'s" }} InsertSQL was empty") } iqs := make([]{{ $camelName }}InsertSQL, len(q.insertSQLs)) copy(iqs, q.insertSQLs) @@ -224,7 +224,7 @@ func (q {{ $camelName }}BulkInsertSQL) {{ $camelName }}InsertSQLToSql() (string, return "INSERT INTO `{{ .Name }}` " + query, vs, nil } -func (q {{ $camelName }}BulkInsertSQL) ToSql() (string, []interface{}, error) { +func (q *{{ $camelName }}BulkInsertSQL) ToSql() (string, []interface{}, error) { query, vs, err := q.{{ $camelName }}InsertSQLToSql() if err != nil { return "", []interface{}{}, err @@ -232,39 +232,20 @@ func (q {{ $camelName }}BulkInsertSQL) ToSql() (string, []interface{}, error) { return query + ";", vs, nil } -func (q {{ $camelName }}BulkInsertSQL) OnDuplicateKeyUpdate() {{ $camelName }}InsertOnDuplicateKeyUpdateSQL { +func (q *{{ $camelName }}BulkInsertSQL) OnDuplicateKeyUpdate() {{ $camelName }}InsertOnDuplicateKeyUpdateSQL { return {{ $camelName }}InsertOnDuplicateKeyUpdateSQL{ insertSQL: q, onDuplicateKeyUpdateMap: sqlla.SetMap{}, } } -{{ if .HasPk -}} -func (q {{ $camelName }}BulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) ({{ .StructName }}, error) { -{{- else -}} -func (q {{ $camelName }}BulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (sql.Result, error) { -{{- end }} +func (q *{{ $camelName }}BulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (sql.Result, error) { query, args, err := q.ToSql() if err != nil { - {{ if .HasPk -}} - return {{ .StructName }}{}, err - {{- else }} return nil, err - {{- end }} } result, err := db.ExecContext(ctx, query, args...) - {{ if .HasPk -}} - if err != nil { - return {{ .StructName }}{}, err - } - id, err := result.LastInsertId() - if err != nil { - return {{ .StructName }}{}, err - } - return {{ $constructor }}().Select().PkColumn(id).SingleContext(ctx, db) - {{- else -}} return result, err - {{- end }} } {{ end }}