diff --git a/gorm/transaction.go b/gorm/transaction.go index faa83ff1..197635cc 100644 --- a/gorm/transaction.go +++ b/gorm/transaction.go @@ -69,7 +69,7 @@ func BeginFromContext(ctx context.Context) (*gorm.DB, error) { if txn.parent == nil { return nil, ErrCtxTxnNoDB } - db := txn.Begin() + db := txn.beginWithContext(ctx) if db.Error != nil { return nil, db.Error } @@ -89,7 +89,7 @@ func BeginWithOptionsFromContext(ctx context.Context, opts *sql.TxOptions) (*gor if txn.parent == nil { return nil, ErrCtxTxnNoDB } - db := txn.BeginWithOptions(opts) + db := txn.beginWithContextAndOptions(ctx, opts) if db.Error != nil { return nil, db.Error } @@ -99,11 +99,15 @@ func BeginWithOptionsFromContext(ctx context.Context, opts *sql.TxOptions) (*gor // Begin starts new transaction by calling `*gorm.DB.Begin()` // Returns new instance of `*gorm.DB` (error can be checked by `*gorm.DB.Error`) func (t *Transaction) Begin() *gorm.DB { + return t.beginWithContext(context.Background()) +} + +func (t *Transaction) beginWithContext(ctx context.Context) *gorm.DB { t.mu.Lock() defer t.mu.Unlock() if t.current == nil { - t.current = t.parent.Begin() + t.current = t.parent.BeginTx(ctx, nil) } return t.current @@ -112,16 +116,21 @@ func (t *Transaction) Begin() *gorm.DB { // BeginWithOptions starts new transaction by calling `*gorm.DB.BeginTx()` // Returns new instance of `*gorm.DB` (error can be checked by `*gorm.DB.Error`) func (t *Transaction) BeginWithOptions(opts *sql.TxOptions) *gorm.DB { + return t.beginWithContextAndOptions(context.Background(), opts) +} + +func (t *Transaction) beginWithContextAndOptions(ctx context.Context, opts *sql.TxOptions) *gorm.DB { t.mu.Lock() defer t.mu.Unlock() if t.current == nil { - t.current = t.parent.BeginTx(context.Background(), opts) + t.current = t.parent.BeginTx(ctx, opts) } return t.current } + // Rollback terminates transaction by calling `*gorm.DB.Rollback()` // Reset current transaction and returns an error if any. func (t *Transaction) Rollback() error { diff --git a/gorm/transaction_test.go b/gorm/transaction_test.go index 8625a92d..208332f8 100644 --- a/gorm/transaction_test.go +++ b/gorm/transaction_test.go @@ -369,6 +369,7 @@ func TestBeginFromContext_Bad(t *testing.T) { tests := []struct { desc string withOpts bool + contextCanceled bool }{ { desc: "begin without options", @@ -378,11 +379,26 @@ func TestBeginFromContext_Bad(t *testing.T) { desc: "begin with options", withOpts: true, }, + { + desc: "canceled context without context", + withOpts: true, + contextCanceled: true, + }, + { + desc: "canceled context with options", + withOpts: false, + contextCanceled: true, + }, } for _, test := range tests { test := test t.Run(test.desc, func(t *testing.T) { ctx := context.Background() + if test.contextCanceled { + var cancel context.CancelFunc + ctx, cancel = context.WithCancel(ctx) + cancel() + } // Case: Transaction missing from context txn1, err := beginFromContext(ctx, test.withOpts)