From 694cc7efeec0f80d83206630bdf63a32d3eb0515 Mon Sep 17 00:00:00 2001 From: Mikhail Knyazhev Date: Mon, 18 Jul 2022 06:35:02 +0300 Subject: [PATCH] add lite query --- README.md | 123 +++++++++++++++++++++++++-- common.go | 8 ++ db.go | 2 +- errors.go | 1 - example/basic/main.go | 60 +++++++++++++ example/lite/main.go | 80 ++++++++++++++++++ example/main.go | 33 -------- models.go | 100 ---------------------- models_test.go | 32 ------- result.go | 23 ----- stmt.go | 7 -- stmt_exec.go | 97 +++++++++++++++++++++ stmt_ping.go | 13 --- stmt_query.go | 80 ++++++++++++++++++ stmt_raw.go | 36 ++++---- stmt_test.go | 191 ++++++++++++++++++++++++++++++++++++++++++ stmt_tx.go | 72 ++++++++++++++++ types/soft_delete.go | 4 - 18 files changed, 726 insertions(+), 236 deletions(-) create mode 100644 common.go delete mode 100644 errors.go create mode 100644 example/basic/main.go create mode 100644 example/lite/main.go delete mode 100644 example/main.go delete mode 100644 models.go delete mode 100644 models_test.go delete mode 100644 result.go create mode 100644 stmt_exec.go delete mode 100644 stmt_ping.go create mode 100644 stmt_query.go create mode 100644 stmt_test.go create mode 100644 stmt_tx.go delete mode 100644 types/soft_delete.go diff --git a/README.md b/README.md index 162986e..3ac414d 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,9 @@ The library provides a nice and simple ActiveRecord implementation for working with your database. Each database table has a corresponding "Model" that is used to interact with this table. Models allow you to query data in tables, as well as insert new records in the table. -# DEMO +# Examples + +## Init connection ```go package main @@ -24,21 +26,71 @@ import ( ) func main() { - - conn := mysql.New(&mysql.Config{Pool: []mysql.Item{}}) - err := conn.Reconnect() - if err != nil { + conn := mysql.New(&mysql.Config{ + Pool: []mysql.Item{ + { + Name: "main_db", + Host: "127.0.0.1", + Port: 3306, + Schema: "test_table", + User: "demo", + Password: "1234", + }, + }}) + defer conn.Close() + if err := conn.Reconnect(); err != nil { panic(err.Error()) } db := orm.NewDB(conn, orm.Plugins{Logger: plugins.StdOutLog, Metrics: plugins.StdOutMetric}) - pool := db.Pool("") + pool := db.Pool("main_db") - if err = pool.Ping(); err != nil { + if err := pool.Ping(); err != nil { panic(err.Error()) } - err = pool.Call("demo_metric", func(ctx context.Context, conn *sql.DB) error { + // use pool[main_db] here + err := pool.CallContext("query name", context.Background(), func(ctx context.Context, db *sql.DB) error {...} + err := pool.TxContext("query name", context.Background(), func(context.Context, *sql.Tx) error) error {...} +} +``` + +## Basic query + +```go +package main + +import ( + "context" + "database/sql" + + "github.com/deweppro/go-orm" + "github.com/deweppro/go-orm/plugins" + "github.com/deweppro/go-orm/schema/mysql" +) + +func main() { + ... + + var userName string + err := pool.CallContext("user_name", context.Background(), func(ctx context.Context, db *sql.DB) error { + rows, err := db.QueryContext(ctx, "select `name` from `users` where `id`=?", 10) + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + if err = rows.Scan(&userName); err != nil { + return err + } + } + if err = rows.Close(); err != nil { + return err + } + if err = rows.Err(); err != nil { + return err + } return nil }) if err != nil { @@ -47,3 +99,58 @@ func main() { } ``` + +## Lite query + +```go +package main + +import ( + "context" + "fmt" + + "github.com/deweppro/go-orm" + "github.com/deweppro/go-orm/plugins" + "github.com/deweppro/go-orm/schema/mysql" +) + +func main() { + ... + + var userName string + err := pool.QueryContext("user_name", context.Background(), func(q orm.Querier) { + q.SQL("select `name` from `users` limit 1") + q.Bind(func(bind orm.Scanner) error { + return bind.Scan(&userName) + }) + }) + + err = pool.ExecContext("user_name", context.Background(), func(e orm.Executor) { + e.SQL("insert into `users` (`id`, `name`) values (?, ?);") + e.Params(3, "cccc") + + e.Bind(func(result orm.Result) error { + fmt.Printf("RowsAffected=%d LastInsertId=%d", result.RowsAffected, result.LastInsertId) + return nil + }) + }) + + err = pool.TransactionContext("", context.Background(), func(v orm.Tx) { + v.Exec(func(e orm.Executor) { + e.SQL("insert into `users` (`id`, `name`) values (?, ?);") + e.Params(3, "cccc") + + e.Bind(func(result orm.Result) error { + fmt.Printf("RowsAffected=%d LastInsertId=%d", result.RowsAffected, result.LastInsertId) + return nil + }) + }) + v.Query(func(q orm.Querier) { + q.SQL("select `name` from `users` limit 1") + q.Bind(func(bind orm.Scanner) error { + return bind.Scan(&userName) + }) + }) + }) +} +``` \ No newline at end of file diff --git a/common.go b/common.go new file mode 100644 index 0000000..585d079 --- /dev/null +++ b/common.go @@ -0,0 +1,8 @@ +package orm + +import "github.com/deweppro/go-errors" + +var ( + //ErrInvalidModelPool if sync pool has invalid model type + ErrInvalidModelPool = errors.New("invalid internal model pool") +) diff --git a/db.go b/db.go index 284859a..640447b 100644 --- a/db.go +++ b/db.go @@ -34,6 +34,6 @@ func NewDB(c schema.Connector, plug Plugins) *DB { } //Pool getting pool connections by name -func (d *DB) Pool(name string) StmtInterface { +func (d *DB) Pool(name string) *Stmt { return newStmt(name, d.conn, d.plug) } diff --git a/errors.go b/errors.go deleted file mode 100644 index 5800656..0000000 --- a/errors.go +++ /dev/null @@ -1 +0,0 @@ -package orm diff --git a/example/basic/main.go b/example/basic/main.go new file mode 100644 index 0000000..1a02587 --- /dev/null +++ b/example/basic/main.go @@ -0,0 +1,60 @@ +package main + +import ( + "context" + "database/sql" + + "github.com/deweppro/go-orm" + "github.com/deweppro/go-orm/plugins" + "github.com/deweppro/go-orm/schema/mysql" +) + +func main() { + conn := mysql.New(&mysql.Config{ + Pool: []mysql.Item{ + { + Name: "main_db", + Host: "127.0.0.1", + Port: 3306, + Schema: "test_table", + User: "demo", + Password: "1234", + }, + }}) + defer conn.Close() //nolint: errcheck + if err := conn.Reconnect(); err != nil { + panic(err.Error()) + } + + db := orm.NewDB(conn, orm.Plugins{Logger: plugins.StdOutLog, Metrics: plugins.StdOutMetric}) + pool := db.Pool("main_db") + + if err := pool.Ping(); err != nil { + panic(err.Error()) + } + + var userName string + err := pool.CallContext("user_name", context.Background(), func(ctx context.Context, db *sql.DB) error { + rows, err := db.QueryContext(ctx, "select `name` from `users` where `id`=?", 10) + if err != nil { + return err + } + defer rows.Close() //nolint: errcheck + + for rows.Next() { + if err = rows.Scan(&userName); err != nil { + return err + } + } + if err = rows.Close(); err != nil { + return err + } + if err = rows.Err(); err != nil { + return err + } + return nil + }) + if err != nil { + panic(err.Error()) + } +} diff --git a/example/lite/main.go b/example/lite/main.go new file mode 100644 index 0000000..b47407b --- /dev/null +++ b/example/lite/main.go @@ -0,0 +1,80 @@ +package main + +import ( + "context" + "fmt" + + "github.com/deweppro/go-orm" + "github.com/deweppro/go-orm/plugins" + "github.com/deweppro/go-orm/schema/mysql" +) + +func main() { + conn := mysql.New(&mysql.Config{ + Pool: []mysql.Item{ + { + Name: "main_db", + Host: "127.0.0.1", + Port: 3306, + Schema: "test_table", + User: "demo", + Password: "1234", + }, + }}) + + if err := conn.Reconnect(); err != nil { + panic(err.Error()) + } + + db := orm.NewDB(conn, orm.Plugins{Logger: plugins.StdOutLog, Metrics: plugins.StdOutMetric}) + pool := db.Pool("main_db") + + if err := pool.Ping(); err != nil { + panic(err.Error()) + } + + var userName string + err := pool.QueryContext("user_name", context.Background(), func(q orm.Querier) { + q.SQL("select `name` from `users` limit 1") + q.Bind(func(bind orm.Scanner) error { + return bind.Scan(&userName) + }) + }) + if err != nil { + panic(err.Error()) + } + + err = pool.ExecContext("user_name", context.Background(), func(e orm.Executor) { + e.SQL("insert into `users` (`id`, `name`) values (?, ?);") + e.Params(3, "cccc") + + e.Bind(func(result orm.Result) error { + fmt.Printf("RowsAffected=%d LastInsertId=%d", result.RowsAffected, result.LastInsertId) + return nil + }) + }) + if err != nil { + panic(err.Error()) + } + + err = pool.TransactionContext("", context.Background(), func(v orm.Tx) { + v.Exec(func(e orm.Executor) { + e.SQL("insert into `users` (`id`, `name`) values (?, ?);") + e.Params(3, "cccc") + + e.Bind(func(result orm.Result) error { + fmt.Printf("RowsAffected=%d LastInsertId=%d", result.RowsAffected, result.LastInsertId) + return nil + }) + }) + v.Query(func(q orm.Querier) { + q.SQL("select `name` from `users` limit 1") + q.Bind(func(bind orm.Scanner) error { + return bind.Scan(&userName) + }) + }) + }) + if err != nil { + panic(err.Error()) + } +} diff --git a/example/main.go b/example/main.go deleted file mode 100644 index 458cc68..0000000 --- a/example/main.go +++ /dev/null @@ -1,33 +0,0 @@ -package main - -import ( - "context" - "database/sql" - - "github.com/deweppro/go-orm" - "github.com/deweppro/go-orm/plugins" - "github.com/deweppro/go-orm/schema/mysql" -) - -func main() { - - conn := mysql.New(&mysql.Config{Pool: []mysql.Item{}}) - err := conn.Reconnect() - if err != nil { - panic(err.Error()) - } - - db := orm.NewDB(conn, orm.Plugins{Logger: plugins.StdOutLog, Metrics: plugins.StdOutMetric}) - pool := db.Pool("") - - if err = pool.Ping(); err != nil { - panic(err.Error()) - } - - err = pool.Call("demo_metric", func(ctx context.Context, db *sql.DB) error { - return nil - }) - if err != nil { - panic(err.Error()) - } -} diff --git a/models.go b/models.go deleted file mode 100644 index 93c63a7..0000000 --- a/models.go +++ /dev/null @@ -1,100 +0,0 @@ -package orm - -import ( - "fmt" - "reflect" - "strings" -) - -const ( - tagKey = "orm" -) - -var ( - typeTableName = reflect.TypeOf(TableName("")) -) - -type ( - //TableName field of table name - TableName string - ormModel struct { - Table TableName - Path string - Type reflect.Type - Origin reflect.Value - Fields []ormModelField - } - ormModelField struct { - Name string - Col string - Val interface{} - Empty bool - } -) - -func parseModel(v interface{}) *ormModel { - val := reflect.ValueOf(v) - ref := val.Type() - - switch ref.Kind() { - case reflect.Struct: - return decodeType(ref, val) - case reflect.Ptr: - return decodeType(ref.Elem(), val.Elem()) - } - - return nil -} - -func decodeType(t reflect.Type, v reflect.Value) *ormModel { - mod := &ormModel{ - Path: t.PkgPath() + ":" + t.Name(), - Type: t, - Fields: make([]ormModelField, 0), - } - - for i := 0; i < t.NumField(); i++ { - typ := t.Field(i) - val := v.FieldByName(typ.Name) - tag := typ.Tag.Get(tagKey) - - if typ.Type.AssignableTo(typeTableName) { - mod.Table = TableName(tag) - continue - } - - if len(tag) == 0 { - continue - } - - tags := strings.Split(tag, ";") - - field := ormModelField{ - Name: typ.Name, - Col: tags[0], - Val: getValueByType(val), - Empty: val.IsZero(), - } - - mod.Fields = append(mod.Fields, field) - } - - return mod -} - -func getValueByType(v reflect.Value) interface{} { - switch v.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return v.Int() - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return v.Uint() - case reflect.String: - return v.String() - case reflect.Float32, reflect.Float64: - return v.Float() - case reflect.Bool: - return v.Bool() - default: - panic(fmt.Sprintf("unknow type - %v", v.Kind())) - } -} diff --git a/models_test.go b/models_test.go deleted file mode 100644 index b5a764b..0000000 --- a/models_test.go +++ /dev/null @@ -1,32 +0,0 @@ -package orm - -import ( - "fmt" - "testing" - "time" - - "github.com/stretchr/testify/require" -) - -func Test_parseModel(t *testing.T) { - type TestModel struct { - _ TableName `orm:"tests"` - ID int64 `orm:"id;index"` - Email string `orm:"email"` - T time.Duration `orm:"time"` - B byte - Link *int - } - - res := parseModel(TestModel{ - ID: 77884, - Email: "aaaaaaaa", - T: time.Minute, - B: 0, - }) - fmt.Println(res) - res = parseModel(&TestModel{}) - fmt.Println(res) - - require.Nil(t, nil) -} diff --git a/result.go b/result.go deleted file mode 100644 index cc65af1..0000000 --- a/result.go +++ /dev/null @@ -1,23 +0,0 @@ -package orm - -//Result model -type Result struct { - Err error - Rows int64 -} - -//Resulter interface -type Resulter interface { - Error() error - RowsAffected() int64 -} - -//Err ... -func (r *Result) Error() error { - return r.Err -} - -//RowsAffected ... -func (r *Result) RowsAffected() int64 { - return r.Rows -} diff --git a/stmt.go b/stmt.go index 019e13b..0d7d345 100644 --- a/stmt.go +++ b/stmt.go @@ -1,7 +1,6 @@ package orm import ( - "context" "database/sql" ) @@ -17,12 +16,6 @@ type ( Dialect() string Pool(string) (*sql.DB, error) } - //StmtInterface statement interface - StmtInterface interface { - Call(string, func(context.Context, *sql.DB) error) error - Tx(string, func(context.Context, *sql.Tx) error) error - Ping() error - } ) //newStmt init new statement diff --git a/stmt_exec.go b/stmt_exec.go new file mode 100644 index 0000000..c9cf7d1 --- /dev/null +++ b/stmt_exec.go @@ -0,0 +1,97 @@ +package orm + +import ( + "context" + "database/sql" + "sync" +) + +var poolExec = sync.Pool{New: func() interface{} { return &exec{} }} + +type exec struct { + Q string + P [][]interface{} + B func(result Result) error +} + +func (v *exec) SQL(query string, args ...interface{}) { + v.Q = query + v.Params(args...) +} + +func (v *exec) Params(args ...interface{}) { + if len(args) > 0 { + v.P = append(v.P, args) + } +} +func (v *exec) Bind(call func(result Result) error) { + v.B = call +} + +func (v *exec) Reset() *exec { + v.Q, v.P, v.B = "", nil, nil + return v +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +type ( + //Result exec result model + Result struct { + RowsAffected int64 + LastInsertId int64 + } + //Executor interface for generate execute query + Executor interface { + SQL(query string, args ...interface{}) + Params(args ...interface{}) + Bind(call func(result Result) error) + } +) + +//ExecContext ... +func (s *Stmt) ExecContext(name string, ctx context.Context, call func(q Executor)) error { + return s.CallContext(name, ctx, func(ctx context.Context, db *sql.DB) error { + return callExecContext(ctx, db, call) + }) +} + +func callExecContext(ctx context.Context, db dbGetter, call func(q Executor)) error { + q, ok := poolExec.Get().(*exec) + if !ok { + return ErrInvalidModelPool + } + defer poolExec.Put(q.Reset()) + + call(q) + + stmt, err := db.PrepareContext(ctx, q.Q) + if err != nil { + return err + } + defer stmt.Close() //nolint: errcheck + var total Result + for _, params := range q.P { + result, err0 := stmt.Exec(params...) + if err0 != nil { + return err0 + } + rows, err0 := result.RowsAffected() + if err0 != nil { + return err0 + } + total.RowsAffected += rows + rows, err0 = result.LastInsertId() + if err0 != nil { + return err0 + } + total.LastInsertId = rows + } + if err = stmt.Close(); err != nil { + return err + } + if q.B == nil { + return nil + } + return q.B(total) +} diff --git a/stmt_ping.go b/stmt_ping.go deleted file mode 100644 index a2d37f3..0000000 --- a/stmt_ping.go +++ /dev/null @@ -1,13 +0,0 @@ -package orm - -import ( - "context" - "database/sql" -) - -//Ping database ping -func (s *Stmt) Ping() error { - return s.Call("ping", func(ctx context.Context, db *sql.DB) error { - return db.PingContext(ctx) - }) -} diff --git a/stmt_query.go b/stmt_query.go new file mode 100644 index 0000000..6ac4776 --- /dev/null +++ b/stmt_query.go @@ -0,0 +1,80 @@ +package orm + +import ( + "context" + "database/sql" + "sync" +) + +var poolQuery = sync.Pool{New: func() interface{} { return &query{} }} + +type query struct { + Q string + P []interface{} + B func(bind Scanner) error +} + +func (v *query) SQL(query string, args ...interface{}) { + v.Q, v.P = query, args +} + +func (v *query) Bind(call func(bind Scanner) error) { + v.B = call +} + +func (v *query) Reset() *query { + v.Q, v.P, v.B = "", nil, nil + return v +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +type ( + //Scanner interface for bind data + Scanner interface { + Scan(args ...interface{}) error + } + + //Querier interface for generate query + Querier interface { + SQL(query string, args ...interface{}) + Bind(call func(bind Scanner) error) + } +) + +//QueryContext ... +func (s *Stmt) QueryContext(name string, ctx context.Context, call func(q Querier)) error { + return s.CallContext(name, ctx, func(ctx context.Context, db *sql.DB) error { + return callQueryContext(ctx, db, call) + }) +} + +func callQueryContext(ctx context.Context, db dbGetter, call func(q Querier)) error { + q, ok := poolQuery.Get().(*query) + if !ok { + return ErrInvalidModelPool + } + defer poolQuery.Put(q.Reset()) + + call(q) + + rows, err := db.QueryContext(ctx, q.Q, q.P...) + if err != nil { + return err + } + defer rows.Close() //nolint: errcheck + if q.B != nil { + for rows.Next() { + if err = q.B(rows); err != nil { + return err + } + } + } + if err = rows.Close(); err != nil { + return err + } + if err = rows.Err(); err != nil { + return err + } + return nil +} diff --git a/stmt_raw.go b/stmt_raw.go index 86e534e..97c75d2 100644 --- a/stmt_raw.go +++ b/stmt_raw.go @@ -3,13 +3,19 @@ package orm import ( "context" "database/sql" + + "github.com/deweppro/go-errors" ) -//Call basic query execution -func (s *Stmt) Call(name string, callFunc func(context.Context, *sql.DB) error) error { - ctx, cncl := context.WithCancel(context.Background()) - defer cncl() +//Ping database ping +func (s *Stmt) Ping() error { + return s.CallContext("ping", context.Background(), func(ctx context.Context, db *sql.DB) error { + return db.PingContext(ctx) + }) +} +//CallContext basic query execution +func (s *Stmt) CallContext(name string, ctx context.Context, callFunc func(context.Context, *sql.DB) error) error { pool, err := s.db.Pool(s.name) if err != nil { return err @@ -20,20 +26,22 @@ func (s *Stmt) Call(name string, callFunc func(context.Context, *sql.DB) error) return err } -//Tx the basic execution of a query in a transaction -func (s *Stmt) Tx(name string, callFunc func(context.Context, *sql.Tx) error) error { - return s.Call(name, func(ctx context.Context, db *sql.DB) error { - tx, err := db.BeginTx(ctx, nil) +//TxContext the basic execution of a query in a transaction +func (s *Stmt) TxContext(name string, ctx context.Context, callFunc func(context.Context, *sql.Tx) error) error { + return s.CallContext(name, ctx, func(ctx context.Context, db *sql.DB) error { + dbx, err := db.BeginTx(ctx, nil) if err != nil { return err } - defer func() { - if err := tx.Rollback(); err != nil { - s.plug.Logger.Errorf("tx rollback: %s", err.Error()) - } - }() + err = callFunc(ctx, dbx) + if err != nil { + return errors.Wrap( + errors.WrapMessage(err, "execute tx"), + errors.WrapMessage(dbx.Rollback(), "rollback tx"), + ) + } - return callFunc(ctx, tx) + return dbx.Commit() }) } diff --git a/stmt_test.go b/stmt_test.go new file mode 100644 index 0000000..e741e37 --- /dev/null +++ b/stmt_test.go @@ -0,0 +1,191 @@ +package orm_test + +import ( + "context" + "database/sql" + "io/ioutil" + "os" + "testing" + + "github.com/deweppro/go-orm" + "github.com/deweppro/go-orm/schema/sqlite" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUnit_Stmt(t *testing.T) { + file, err := ioutil.TempFile("/tmp", "prefix") + require.NoError(t, err) + defer os.Remove(file.Name()) //nolint: errcheck + + conn := sqlite.New(&sqlite.Config{Pool: []sqlite.Item{{Name: "main", File: file.Name()}}}) + require.NoError(t, conn.Reconnect()) + defer conn.Close() //nolint: errcheck + pool := orm.NewDB(conn, orm.Plugins{}).Pool("main") + + err = pool.CallContext("init", context.Background(), func(ctx context.Context, db *sql.DB) error { + sqls := []string{ + `create table users ( + id INTEGER PRIMARY KEY AUTOINCREMENT UNIQUE, + name TEXT + );`, + "insert into `users` (`id`, `name`) values (1, 'aaaa');", + "insert into `users` (`id`, `name`) values (2, 'bbbb');", + } + + for _, item := range sqls { + if _, err = db.ExecContext(ctx, item); err != nil { + return err + } + } + return nil + }) + require.NoError(t, err) + + err = pool.QueryContext("", context.Background(), func(q orm.Querier) { + q.SQL("select `name` from `users` where `id` = ?", 1) + q.Bind(func(bind orm.Scanner) error { + name := "" + assert.NoError(t, bind.Scan(&name)) + assert.Equal(t, "aaaa", name) + return nil + }) + }) + assert.NoError(t, err) + + var result []string + err = pool.QueryContext("", context.Background(), func(q orm.Querier) { + q.SQL("select `name` from `users`") + q.Bind(func(bind orm.Scanner) error { + name := "" + assert.NoError(t, bind.Scan(&name)) + result = append(result, name) + return nil + }) + }) + assert.NoError(t, err) + assert.Equal(t, []string{"aaaa", "bbbb"}, result) + + err = pool.ExecContext("", context.Background(), func(e orm.Executor) { + e.SQL("insert into `users` (`id`, `name`) values (?, ?);") + e.Params(3, "cccc") + e.Params(4, "dddd") + + e.Bind(func(result orm.Result) error { + assert.Equal(t, int64(2), result.RowsAffected) + assert.Equal(t, int64(4), result.LastInsertId) + return nil + }) + }) + assert.NoError(t, err) + + var result2 []string + err = pool.QueryContext("", context.Background(), func(q orm.Querier) { + q.SQL("select `name` from `users`") + q.Bind(func(bind orm.Scanner) error { + name := "" + err = bind.Scan(&name) + result2 = append(result2, name) + return err + }) + }) + assert.NoError(t, err) + assert.Equal(t, []string{"aaaa", "bbbb", "cccc", "dddd"}, result2) + + var result3 []string + err = pool.TransactionContext("", context.Background(), func(v orm.Tx) { + v.Exec(func(e orm.Executor) { + e.SQL("insert into `users` (`id`, `name`) values (?, ?);") + e.Params(10, "abcd") + e.Params(11, "efgh") + e.Bind(func(result orm.Result) error { + assert.Equal(t, int64(2), result.RowsAffected) + assert.Equal(t, int64(11), result.LastInsertId) + return nil + }) + }) + v.Query(func(q orm.Querier) { + q.SQL("select `name` from `users`") + q.Bind(func(bind orm.Scanner) error { + name := "" + err = bind.Scan(&name) + result3 = append(result3, name) + return err + }) + }) + }) + assert.NoError(t, err) + assert.Equal(t, []string{"aaaa", "bbbb", "cccc", "dddd", "abcd", "efgh"}, result3) + + var result4 []string + err = pool.QueryContext("", context.Background(), func(q orm.Querier) { + q.SQL("select `name` from `users`") + q.Bind(func(bind orm.Scanner) error { + name := "" + err = bind.Scan(&name) + result4 = append(result4, name) + return err + }) + }) + assert.NoError(t, err) + assert.Equal(t, []string{"aaaa", "bbbb", "cccc", "dddd", "abcd", "efgh"}, result4) +} + +func BenchmarkStmt(b *testing.B) { + file, err := ioutil.TempFile("/tmp", "prefix") + require.NoError(b, err) + defer os.Remove(file.Name()) //nolint: errcheck + + conn := sqlite.New(&sqlite.Config{Pool: []sqlite.Item{{Name: "main", File: file.Name()}}}) + require.NoError(b, conn.Reconnect()) + defer conn.Close() //nolint: errcheck + pool := orm.NewDB(conn, orm.Plugins{}).Pool("main") + + err = pool.CallContext("init", context.Background(), func(ctx context.Context, db *sql.DB) error { + sqls := []string{ + `create table users ( + id INTEGER PRIMARY KEY AUTOINCREMENT UNIQUE, + name TEXT + );`, + } + + for _, item := range sqls { + if _, err = db.ExecContext(ctx, item); err != nil { + return err + } + } + return nil + }) + require.NoError(b, err) + + b.Run("insert", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + + for i := 1; i < b.N; i++ { + err = pool.ExecContext("", context.Background(), func(e orm.Executor) { + i := i + e.SQL("insert or ignore into `users` (`id`, `name`) values (?, ?);") + e.Params(i, "cccc") + }) + assert.NoError(b, err) + } + }) + + var name string + b.Run("select", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + + for i := 1; i < b.N; i++ { + err = pool.QueryContext("", context.Background(), func(q orm.Querier) { + i := i + q.SQL("select `name` from `users` where `id` = ?", i) + q.Bind(func(bind orm.Scanner) error { + return bind.Scan(&name) + }) + }) + assert.NoError(b, err) + } + }) +} diff --git a/stmt_tx.go b/stmt_tx.go new file mode 100644 index 0000000..0f22319 --- /dev/null +++ b/stmt_tx.go @@ -0,0 +1,72 @@ +package orm + +import ( + "context" + "database/sql" + "fmt" + "sync" +) + +var poolTx = sync.Pool{New: func() interface{} { return &tx{} }} + +type ( + Tx interface { + Exec(vv ...func(e Executor)) + Query(vv ...func(q Querier)) + } + + tx struct { + v []interface{} + } + + dbGetter interface { + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) + } +) + +func (v *tx) Exec(vv ...func(q Executor)) { + for _, f := range vv { + v.v = append(v.v, f) + } +} + +func (v *tx) Query(vv ...func(q Querier)) { + for _, f := range vv { + v.v = append(v.v, f) + } +} + +func (v *tx) Reset() *tx { + v.v = v.v[:0] + return v +} + +func (s *Stmt) TransactionContext(name string, ctx context.Context, call func(v Tx)) error { + q, ok := poolTx.Get().(*tx) + if !ok { + return ErrInvalidModelPool + } + defer poolTx.Put(q.Reset()) + + call(q) + + return s.TxContext(name, ctx, func(ctx context.Context, tx *sql.Tx) error { + for i, c := range q.v { + if cc, ok := c.(func(q Executor)); ok { + if err := callExecContext(ctx, tx, cc); err != nil { + return err + } + continue + } + if cc, ok := c.(func(q Querier)); ok { + if err := callQueryContext(ctx, tx, cc); err != nil { + return err + } + continue + } + return fmt.Errorf("unknown query model #%d", i) + } + return nil + }) +} diff --git a/types/soft_delete.go b/types/soft_delete.go deleted file mode 100644 index 72fdaa2..0000000 --- a/types/soft_delete.go +++ /dev/null @@ -1,4 +0,0 @@ -package types - -//DeletedAt soft deleting model -type DeletedAt TimeAt