diff --git a/_example/user.gen.go b/_example/user.gen.go index 0ce425d..d1d9680 100644 --- a/_example/user.gen.go +++ b/_example/user.gen.go @@ -3,6 +3,7 @@ package example import ( "context" + "fmt" "strconv" "strings" @@ -681,14 +682,14 @@ func (q userInsertSQL) ValueUpdatedAt(v mysql.NullTime) userInsertSQL { } func (q userInsertSQL) ToSql() (string, []interface{}, error) { - query, vs, err := q.toSql() + query, vs, err := q.userInsertSQLToSql() if err != nil { return "", []interface{}{}, err } return query + ";", vs, nil } -func (q userInsertSQL) toSql() (string, []interface{}, error) { +func (q userInsertSQL) userInsertSQLToSql() (string, []interface{}, error) { var err error var s interface{} = User{} if t, ok := s.(userDefaultInsertHooker); ok { @@ -750,8 +751,12 @@ type userDefaultInsertHooker interface { DefaultInsertHook(userInsertSQL) (userInsertSQL, error) } +type userInsertSQLToSqler interface { + userInsertSQLToSql() (string, []interface{}, error) +} + type userInsertOnDuplicateKeyUpdateSQL struct { - insertSQL userInsertSQL + insertSQL userInsertSQLToSqler onDuplicateKeyUpdateMap sqlla.SetMap } @@ -870,7 +875,7 @@ func (q userInsertOnDuplicateKeyUpdateSQL) ToSql() (string, []interface{}, error } } - query, vs, err := q.insertSQL.toSql() + query, vs, err := q.insertSQL.userInsertSQLToSql() if err != nil { return "", []interface{}{}, err } @@ -905,6 +910,76 @@ type userDefaultInsertOnDuplicateKeyUpdateHooker interface { DefaultInsertOnDuplicateKeyUpdateHook(userInsertOnDuplicateKeyUpdateSQL) (userInsertOnDuplicateKeyUpdateSQL, error) } +type userBulkInsertSQL struct { + insertSQLs []userInsertSQL +} + +func (q userSQL) BulkInsert() *userBulkInsertSQL { + return &userBulkInsertSQL{ + insertSQLs: []userInsertSQL{}, + } +} + +func (q *userBulkInsertSQL) Append(iqs ...userInsertSQL) { + q.insertSQLs = append(q.insertSQLs, iqs...) +} + +func (q *userBulkInsertSQL) userInsertSQLToSql() (string, []interface{}, error) { + if len(q.insertSQLs) == 0 { + return "", []interface{}{}, fmt.Errorf("sqlla: This userBulkInsertSQL's InsertSQL was empty") + } + iqs := make([]userInsertSQL, len(q.insertSQLs)) + copy(iqs, q.insertSQLs) + + var s interface{} = User{} + if t, ok := s.(userDefaultInsertHooker); ok { + for i, iq := range iqs { + var err error + iq, err = t.DefaultInsertHook(iq) + if err != nil { + return "", []interface{}{}, err + } + iqs[i] = iq + } + } + + sms := make(sqlla.SetMaps, 0, len(q.insertSQLs)) + for _, iq := range q.insertSQLs { + sms = append(sms, iq.setMap) + } + + query, vs, err := sms.ToInsertSql() + if err != nil { + return "", []interface{}{}, err + } + + return "INSERT INTO `user` " + query, vs, nil +} + +func (q *userBulkInsertSQL) ToSql() (string, []interface{}, error) { + query, vs, err := q.userInsertSQLToSql() + if err != nil { + return "", []interface{}{}, err + } + return query + ";", vs, nil +} + +func (q *userBulkInsertSQL) OnDuplicateKeyUpdate() userInsertOnDuplicateKeyUpdateSQL { + return userInsertOnDuplicateKeyUpdateSQL{ + insertSQL: q, + onDuplicateKeyUpdateMap: sqlla.SetMap{}, + } +} + +func (q *userBulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (sql.Result, error) { + query, args, err := q.ToSql() + if err != nil { + return nil, err + } + result, err := db.ExecContext(ctx, query, args...) + return result, err +} + type userDeleteSQL struct { userSQL } diff --git a/_example/user_hook.go b/_example/user_hook.go index 48d625c..36a0108 100644 --- a/_example/user_hook.go +++ b/_example/user_hook.go @@ -12,9 +12,6 @@ func (u User) DefaultInsertHook(q userInsertSQL) (userInsertSQL, error) { } func (u User) DefaultInsertOnDuplicateKeyUpdateHook(q userInsertOnDuplicateKeyUpdateSQL) (userInsertOnDuplicateKeyUpdateSQL, error) { - now := time.Now() - q.insertSQL = q.insertSQL.ValueUpdatedAt(mysql.NullTime{Time: now, Valid: true}) - return q.SameOnUpdateUpdatedAt(), nil } diff --git a/_example/user_item.gen.go b/_example/user_item.gen.go index 2ab147c..f3dfee2 100644 --- a/_example/user_item.gen.go +++ b/_example/user_item.gen.go @@ -3,6 +3,7 @@ package example import ( "context" + "fmt" "strconv" "strings" @@ -614,14 +615,14 @@ func (q userItemInsertSQL) ValueUsedAt(v mysql.NullTime) userItemInsertSQL { } func (q userItemInsertSQL) ToSql() (string, []interface{}, error) { - query, vs, err := q.toSql() + query, vs, err := q.userItemInsertSQLToSql() if err != nil { return "", []interface{}{}, err } return query + ";", vs, nil } -func (q userItemInsertSQL) toSql() (string, []interface{}, error) { +func (q userItemInsertSQL) userItemInsertSQLToSql() (string, []interface{}, error) { var err error var s interface{} = UserItem{} if t, ok := s.(userItemDefaultInsertHooker); ok { @@ -683,8 +684,12 @@ type userItemDefaultInsertHooker interface { DefaultInsertHook(userItemInsertSQL) (userItemInsertSQL, error) } +type userItemInsertSQLToSqler interface { + userItemInsertSQLToSql() (string, []interface{}, error) +} + type userItemInsertOnDuplicateKeyUpdateSQL struct { - insertSQL userItemInsertSQL + insertSQL userItemInsertSQLToSqler onDuplicateKeyUpdateMap sqlla.SetMap } @@ -788,7 +793,7 @@ func (q userItemInsertOnDuplicateKeyUpdateSQL) ToSql() (string, []interface{}, e } } - query, vs, err := q.insertSQL.toSql() + query, vs, err := q.insertSQL.userItemInsertSQLToSql() if err != nil { return "", []interface{}{}, err } @@ -823,6 +828,76 @@ type userItemDefaultInsertOnDuplicateKeyUpdateHooker interface { DefaultInsertOnDuplicateKeyUpdateHook(userItemInsertOnDuplicateKeyUpdateSQL) (userItemInsertOnDuplicateKeyUpdateSQL, error) } +type userItemBulkInsertSQL struct { + insertSQLs []userItemInsertSQL +} + +func (q userItemSQL) BulkInsert() *userItemBulkInsertSQL { + return &userItemBulkInsertSQL{ + insertSQLs: []userItemInsertSQL{}, + } +} + +func (q *userItemBulkInsertSQL) Append(iqs ...userItemInsertSQL) { + q.insertSQLs = append(q.insertSQLs, iqs...) +} + +func (q *userItemBulkInsertSQL) userItemInsertSQLToSql() (string, []interface{}, error) { + if len(q.insertSQLs) == 0 { + return "", []interface{}{}, fmt.Errorf("sqlla: This userItemBulkInsertSQL's InsertSQL was empty") + } + iqs := make([]userItemInsertSQL, len(q.insertSQLs)) + copy(iqs, q.insertSQLs) + + var s interface{} = UserItem{} + if t, ok := s.(userItemDefaultInsertHooker); ok { + for i, iq := range iqs { + var err error + iq, err = t.DefaultInsertHook(iq) + if err != nil { + return "", []interface{}{}, err + } + iqs[i] = iq + } + } + + sms := make(sqlla.SetMaps, 0, len(q.insertSQLs)) + for _, iq := range q.insertSQLs { + sms = append(sms, iq.setMap) + } + + query, vs, err := sms.ToInsertSql() + if err != nil { + return "", []interface{}{}, err + } + + return "INSERT INTO `user_item` " + query, vs, nil +} + +func (q *userItemBulkInsertSQL) ToSql() (string, []interface{}, error) { + query, vs, err := q.userItemInsertSQLToSql() + if err != nil { + return "", []interface{}{}, err + } + return query + ";", vs, nil +} + +func (q *userItemBulkInsertSQL) OnDuplicateKeyUpdate() userItemInsertOnDuplicateKeyUpdateSQL { + return userItemInsertOnDuplicateKeyUpdateSQL{ + insertSQL: q, + onDuplicateKeyUpdateMap: sqlla.SetMap{}, + } +} + +func (q *userItemBulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (sql.Result, error) { + query, args, err := q.ToSql() + if err != nil { + return nil, err + } + result, err := db.ExecContext(ctx, query, args...) + return result, err +} + type userItemDeleteSQL struct { userItemSQL } diff --git a/_example/user_test.go b/_example/user_test.go index 577fd0c..041b140 100644 --- a/_example/user_test.go +++ b/_example/user_test.go @@ -6,6 +6,7 @@ import ( "os" "reflect" "regexp" + "strconv" "strings" "testing" "time" @@ -190,24 +191,24 @@ func TestInsert(t *testing.T) { if err != nil { t.Error("unexpected error:", err) } - switch query { - case "INSERT INTO user (`name`,`created_at`) VALUES(?,?);": - if !reflect.DeepEqual(args[0], "hogehoge") { - t.Error("unexpected args:", args) - } - case "INSERT INTO user (`created_at`,`name`) VALUES(?,?);": - if !reflect.DeepEqual(args[1], "hogehoge") { - t.Error("unexpected args:", args) - } - default: + expected := "INSERT INTO user (`created_at`,`name`) VALUES(?,?);" + if query != expected { t.Error("unexpected query:", query) } + if !reflect.DeepEqual(args[1], "hogehoge") { + t.Error("unexpected args:", args) + } } func TestInsertOnDuplicateKeyUpdate(t *testing.T) { + now := time.Now() q := NewUserSQL().Insert(). ValueID(1). ValueName("hogehoge"). + ValueUpdatedAt(mysql.NullTime{ + Valid: true, + Time: now, + }). OnDuplicateKeyUpdate(). ValueOnUpdateAge(sql.NullInt64{ Valid: true, @@ -229,6 +230,58 @@ func TestInsertOnDuplicateKeyUpdate(t *testing.T) { } } +func TestBulkInsert(t *testing.T) { + items := NewUserItemSQL().BulkInsert() + for i := 1; i <= 10; i++ { + q := NewUserItemSQL().Insert(). + ValueUserID(42). + ValueItemID(strconv.Itoa(i)) + items.Append(q) + } + query, vs, err := items.ToSql() + if err != nil { + t.Error("unexpected error:", err) + } + expected := "INSERT INTO `user_item` (`item_id`,`user_id`) VALUES (?,?),(?,?),(?,?),(?,?),(?,?),(?,?),(?,?),(?,?),(?,?),(?,?);" + if query != expected { + t.Error("query is not match:", query) + } + if !reflect.DeepEqual(vs, []interface{}{"1", uint64(42), "2", uint64(42), "3", uint64(42), "4", uint64(42), "5", uint64(42), "6", uint64(42), "7", uint64(42), "8", uint64(42), "9", uint64(42), "10", uint64(42)}) { + t.Errorf("vs is not valid: %+v", vs) + } +} + +func TestBulkInsertWithOnDuplicateKeyUpdate(t *testing.T) { + items := NewUserItemSQL().BulkInsert() + items.Append( + NewUserItemSQL().Insert().ValueUserID(42).ValueItemID("1").ValueIsUsed(true), + NewUserItemSQL().Insert().ValueUserID(42).ValueItemID("2").ValueIsUsed(true), + ) + + now := mysql.NullTime{ + Valid: true, + Time: time.Now(), + } + query, vs, err := items. + OnDuplicateKeyUpdate(). + SameOnUpdateIsUsed(). + ValueOnUpdateUsedAt(now). + ToSql() + if err != nil { + t.Error("unexpected error:", err) + } + bulkInsertQuery := "INSERT INTO `user_item` (`is_used`,`item_id`,`user_id`) VALUES (?,?,?),(?,?,?) " + expected1 := bulkInsertQuery + "ON DUPLICATE KEY UPDATE `is_used` = VALUES(`is_used`), `used_at` = ?;" + expected2 := bulkInsertQuery + "ON DUPLICATE KEY UPDATE `used_at` = ?, `is_used` = VALUES(`is_used`);" + if query != expected1 && query != expected2 { + t.Error("query is not match:", query) + } + if !reflect.DeepEqual(vs, []interface{}{true, "1", uint64(42), true, "2", uint64(42), now}) { + t.Errorf("vs is not valid: %+v", vs) + } + +} + func TestDelete(t *testing.T) { q := NewUserSQL().Delete().Name("hogehoge") query, args, err := q.ToSql() diff --git a/_example/user_withmysql_test.go b/_example/user_withmysql_test.go index 4896412..506aa47 100644 --- a/_example/user_withmysql_test.go +++ b/_example/user_withmysql_test.go @@ -9,10 +9,12 @@ import ( "io/ioutil" "log" "os" + "strconv" "strings" "testing" "time" + "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql" "github.com/mackee/go-sqlla/v2" "github.com/ory/dockertest/v3" @@ -84,12 +86,14 @@ func TestMain(m *testing.M) { func TestInsertOnDuplicateKeyUpdate__WithMySQL(t *testing.T) { ctx := context.Background() + now1 := time.Now() q1 := NewUserSQL().Insert(). ValueName("hogehoge"). ValueRate(3.14). ValueIconImage([]byte{}). - ValueAge(sql.NullInt64{Valid: true, Int64: 17}) + ValueAge(sql.NullInt64{Valid: true, Int64: 17}). + ValueUpdatedAt(mysql.NullTime{Valid: true, Time: now1}) query, args, _ := q1.ToSql() t.Logf("query=%s, args=%+v", query, args) r1, err := q1.ExecContext(ctx, db) @@ -97,12 +101,13 @@ func TestInsertOnDuplicateKeyUpdate__WithMySQL(t *testing.T) { t.Fatal("unexpected error:", err) } - time.Sleep(1 * time.Second) + now2 := now1.Add(1 * time.Second) q2 := NewUserSQL().Insert(). ValueName("hogehoge"). ValueAge(sql.NullInt64{Valid: true, Int64: 17}). ValueIconImage([]byte{}). + ValueUpdatedAt(mysql.NullTime{Valid: true, Time: now2}). OnDuplicateKeyUpdate(). RawValueOnUpdateAge(sqlla.SetMapRawValue("`age` + 1")) r2, err := q2.ExecContext(ctx, db) @@ -119,5 +124,110 @@ func TestInsertOnDuplicateKeyUpdate__WithMySQL(t *testing.T) { if r2.UpdatedAt.Time.Unix() <= r1.UpdatedAt.Time.Unix() { t.Fatal("updated_at does not updated:", r1.UpdatedAt.Time.Unix(), r2.UpdatedAt.Time.Unix()) } +} + +func TestBulkInsert__WithMySQL(t *testing.T) { + ctx := context.Background() + + if _, err := NewUserItemSQL().Delete().ExecContext(ctx, db); err != nil { + t.Fatal("unexpected error:", err) + } + + items := NewUserItemSQL().BulkInsert() + items.Append( + NewUserItemSQL().Insert().ValueUserID(42).ValueItemID("1").ValueIsUsed(true), + NewUserItemSQL().Insert().ValueUserID(42).ValueItemID("2").ValueIsUsed(true), + ) + + if _, err := items.ExecContext(ctx, db); err != nil { + t.Fatal("unexpected error:", err) + } + + uis, err := NewUserItemSQL().Select().AllContext(ctx, db) + if err != nil { + t.Fatal("unexpected error:", err) + } + for i, ui := range uis { + if ui.UserId != 42 { + t.Error("UserId is not match:", ui.UserId) + } + if ui.ItemId != strconv.Itoa(i+1) { + t.Errorf("ItemId is not match: index=%d, got=%s", i, ui.ItemId) + } + if !ui.IsUsed { + t.Error("IsUsed is false") + } + } +} + +func TestBulkInsertOnDuplicateKeyUpdate__WithMySQL(t *testing.T) { + ctx := context.Background() + + if _, err := NewUserItemSQL().Delete().ExecContext(ctx, db); err != nil { + t.Fatal("unexpected error:", err) + } + + items := NewUserItemSQL().BulkInsert() + items.Append( + NewUserItemSQL().Insert().ValueUserID(42).ValueItemID("1").ValueIsUsed(false), + NewUserItemSQL().Insert().ValueUserID(42).ValueItemID("2").ValueIsUsed(false), + ) + + if _, err := items.ExecContext(ctx, db); err != nil { + t.Fatal("unexpected error:", err) + } + + uis, err := NewUserItemSQL().Select().AllContext(ctx, db) + if err != nil { + t.Fatal("unexpected error:", err) + } + uitems := NewUserItemSQL().BulkInsert() + for _, ui := range uis { + uitems.Append( + NewUserItemSQL().Insert(). + ValueID(ui.Id). + ValueUserID(42). + ValueItemID(ui.ItemId). + ValueIsUsed(true), + ) + } + uitems.Append( + NewUserItemSQL().Insert(). + ValueID(uis[len(uis)-1].Id + 1). + ValueUserID(42). + ValueItemID("3"). + ValueIsUsed(true), + ) + now := time.Now() + dup := uitems.OnDuplicateKeyUpdate(). + SameOnUpdateIsUsed(). + ValueOnUpdateUsedAt(mysql.NullTime{ + Valid: true, + Time: now, + }) + + if _, err := dup.ExecContext(ctx, db); err != nil { + t.Fatal("unexpected error:", err) + } + + uuis, err := NewUserItemSQL().Select().OrderByID(sqlla.Asc).AllContext(ctx, db) + if err != nil { + t.Fatal("unexpected error:", err) + } + for i, ui := range uuis { + if !ui.IsUsed { + t.Errorf("IsUsed is false: index=%d", i) + } + switch i { + case 0, 1: + if !ui.UsedAt.Valid { + t.Errorf("UsedAt is not valid: index=%d", i) + } + case 2: + if ui.UsedAt.Valid { + t.Errorf("UsedAt is valid: index=%d", i) + } + } + } } diff --git a/column.go b/column.go index df87197..f2b0c0a 100644 --- a/column.go +++ b/column.go @@ -1,5 +1,11 @@ package sqlla +import ( + "fmt" + "sort" + "strings" +) + type Where []Expr func (wh Where) ToSql() (string, []interface{}, error) { @@ -29,6 +35,41 @@ type SetMapRawValue string type SetMap map[string]interface{} +func (sm SetMap) NewIterator() *SetMapIterator { + keys := make(sort.StringSlice, 0, len(sm)) + for k := range sm { + keys = append(keys, k) + } + sort.Sort(keys) + return &SetMapIterator{ + sm: sm, + keys: keys, + cursor: -1, + } +} + +type SetMapIterator struct { + sm SetMap + cursor int + keys []string +} + +func (s *SetMapIterator) Iterate() bool { + s.cursor++ + if len(s.keys)-1 < s.cursor { + return false + } + return true +} + +func (s *SetMapIterator) Key() string { + return s.keys[s.cursor] +} + +func (s *SetMapIterator) Value() interface{} { + return s.sm[s.keys[s.cursor]] +} + func (sm SetMap) ToUpdateSql() (string, []interface{}, error) { var setColumns string vs := []interface{}{} @@ -49,11 +90,13 @@ func (sm SetMap) ToUpdateSql() (string, []interface{}, error) { return setColumns, vs, nil } -func (sm SetMap) ToInsertSql() (string, []interface{}, error) { +func (sm SetMap) ToInsertColumnsAndValues() (string, string, []interface{}) { qs, ps := "(", "(" vs := []interface{}{} columnCount := 0 - for k, v := range sm { + iter := sm.NewIterator() + for iter.Iterate() { + k, v := iter.Key(), iter.Value() if columnCount != 0 { qs += "," ps += "," @@ -65,6 +108,39 @@ func (sm SetMap) ToInsertSql() (string, []interface{}, error) { } qs += ")" ps += ")" + return qs, ps, vs +} +func (sm SetMap) ToInsertSql() (string, []interface{}, error) { + qs, ps, vs := sm.ToInsertColumnsAndValues() return qs + " VALUES" + ps, vs, nil } + +type SetMaps []SetMap + +func (s SetMaps) ToInsertSql() (string, []interface{}, error) { + if len(s) == 0 { + return "", nil, fmt.Errorf("sqlla: SetMaps is empty") + } + + first := s[0] + columns, values, vs := first.ToInsertColumnsAndValues() + var b strings.Builder + if _, err := b.WriteString(values); err != nil { + return "", nil, err + } + for i, _s := range s[1:] { + _columns, _values, _vs := _s.ToInsertColumnsAndValues() + if columns != _columns { + return "", nil, fmt.Errorf("sqlla: two SetMap are not match keys: [0]=%s, [%d]=%s", columns, i, _columns) + } + vs = append(vs, _vs...) + if _, err := b.WriteString(","); err != nil { + return "", nil, err + } + if _, err := b.WriteString(_values); err != nil { + return "", nil, err + } + } + return columns + " VALUES " + b.String(), vs, nil +} diff --git a/generator_go116.go b/generator_go116.go index e8d31d6..aab93b7 100644 --- a/generator_go116.go +++ b/generator_go116.go @@ -67,7 +67,7 @@ func WriteCode(w io.Writer, table *Table) error { } bs, err := format.Source(buf.Bytes()) if err != nil { - return errors.Wrapf(err, "fail to format") + return errors.Wrapf(err, "fail to format: table=%s", table.Name) } _, err = w.Write(bs) return err diff --git a/template/insert.tmpl b/template/insert.tmpl index 6ffa87e..a890641 100644 --- a/template/insert.tmpl +++ b/template/insert.tmpl @@ -16,14 +16,14 @@ func (q {{ $camelName }}SQL) Insert() {{ $camelName }}InsertSQL { {{ range .Columns }}{{ template "InsertColumn" . }}{{ end }} func (q {{ $camelName }}InsertSQL) ToSql() (string, []interface{}, error) { - query, vs, err := q.toSql() + query, vs, err := q.{{ $camelName }}InsertSQLToSql() if err != nil { return "", []interface{}{}, err } return query + ";", vs, nil } -func (q {{ $camelName }}InsertSQL) toSql() (string, []interface{}, error) { +func (q {{ $camelName }}InsertSQL) {{ $camelName }}InsertSQLToSql() (string, []interface{}, error) { var err error var s interface{} = {{ .StructName }}{} if t, ok := s.({{ $camelName }}DefaultInsertHooker); ok { @@ -110,8 +110,12 @@ type {{ $camelName }}DefaultInsertHooker interface { DefaultInsertHook({{ $camelName }}InsertSQL) ({{ $camelName }}InsertSQL, error) } +type {{ $camelName }}InsertSQLToSqler interface { + {{ $camelName }}InsertSQLToSql() (string, []interface{}, error) +} + type {{ $camelName }}InsertOnDuplicateKeyUpdateSQL struct { - insertSQL {{ $camelName }}InsertSQL + insertSQL {{ $camelName }}InsertSQLToSqler onDuplicateKeyUpdateMap sqlla.SetMap } @@ -127,7 +131,7 @@ func (q {{ $camelName }}InsertOnDuplicateKeyUpdateSQL) ToSql() (string, []interf } } - query, vs, err := q.insertSQL.toSql() + query, vs, err := q.insertSQL.{{ $camelName }}InsertSQLToSql() if err != nil { return "", []interface{}{}, err } @@ -173,4 +177,75 @@ func (q {{ $camelName }}InsertOnDuplicateKeyUpdateSQL) ExecContext(ctx context.C type {{ $camelName }}DefaultInsertOnDuplicateKeyUpdateHooker interface { DefaultInsertOnDuplicateKeyUpdateHook({{ $camelName }}InsertOnDuplicateKeyUpdateSQL) ({{ $camelName }}InsertOnDuplicateKeyUpdateSQL, error) } + +type {{ $camelName }}BulkInsertSQL struct { + insertSQLs []{{ $camelName }}InsertSQL +} + +func (q {{ $camelName }}SQL) BulkInsert() *{{ $camelName }}BulkInsertSQL { + return &{{ $camelName }}BulkInsertSQL{ + insertSQLs: []{{ $camelName }}InsertSQL{}, + } +} + +func (q *{{ $camelName }}BulkInsertSQL) Append(iqs ...{{ $camelName }}InsertSQL) { + q.insertSQLs = append(q.insertSQLs, iqs...) +} + +func (q *{{ $camelName }}BulkInsertSQL) {{ $camelName }}InsertSQLToSql() (string, []interface{}, error) { + if len(q.insertSQLs) == 0 { + return "", []interface{}{}, fmt.Errorf("sqlla: This {{ $camelName }}BulkInsertSQL{{ "'s" }} InsertSQL was empty") + } + iqs := make([]{{ $camelName }}InsertSQL, len(q.insertSQLs)) + copy(iqs, q.insertSQLs) + + var s interface{} = {{ .StructName }}{} + if t, ok := s.({{ $camelName }}DefaultInsertHooker); ok { + for i, iq := range iqs { + var err error + iq, err = t.DefaultInsertHook(iq) + if err != nil { + return "", []interface{}{}, err + } + iqs[i] = iq + } + } + + sms := make(sqlla.SetMaps, 0, len(q.insertSQLs)) + for _, iq := range q.insertSQLs { + sms = append(sms, iq.setMap) + } + + query, vs, err := sms.ToInsertSql() + if err != nil { + return "", []interface{}{}, err + } + + return "INSERT INTO `{{ .Name }}` " + query, vs, nil +} + +func (q *{{ $camelName }}BulkInsertSQL) ToSql() (string, []interface{}, error) { + query, vs, err := q.{{ $camelName }}InsertSQLToSql() + if err != nil { + return "", []interface{}{}, err + } + return query + ";", vs, nil +} + +func (q *{{ $camelName }}BulkInsertSQL) OnDuplicateKeyUpdate() {{ $camelName }}InsertOnDuplicateKeyUpdateSQL { + return {{ $camelName }}InsertOnDuplicateKeyUpdateSQL{ + insertSQL: q, + onDuplicateKeyUpdateMap: sqlla.SetMap{}, + } +} + +func (q *{{ $camelName }}BulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (sql.Result, error) { + query, args, err := q.ToSql() + if err != nil { + return nil, err + } + result, err := db.ExecContext(ctx, query, args...) + return result, err +} + {{ end }} diff --git a/template/table.tmpl b/template/table.tmpl index abffbc1..2626549 100644 --- a/template/table.tmpl +++ b/template/table.tmpl @@ -5,6 +5,7 @@ import ( "strings" "strconv" "context" + "fmt" "database/sql" {{ range .AdditionalPackages -}}