Skip to content

Commit

Permalink
Add transaction context support
Browse files Browse the repository at this point in the history
  • Loading branch information
jackc committed May 20, 2017
1 parent 2df4b14 commit d1fd222
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 19 deletions.
6 changes: 3 additions & 3 deletions conn_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ func (p *ConnPool) QueryRowEx(ctx context.Context, sql string, options *QueryExO
// Begin acquires a connection and begins a transaction on it. When the
// transaction is closed the connection will be automatically released.
func (p *ConnPool) Begin() (*Tx, error) {
return p.BeginEx(nil)
return p.BeginEx(context.Background(), nil)
}

// Prepare creates a prepared statement on a connection in the pool to test the
Expand Down Expand Up @@ -499,14 +499,14 @@ func (p *ConnPool) Deallocate(name string) (err error) {
// BeginEx acquires a connection and starts a transaction with txOptions
// determining the transaction mode. When the transaction is closed the
// connection will be automatically released.
func (p *ConnPool) BeginEx(txOptions *TxOptions) (*Tx, error) {
func (p *ConnPool) BeginEx(ctx context.Context, txOptions *TxOptions) (*Tx, error) {
for {
c, err := p.Acquire()
if err != nil {
return nil, err
}

tx, err := c.BeginEx(txOptions)
tx, err := c.BeginEx(ctx, txOptions)
if err != nil {
alive := c.IsAlive()
p.Release(c)
Expand Down
3 changes: 2 additions & 1 deletion conn_pool_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package pgx_test

import (
"context"
"errors"
"fmt"
"net"
Expand Down Expand Up @@ -635,7 +636,7 @@ func TestConnPoolTransactionIso(t *testing.T) {
pool := createConnPool(t, 2)
defer pool.Close()

tx, err := pool.BeginEx(&pgx.TxOptions{IsoLevel: pgx.Serializable})
tx, err := pool.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: pgx.Serializable})
if err != nil {
t.Fatalf("pool.BeginEx failed: %v", err)
}
Expand Down
25 changes: 25 additions & 0 deletions pgmock/pgmock.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package pgmock
import (
"errors"
"fmt"
"io"
"net"
"reflect"

Expand Down Expand Up @@ -38,6 +39,9 @@ func (s *Server) ServeOne() error {
if err != nil {
return err
}
defer conn.Close()

s.Close()

backend, err := pgproto3.NewBackend(conn, conn)
if err != nil {
Expand Down Expand Up @@ -167,6 +171,27 @@ func SendMessage(msg pgproto3.BackendMessage) Step {
return &sendMessageStep{msg: msg}
}

type waitForCloseMessageStep struct{}

func (e *waitForCloseMessageStep) Step(backend *pgproto3.Backend) error {
for {
msg, err := backend.Receive()
if err == io.EOF {
return nil
} else if err != nil {
return err
}

if _, ok := msg.(*pgproto3.Terminate); ok {
return nil
}
}
}

func WaitForClose() Step {
return &waitForCloseMessageStep{}
}

func AcceptUnauthenticatedConnRequestSteps() []Step {
return []Step{
ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
Expand Down
2 changes: 1 addition & 1 deletion stdlib/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
pgxOpts.AccessMode = pgx.ReadOnly
}

return c.conn.BeginEx(&pgxOpts)
return c.conn.BeginEx(ctx, &pgxOpts)
}

func (c *Conn) Exec(query string, argsV []driver.Value) (driver.Result, error) {
Expand Down
11 changes: 7 additions & 4 deletions stdlib/sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,7 @@ func TestConnPingContextCancel(t *testing.T) {
script.Steps = append(script.Steps, pgmock.PgxInitSteps()...)
script.Steps = append(script.Steps,
pgmock.ExpectMessage(&pgproto3.Query{String: ";"}),
pgmock.WaitForClose(),
)

server, err := pgmock.NewServer(script)
Expand All @@ -855,7 +856,7 @@ func TestConnPingContextCancel(t *testing.T) {
}
defer server.Close()

errChan := make(chan error)
errChan := make(chan error, 1)
go func() {
errChan <- server.ServeOne()
}()
Expand All @@ -864,7 +865,7 @@ func TestConnPingContextCancel(t *testing.T) {
if err != nil {
t.Fatalf("sql.Open failed: %v", err)
}
// defer closeDB(t, db) // mock DB doesn't close correctly yet
defer closeDB(t, db)

ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond)

Expand Down Expand Up @@ -900,6 +901,7 @@ func TestConnPrepareContextCancel(t *testing.T) {
pgmock.ExpectMessage(&pgproto3.Parse{Name: "pgx_0", Query: "select now()"}),
pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'S', Name: "pgx_0"}),
pgmock.ExpectMessage(&pgproto3.Sync{}),
pgmock.WaitForClose(),
)

server, err := pgmock.NewServer(script)
Expand All @@ -917,7 +919,7 @@ func TestConnPrepareContextCancel(t *testing.T) {
if err != nil {
t.Fatalf("sql.Open failed: %v", err)
}
// defer closeDB(t, db) // mock DB doesn't close correctly yet
defer closeDB(t, db)

ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond)

Expand Down Expand Up @@ -950,6 +952,7 @@ func TestConnExecContextCancel(t *testing.T) {
script.Steps = append(script.Steps, pgmock.PgxInitSteps()...)
script.Steps = append(script.Steps,
pgmock.ExpectMessage(&pgproto3.Query{String: "create temporary table exec_context_test(id serial primary key)"}),
pgmock.WaitForClose(),
)

server, err := pgmock.NewServer(script)
Expand All @@ -967,7 +970,7 @@ func TestConnExecContextCancel(t *testing.T) {
if err != nil {
t.Fatalf("sql.Open failed: %v", err)
}
// defer closeDB(t, db) // mock DB doesn't close correctly yet
defer closeDB(t, db)

ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond)

Expand Down
28 changes: 22 additions & 6 deletions tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package pgx

import (
"bytes"
"context"
"errors"
"fmt"
"time"
)

type TxIsoLevel string
Expand Down Expand Up @@ -56,12 +58,13 @@ var ErrTxCommitRollback = errors.New("commit unexpectedly resulted in rollback")
// Begin starts a transaction with the default transaction mode for the
// current connection. To use a specific transaction mode see BeginEx.
func (c *Conn) Begin() (*Tx, error) {
return c.BeginEx(nil)
return c.BeginEx(context.Background(), nil)
}

// BeginEx starts a transaction with txOptions determining the transaction
// mode.
func (c *Conn) BeginEx(txOptions *TxOptions) (*Tx, error) {
// mode. Unlike database/sql, the context only affects the begin command. i.e.
// there is no auto-rollback on context cancelation.
func (c *Conn) BeginEx(ctx context.Context, txOptions *TxOptions) (*Tx, error) {
var beginSQL string
if txOptions == nil {
beginSQL = "begin"
Expand All @@ -81,8 +84,11 @@ func (c *Conn) BeginEx(txOptions *TxOptions) (*Tx, error) {
beginSQL = buf.String()
}

_, err := c.Exec(beginSQL)
_, err := c.ExecEx(ctx, beginSQL, nil)
if err != nil {
// begin should never fail unless there is an underlying connection issue or
// a context timeout. In either case, the connection is possibly broken.
c.die(errors.New("failed to begin transaction"))
return nil, err
}

Expand All @@ -102,11 +108,16 @@ type Tx struct {

// Commit commits the transaction
func (tx *Tx) Commit() error {
return tx.CommitEx(context.Background())
}

// CommitEx commits the transaction with a context.
func (tx *Tx) CommitEx(ctx context.Context) error {
if tx.status != TxStatusInProgress {
return ErrTxClosed
}

commandTag, err := tx.conn.Exec("commit")
commandTag, err := tx.conn.ExecEx(ctx, "commit", nil)
if err == nil && commandTag == "COMMIT" {
tx.status = TxStatusCommitSuccess
} else if err == nil && commandTag == "ROLLBACK" {
Expand All @@ -115,6 +126,8 @@ func (tx *Tx) Commit() error {
} else {
tx.status = TxStatusCommitFailure
tx.err = err
// A commit failure leaves the connection in an undefined state
tx.conn.die(errors.New("commit failed"))
}

if tx.connPool != nil {
Expand All @@ -133,11 +146,14 @@ func (tx *Tx) Rollback() error {
return ErrTxClosed
}

_, tx.err = tx.conn.Exec("rollback")
ctx, _ := context.WithTimeout(context.Background(), 15*time.Second)
_, tx.err = tx.conn.ExecEx(ctx, "rollback", nil)
if tx.err == nil {
tx.status = TxStatusRollbackSuccess
} else {
tx.status = TxStatusRollbackFailure
// A rollback failure leaves the connection in an undefined state
tx.conn.die(errors.New("rollback failed"))
}

if tx.connPool != nil {
Expand Down
112 changes: 108 additions & 4 deletions tx_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
package pgx_test

import (
"context"
"fmt"
"testing"
"time"

"github.com/jackc/pgx"
"github.com/jackc/pgx/pgmock"
"github.com/jackc/pgx/pgproto3"
)

func TestTransactionSuccessfulCommit(t *testing.T) {
Expand Down Expand Up @@ -107,13 +112,13 @@ func TestTxCommitSerializationFailure(t *testing.T) {
}
defer pool.Exec(`drop table tx_serializable_sums`)

tx1, err := pool.BeginEx(&pgx.TxOptions{IsoLevel: pgx.Serializable})
tx1, err := pool.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: pgx.Serializable})
if err != nil {
t.Fatalf("BeginEx failed: %v", err)
}
defer tx1.Rollback()

tx2, err := pool.BeginEx(&pgx.TxOptions{IsoLevel: pgx.Serializable})
tx2, err := pool.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: pgx.Serializable})
if err != nil {
t.Fatalf("BeginEx failed: %v", err)
}
Expand Down Expand Up @@ -190,7 +195,7 @@ func TestBeginExIsoLevels(t *testing.T) {

isoLevels := []pgx.TxIsoLevel{pgx.Serializable, pgx.RepeatableRead, pgx.ReadCommitted, pgx.ReadUncommitted}
for _, iso := range isoLevels {
tx, err := conn.BeginEx(&pgx.TxOptions{IsoLevel: iso})
tx, err := conn.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: iso})
if err != nil {
t.Fatalf("conn.BeginEx failed: %v", err)
}
Expand All @@ -214,7 +219,7 @@ func TestBeginExReadOnly(t *testing.T) {
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)

tx, err := conn.BeginEx(&pgx.TxOptions{AccessMode: pgx.ReadOnly})
tx, err := conn.BeginEx(context.Background(), &pgx.TxOptions{AccessMode: pgx.ReadOnly})
if err != nil {
t.Fatalf("conn.BeginEx failed: %v", err)
}
Expand All @@ -226,6 +231,105 @@ func TestBeginExReadOnly(t *testing.T) {
}
}

func TestConnBeginExContextCancel(t *testing.T) {
t.Parallel()

script := &pgmock.Script{
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
}
script.Steps = append(script.Steps, pgmock.PgxInitSteps()...)
script.Steps = append(script.Steps,
pgmock.ExpectMessage(&pgproto3.Query{String: "begin"}),
pgmock.WaitForClose(),
)

server, err := pgmock.NewServer(script)
if err != nil {
t.Fatal(err)
}
defer server.Close()

errChan := make(chan error, 1)
go func() {
errChan <- server.ServeOne()
}()

mockConfig, err := pgx.ParseURI(fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr()))
if err != nil {
t.Fatal(err)
}

conn := mustConnect(t, mockConfig)

ctx, _ := context.WithTimeout(context.Background(), 50*time.Millisecond)

_, err = conn.BeginEx(ctx, nil)
if err != context.DeadlineExceeded {
t.Errorf("err => %v, want %v", err, context.DeadlineExceeded)
}

if conn.IsAlive() {
t.Error("expected conn to be dead after BeginEx failure")
}

if err := <-errChan; err != nil {
t.Errorf("mock server err: %v", err)
}
}

func TestTxCommitExCancel(t *testing.T) {
t.Parallel()

script := &pgmock.Script{
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
}
script.Steps = append(script.Steps, pgmock.PgxInitSteps()...)
script.Steps = append(script.Steps,
pgmock.ExpectMessage(&pgproto3.Query{String: "begin"}),
pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: "BEGIN"}),
pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'T'}),
pgmock.WaitForClose(),
)

server, err := pgmock.NewServer(script)
if err != nil {
t.Fatal(err)
}
defer server.Close()

errChan := make(chan error, 1)
go func() {
errChan <- server.ServeOne()
}()

mockConfig, err := pgx.ParseURI(fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr()))
if err != nil {
t.Fatal(err)
}

conn := mustConnect(t, mockConfig)
defer conn.Close()

tx, err := conn.Begin()
if err != nil {
t.Fatal(err)
}

ctx, _ := context.WithTimeout(context.Background(), 50*time.Millisecond)
err = tx.CommitEx(ctx)
if err != context.DeadlineExceeded {
t.Errorf("err => %v, want %v", err, context.DeadlineExceeded)
}

if conn.IsAlive() {
t.Error("expected conn to be dead after CommitEx failure")
}

if err := <-errChan; err != nil {
t.Errorf("mock server err: %v", err)
}
}

func TestTxStatus(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit d1fd222

Please sign in to comment.