Skip to content

Commit

Permalink
Make BatchResults.Close safe to be called multiple times
Browse files Browse the repository at this point in the history
  • Loading branch information
jackc committed Feb 7, 2022
1 parent d02b2ed commit e8857f0
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 7 deletions.
33 changes: 26 additions & 7 deletions batch.go
Expand Up @@ -3,6 +3,7 @@ package pgx
import (
"context"
"errors"
"fmt"

"github.com/jackc/pgconn"
)
Expand Down Expand Up @@ -46,24 +47,28 @@ type BatchResults interface {

// Close closes the batch operation. This must be called before the underlying connection can be used again. Any error
// that occurred during a batch operation may have made it impossible to resyncronize the connection with the server.
// In this case the underlying connection will have been closed.
// In this case the underlying connection will have been closed. Close is safe to call multiple times.
Close() error
}

type batchResults struct {
ctx context.Context
conn *Conn
mrr *pgconn.MultiResultReader
err error
b *Batch
ix int
ctx context.Context
conn *Conn
mrr *pgconn.MultiResultReader
err error
b *Batch
ix int
closed bool
}

// Exec reads the results from the next query in the batch as if the query has been sent with Exec.
func (br *batchResults) Exec() (pgconn.CommandTag, error) {
if br.err != nil {
return nil, br.err
}
if br.closed {
return nil, fmt.Errorf("batch already closed")
}

query, arguments, _ := br.nextQueryAndArgs()

Expand Down Expand Up @@ -114,6 +119,11 @@ func (br *batchResults) Query() (Rows, error) {
return &connRows{err: br.err, closed: true}, br.err
}

if br.closed {
alreadyClosedErr := fmt.Errorf("batch already closed")
return &connRows{err: alreadyClosedErr, closed: true}, alreadyClosedErr
}

rows := br.conn.getRows(br.ctx, query, arguments)

if !br.mrr.NextResult() {
Expand All @@ -140,6 +150,10 @@ func (br *batchResults) Query() (Rows, error) {

// QueryFunc reads the results from the next query in the batch as if the query has been sent with Conn.QueryFunc.
func (br *batchResults) QueryFunc(scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) {
if br.closed {
return nil, fmt.Errorf("batch already closed")
}

rows, err := br.Query()
if err != nil {
return nil, err
Expand Down Expand Up @@ -179,6 +193,11 @@ func (br *batchResults) Close() error {
return br.err
}

if br.closed {
return nil
}
br.closed = true

// log any queries that haven't yet been logged by Exec or Query
for {
query, args, ok := br.nextQueryAndArgs()
Expand Down
52 changes: 52 additions & 0 deletions pgxpool/pool_test.go
Expand Up @@ -979,3 +979,55 @@ func TestCreateMinPoolReturnsFirstError(t *testing.T) {
require.True(t, connectAttempts >= 5, "Expected %d got %d", 5, connectAttempts)
require.ErrorIs(t, err, mockErr)
}

func TestPoolSendBatchBatchCloseTwice(t *testing.T) {
t.Parallel()

pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer pool.Close()

errChan := make(chan error)
testCount := 5000

for i := 0; i < testCount; i++ {
go func() {
batch := &pgx.Batch{}
batch.Queue("select 1")
batch.Queue("select 2")

br := pool.SendBatch(context.Background(), batch)
defer br.Close()

var err error
var n int32
err = br.QueryRow().Scan(&n)
if err != nil {
errChan <- err
return
}
if n != 1 {
errChan <- fmt.Errorf("expected 1 got %v", n)
return
}

err = br.QueryRow().Scan(&n)
if err != nil {
errChan <- err
return
}
if n != 2 {
errChan <- fmt.Errorf("expected 2 got %v", n)
return
}

err = br.Close()
errChan <- err
}()
}

for i := 0; i < testCount; i++ {
err := <-errChan
assert.NoError(t, err)
}
}

0 comments on commit e8857f0

Please sign in to comment.