Skip to content

Commit

Permalink
feat: 添加 Context 相关方法
Browse files Browse the repository at this point in the history
  • Loading branch information
caixw committed Apr 9, 2024
1 parent 128a098 commit e50cbd0
Show file tree
Hide file tree
Showing 10 changed files with 222 additions and 80 deletions.
8 changes: 2 additions & 6 deletions core/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,24 +39,20 @@ type ConstraintType int8
// - {} 符号会被替换为 [Dialect.Quotes] 对应的符号;
// - # 会被替换为 [Engine.TablePrefix] 的返回值;
type Engine interface {
Dialect() Dialect

Query(query string, args ...any) (*sql.Rows, error)

QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)

QueryRow(query string, args ...any) *sql.Row

QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row

Exec(query string, args ...any) (sql.Result, error)

ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)

Prepare(query string) (*Stmt, error)

PrepareContext(ctx context.Context, query string) (*Stmt, error)

Dialect() Dialect

// TablePrefix 所有数据表拥有的统一表名前缀
//
// 当需要在一个数据库中创建不同的实例,
Expand Down
58 changes: 48 additions & 10 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,35 +92,73 @@ func (db *DB) Close() error {
// Version 数据库服务端的版本号
func (db *DB) Version() string { return db.version }

func (db *DB) LastInsertID(v TableNamer) (int64, error) { return lastInsertID(db, v) }
func (db *DB) LastInsertID(v TableNamer) (int64, error) {
return db.LastInsertIDContext(context.Background(), v)
}

func (db *DB) LastInsertIDContext(ctx context.Context, v TableNamer) (int64, error) {
return lastInsertID(ctx, db, v)
}

// Insert 插入数据
//
// NOTE: 若需一次性插入多条数据,请使用 [Tx.InsertMany]。
func (db *DB) Insert(v TableNamer) (sql.Result, error) { return insert(db, v) }
func (db *DB) Insert(v TableNamer) (sql.Result, error) {
return db.InsertContext(context.Background(), v)
}

func (db *DB) InsertContext(ctx context.Context, v TableNamer) (sql.Result, error) {
return insert(ctx, db, v)
}

func (db *DB) Delete(v TableNamer) (sql.Result, error) { return del(db, v) }
func (db *DB) Delete(v TableNamer) (sql.Result, error) {
return db.DeleteContext(context.Background(), v)
}

func (db *DB) Update(v TableNamer, cols ...string) (sql.Result, error) { return update(db, v, cols...) }
func (db *DB) DeleteContext(ctx context.Context, v TableNamer) (sql.Result, error) {
return del(ctx, db, v)
}

func (db *DB) Select(v TableNamer) (bool, error) { return find(db, v) }
func (db *DB) Update(v TableNamer, cols ...string) (sql.Result, error) {
return db.UpdateContext(context.Background(), v, cols...)
}

func (db *DB) Create(v TableNamer) error { return create(db, v) }
func (db *DB) UpdateContext(ctx context.Context, v TableNamer, cols ...string) (sql.Result, error) {
return update(ctx, db, v, cols...)
}

func (db *DB) Drop(v TableNamer) error { return drop(db, v) }
func (db *DB) Select(v TableNamer) (bool, error) { return db.SelectContext(context.Background(), v) }

func (db *DB) SelectContext(ctx context.Context, v TableNamer) (bool, error) { return find(ctx, db, v) }

func (db *DB) Create(v TableNamer) error { return db.CreateContext(context.Background(), v) }

func (db *DB) CreateContext(ctx context.Context, v TableNamer) error { return create(ctx, db, v) }

func (db *DB) Drop(v TableNamer) error { return db.DropContext(context.Background(), v) }

func (db *DB) DropContext(ctx context.Context, v TableNamer) error { return drop(ctx, db, v) }

func (db *DB) Truncate(v TableNamer) error {
return db.TruncateContext(context.Background(), v)
}

func (db *DB) TruncateContext(ctx context.Context, v TableNamer) error {
if !db.Dialect().TransactionalDDL() {
return truncate(db, v)
return truncate(ctx, db, v)
}
return db.DoTransaction(func(tx *Tx) error { return truncate(tx, v) })
return db.DoTransaction(func(tx *Tx) error { return truncate(ctx, tx, v) })
}

// InsertMany 一次插入多条数据
//
// 会自动转换成事务进行处理。
func (db *DB) InsertMany(max int, v ...TableNamer) error {
return db.DoTransaction(func(tx *Tx) error { return tx.InsertMany(max, v...) })
return db.InsertManyContext(context.Background(), max, v...)
}

func (db *DB) InsertManyContext(ctx context.Context, max int, v ...TableNamer) error {
return db.DoTransaction(func(tx *Tx) error { return tx.InsertManyContext(ctx, max, v...) })
}

func (db *DB) SQLBuilder() *sqlbuilder.SQLBuilder { return db.sqlBuilder }
Expand Down
45 changes: 23 additions & 22 deletions sqlbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package orm

import (
"context"
"database/sql"
"errors"
"fmt"
Expand Down Expand Up @@ -100,14 +101,14 @@ func getKV(rval reflect.Value, cols ...*core.Column) (keys []string, vals []any)
}

// 创建表或是视图
func create(e Engine, v TableNamer) error {
func create(ctx context.Context, e Engine, v TableNamer) error {
m, _, err := getModel(e, v)
if err != nil {
return err
}

if m.Type == core.View {
return createView(e, m)
return createView(ctx, e, m)
}

sb := e.SQLBuilder().CreateTable().Table(m.Name)
Expand Down Expand Up @@ -152,19 +153,19 @@ func create(e Engine, v TableNamer) error {
sb.PK(constraintName(m.Name, m.PrimaryKey.Name), cols...)
}

return sb.Exec()
return sb.ExecContext(ctx)
}

func createView(e Engine, m *core.Model) error {
func createView(ctx context.Context, e Engine, m *core.Model) error {
stmt := e.SQLBuilder().CreateView().Name(m.Name)

for _, col := range m.Columns {
stmt.Column(col.Name)
}
return stmt.FromQuery(m.ViewAs).Exec()
return stmt.FromQuery(m.ViewAs).ExecContext(ctx)
}

func truncate(e Engine, v TableNamer) error {
func truncate(ctx context.Context, e Engine, v TableNamer) error {
m, err := e.newModel(v)
if err != nil {
return err
Expand All @@ -181,24 +182,24 @@ func truncate(e Engine, v TableNamer) error {
stmt.Table(m.Name, "")
}

return stmt.Exec()
return stmt.ExecContext(ctx)
}

// 删除表或视图
func drop(e Engine, v TableNamer) error {
func drop(ctx context.Context, e Engine, v TableNamer) error {
m, err := e.newModel(v)
if err != nil {
return err
}

if m.Type == core.View {
return e.SQLBuilder().DropView().Name(m.Name).Exec()
return e.SQLBuilder().DropView().Name(m.Name).ExecContext(ctx)
}

return e.SQLBuilder().DropTable().Table(m.Name).Exec()
return e.SQLBuilder().DropTable().Table(m.Name).ExecContext(ctx)
}

func lastInsertID(e Engine, v TableNamer) (int64, error) {
func lastInsertID(ctx context.Context, e Engine, v TableNamer) (int64, error) {
m, rval, err := getModel(e, v)
if err != nil {
return 0, err
Expand Down Expand Up @@ -233,10 +234,10 @@ func lastInsertID(e Engine, v TableNamer) (int64, error) {
stmt.KeyValue(col.Name, field.Interface())
}

return stmt.LastInsertID(m.Name, m.AutoIncrement.Name)
return stmt.LastInsertIDContext(ctx, m.AutoIncrement.Name)
}

func insert(e Engine, v TableNamer) (sql.Result, error) {
func insert(ctx context.Context, e Engine, v TableNamer) (sql.Result, error) {
m, rval, err := getModel(e, v)
if err != nil {
return nil, err
Expand Down Expand Up @@ -267,14 +268,14 @@ func insert(e Engine, v TableNamer) (sql.Result, error) {
stmt.KeyValue(col.Name, field.Interface())
}

return stmt.Exec()
return stmt.ExecContext(ctx)
}

// 查找数据
//
// 根据 v 的 pk 或中唯一索引列查找一行数据,并赋值给 v。
// 若 v 为空,则不发生任何操作,v 可以是数组。
func find(e Engine, v TableNamer) (bool, error) {
func find(ctx context.Context, e Engine, v TableNamer) (bool, error) {
m, rval, err := getModel(e, v)
if err != nil {
return false, err
Expand All @@ -285,15 +286,15 @@ func find(e Engine, v TableNamer) (bool, error) {
return false, err
}

size, err := stmt.QueryObject(true, v)
size, err := stmt.QueryObjectContext(ctx, true, v)
if err != nil {
return false, err
}
return size > 0, nil
}

// for update 只能作用于事务
func forUpdate(tx *Tx, v TableNamer) error {
func forUpdate(ctx context.Context, tx *Tx, v TableNamer) error {
m, rval, err := getModel(tx, v)
if err != nil {
return err
Expand All @@ -314,7 +315,7 @@ func forUpdate(tx *Tx, v TableNamer) error {
return err
}

_, err = stmt.QueryObject(true, v)
_, err = stmt.QueryObjectContext(ctx, true, v)
return err
}

Expand All @@ -323,7 +324,7 @@ func forUpdate(tx *Tx, v TableNamer) error {
//
// 更新依据为每个对象的主键或是唯一索引列。
// 若不存在此两个类型的字段,则返回错误信息。
func update(e Engine, v TableNamer, cols ...string) (sql.Result, error) {
func update(ctx context.Context, e Engine, v TableNamer, cols ...string) (sql.Result, error) {
stmt := e.SQLBuilder().Update()

m, rval, err := getUpdateColumns(e, v, stmt, cols...)
Expand All @@ -335,7 +336,7 @@ func update(e Engine, v TableNamer, cols ...string) (sql.Result, error) {
return nil, err
}

return stmt.Exec()
return stmt.ExecContext(ctx)
}

func getUpdateColumns(e Engine, v TableNamer, stmt *sqlbuilder.UpdateStmt, cols ...string) (*core.Model, reflect.Value, error) {
Expand Down Expand Up @@ -378,7 +379,7 @@ func getUpdateColumns(e Engine, v TableNamer, stmt *sqlbuilder.UpdateStmt, cols
}

// 将 v 生成 delete 的 sql 语句
func del(e Engine, v TableNamer) (sql.Result, error) {
func del(ctx context.Context, e Engine, v TableNamer) (sql.Result, error) {
m, rval, err := getModel(e, v)
if err != nil {
return nil, err
Expand All @@ -393,7 +394,7 @@ func del(e Engine, v TableNamer) (sql.Result, error) {
return nil, err
}

return stmt.Exec()
return stmt.ExecContext(ctx)
}

var errInsertManyHasDifferentType = errors.New("InsertMany 必须是相同的数据类型")
Expand Down
2 changes: 1 addition & 1 deletion sqlbuilder/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ func (stmt *InsertStmt) fromSelect(builder *core.Builder) (string, []any, error)
// 并根据表名和自增列 ID 返回当前行的自增 ID 值。
//
// NOTE: 对于指定了自增值的,其结果是未知的。
func (stmt *InsertStmt) LastInsertID(table, col string) (int64, error) {
func (stmt *InsertStmt) LastInsertID(col string) (int64, error) {
return stmt.LastInsertIDContext(context.Background(), col)
}

Expand Down
36 changes: 26 additions & 10 deletions sqlbuilder/select.go
Original file line number Diff line number Diff line change
Expand Up @@ -437,25 +437,37 @@ func (stmt *SelectStmt) Union(all bool, sel ...*SelectStmt) *SelectStmt {
//
// 关于 objs 的类型,可以参考 [fetch.Object] 函数的相关介绍。
func (stmt *SelectStmt) QueryObject(strict bool, objs any) (size int, err error) {
rows, err := stmt.Query()
return stmt.QueryObjectContext(context.Background(), strict, objs)
}

func (stmt *SelectStmt) QueryObjectContext(ctx context.Context, strict bool, objs any) (size int, err error) {
rows, err := stmt.QueryContext(ctx)
if err != nil {
return 0, err
}
return queryObject(rows, strict, objs)
return fetchObject(rows, strict, objs)
}

// QueryString 查询指定列的第一行数据,并将其转换成 string
func (stmt *SelectStmt) QueryString(colName string) (v string, err error) {
rows, err := stmt.Query()
return stmt.QueryStringContext(context.Background(), colName)
}

func (stmt *SelectStmt) QueryStringContext(ctx context.Context, colName string) (v string, err error) {
rows, err := stmt.QueryContext(ctx)
if err != nil {
return "", err
}
return queryString(rows, colName)
return fetchString(rows, colName)
}

// QueryFloat 查询指定列的第一行数据,并将其转换成 float64
func (stmt *SelectStmt) QueryFloat(colName string) (float64, error) {
v, err := stmt.QueryString(colName)
return stmt.QueryFloatContext(context.Background(), colName)
}

func (stmt *SelectStmt) QueryFloatContext(ctx context.Context, colName string) (float64, error) {
v, err := stmt.QueryStringContext(ctx, colName)
if err != nil {
return 0, err
}
Expand All @@ -465,10 +477,14 @@ func (stmt *SelectStmt) QueryFloat(colName string) (float64, error) {

// QueryInt 查询指定列的第一行数据,并将其转换成 int64
func (stmt *SelectStmt) QueryInt(colName string) (int64, error) {
return stmt.QueryIntContext(context.Background(), colName)
}

func (stmt *SelectStmt) QueryIntContext(ctx context.Context, colName string) (int64, error) {
// NOTE: 可能会出现浮点数的情况。比如:
// select avg(xx) as avg form xxx where xxx
// 查询 avg 的值可能是 5.000 等值。
v, err := stmt.QueryString(colName)
v, err := stmt.QueryStringContext(ctx, colName)
if err != nil {
return 0, err
}
Expand Down Expand Up @@ -503,7 +519,7 @@ func (stmt *SelectQuery) QueryObject(strict bool, objs any, arg ...any) (size in
if err != nil {
return 0, err
}
return queryObject(rows, strict, objs)
return fetchObject(rows, strict, objs)
}

// QueryString 查询指定列的第一行数据,并将其转换成 string
Expand All @@ -512,7 +528,7 @@ func (stmt *SelectQuery) QueryString(colName string, arg ...any) (v string, err
if err != nil {
return "", err
}
return queryString(rows, colName)
return fetchString(rows, colName)
}

// QueryFloat 查询指定列的第一行数据,并将其转换成 float64
Expand Down Expand Up @@ -540,13 +556,13 @@ func (stmt *SelectQuery) QueryInt(colName string, arg ...any) (int64, error) {

func (stmt *SelectQuery) Close() error { return stmt.stmt.Close() }

func queryObject(rows *sql.Rows, strict bool, objs any) (size int, err error) {
func fetchObject(rows *sql.Rows, strict bool, objs any) (size int, err error) {
defer func() { err = errors.Join(err, rows.Close()) }()
size, err = fetch.Object(strict, rows, objs)
return
}

func queryString(rows *sql.Rows, colName string) (v string, err error) {
func fetchString(rows *sql.Rows, colName string) (v string, err error) {
defer func() { err = errors.Join(err, rows.Close()) }()

cols, err := fetch.ColumnString(true, colName, rows)
Expand Down
Loading

0 comments on commit e50cbd0

Please sign in to comment.