Skip to content
This repository has been archived by the owner on Jul 14, 2022. It is now read-only.

Feature/ezmsg #87

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 71 additions & 5 deletions orm/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,15 @@ type DBTx struct {
err error
rowsAffected int64
wrappers []database.Wrapper
afterCommit func(err error)
}

func (store *DBStore) BeginTx() (*DBTx, error) {
tx, err := store.Begin()
return store.BeginTxContext(context.Background())
}

func (store *DBStore) BeginTxContext(ctx context.Context) (*DBTx, error) {
tx, err := store.DB.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
Expand All @@ -198,10 +203,18 @@ func (tx *DBTx) Close() error {
if tx.err != nil {
return tx.tx.Rollback()
}
return tx.tx.Commit()
err := tx.tx.Commit()
if tx.afterCommit != nil {
tx.afterCommit(err)
}
return err
}

func (tx *DBTx) Query(sql string, args ...interface{}) (*sql.Rows, error) {
return tx.QueryContext(context.Background(), sql, args...)
}

func (tx *DBTx) queryContext(ctx context.Context, sql string, args ...interface{}) (*sql.Rows, error) {
t1 := time.Now()
if tx.slowlog > 0 {
defer func(t time.Time) {
Expand All @@ -214,14 +227,18 @@ func (tx *DBTx) Query(sql string, args ...interface{}) (*sql.Rows, error) {
if tx.debug {
log.Println("DEBUG: ", sql, args)
}
result, err := tx.tx.Query(sql, args...)
result, err := tx.tx.QueryContext(ctx, sql, args...)
if err != nil {
tx.err = err
}
return result, tx.err
}

func (tx *DBTx) Exec(sql string, args ...interface{}) (sql.Result, error) {
return tx.ExecContext(context.Background(), sql, args...)
}

func (tx *DBTx) execContext(ctx context.Context, sql string, args ...interface{}) (sql.Result, error) {
t1 := time.Now()
if tx.slowlog > 0 {
defer func(t time.Time) {
Expand All @@ -244,7 +261,7 @@ func (tx *DBTx) Exec(sql string, args ...interface{}) (sql.Result, error) {
func (tx *DBTx) QueryContext(ctx context.Context, query string,
args ...interface{}) (*sql.Rows, error) {
fn := func(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
return tx.tx.QueryContext(ctx, query, args...)
return tx.queryContext(ctx, query, args...)
}
for _, wp := range tx.wrappers {
fn = wp.WrapQueryContext(fn, query, args...)
Expand All @@ -255,7 +272,7 @@ func (tx *DBTx) QueryContext(ctx context.Context, query string,
func (tx *DBTx) ExecContext(ctx context.Context, query string,
args ...interface{}) (sql.Result, error) {
fn := func(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
return tx.tx.ExecContext(ctx, query, args...)
return tx.execContext(ctx, query, args...)
}
for _, wp := range tx.wrappers {
fn = wp.WrapExecContext(fn, query, args...)
Expand All @@ -267,6 +284,14 @@ func (tx *DBTx) SetError(err error) {
tx.err = err
}

func (tx *DBTx) AfterCommit(afterCommit func(err error)) {
tx.afterCommit = afterCommit
}

func (tx *DBTx) GetStdTx() *sql.Tx {
return tx.tx
}

func TransactFunc(db *DBStore, txFunc func(*DBTx) error) (err error) {
tx, err := db.BeginTx()
if err != nil {
Expand All @@ -289,10 +314,51 @@ func TransactFunc(db *DBStore, txFunc func(*DBTx) error) (err error) {
return err
}

func TransactFuncContext(ctx context.Context, db *DBStore, txFunc func(ctx context.Context, tx *DBTx) error) (err error) {
tx, err := db.BeginTxContext(ctx)
if err != nil {
return err
}
defer func() {
if p := recover(); p != nil {
tx.SetError(fmt.Errorf("panic: %v", p))
tx.Close()
panic(p)
} else if err != nil {
tx.SetError(err)
tx.Close()
} else {
err = tx.Close()
}
}()

err = txFunc(ctx, tx)
return err
}

type Transactor interface {
Transact(tx *DBTx) error
}

func Transact(db *DBStore, t Transactor) error {
return TransactFunc(db, t.Transact)
}

type TransactorWithContext interface {
TransactContext(ctx context.Context, tx *DBTx) error
}

func TransactContext(ctx context.Context, db *DBStore, t TransactorWithContext) error {
return TransactFuncContext(ctx, db, t.TransactContext)
}

func BeginTx(ctx context.Context, db *sql.DB) (*DBTx, error) {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}

return &DBTx{
tx: tx,
}, nil
}