Skip to content

Commit

Permalink
Handle writes that could deadlock with reads from the server
Browse files Browse the repository at this point in the history
This commit adds a background reader that can optionally buffer reads.
It is used whenever a potentially blocking write is made to the server.
The background reader is started on a slight delay so there should be no
meaningful performance impact as it doesn't run for quick queries and
its overhead is minimal relative to slower queries.
  • Loading branch information
jackc committed Jun 12, 2023
1 parent 85136a8 commit 26c79eb
Show file tree
Hide file tree
Showing 5 changed files with 316 additions and 15 deletions.
4 changes: 2 additions & 2 deletions pgconn/auth_scram.go
Expand Up @@ -42,7 +42,7 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
Data: sc.clientFirstMessage(),
}
c.frontend.Send(saslInitialResponse)
err = c.frontend.Flush()
err = c.flushWithPotentialWriteReadDeadlock()
if err != nil {
return err
}
Expand All @@ -62,7 +62,7 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
Data: []byte(sc.clientFinalMessage()),
}
c.frontend.Send(saslResponse)
err = c.frontend.Flush()
err = c.flushWithPotentialWriteReadDeadlock()
if err != nil {
return err
}
Expand Down
132 changes: 132 additions & 0 deletions pgconn/internal/bgreader/bgreader.go
@@ -0,0 +1,132 @@
// Package bgreader provides a io.Reader that can optionally buffer reads in the background.
package bgreader

import (
"io"
"sync"

"github.com/jackc/pgx/v5/internal/iobufpool"
)

const (
bgReaderStatusStopped = iota
bgReaderStatusRunning
bgReaderStatusStopping
)

// BGReader is an io.Reader that can optionally buffer reads in the background. It is safe for concurrent use.
type BGReader struct {
r io.Reader

cond *sync.Cond
bgReaderStatus int32
readResults []readResult
}

type readResult struct {
buf *[]byte
err error
}

// Start starts the backgrounder reader. If the background reader is already running this is a no-op. The background
// reader will stop automatically when the underlying reader returns an error.
func (r *BGReader) Start() {
r.cond.L.Lock()
defer r.cond.L.Unlock()

switch r.bgReaderStatus {
case bgReaderStatusStopped:
r.bgReaderStatus = bgReaderStatusRunning
go r.bgRead()
case bgReaderStatusRunning:
// no-op
case bgReaderStatusStopping:
r.bgReaderStatus = bgReaderStatusRunning
}
}

// Stop tells the background reader to stop after the in progress Read returns. It is safe to call Stop when the
// background reader is not running.
func (r *BGReader) Stop() {
r.cond.L.Lock()
defer r.cond.L.Unlock()

switch r.bgReaderStatus {
case bgReaderStatusStopped:
// no-op
case bgReaderStatusRunning:
r.bgReaderStatus = bgReaderStatusStopping
case bgReaderStatusStopping:
// no-op
}
}

func (r *BGReader) bgRead() {
keepReading := true
for keepReading {
buf := iobufpool.Get(8192)
n, err := r.r.Read(*buf)
*buf = (*buf)[:n]

r.cond.L.Lock()
r.readResults = append(r.readResults, readResult{buf: buf, err: err})
if r.bgReaderStatus == bgReaderStatusStopping || err != nil {
r.bgReaderStatus = bgReaderStatusStopped
keepReading = false
}
r.cond.L.Unlock()
r.cond.Broadcast()
}
}

// Read implements the io.Reader interface.
func (r *BGReader) Read(p []byte) (int, error) {
r.cond.L.Lock()
defer r.cond.L.Unlock()

if len(r.readResults) > 0 {
return r.readFromReadResults(p)
}

// There are no unread background read results and the background reader is stopped.
if r.bgReaderStatus == bgReaderStatusStopped {
return r.r.Read(p)
}

// Wait for results from the background reader
for len(r.readResults) == 0 {
r.cond.Wait()
}
return r.readFromReadResults(p)
}

// readBackgroundResults reads a result previously read by the background reader. r.cond.L must be held.
func (r *BGReader) readFromReadResults(p []byte) (int, error) {
buf := r.readResults[0].buf
var err error

n := copy(p, *buf)
if n == len(*buf) {
err = r.readResults[0].err
iobufpool.Put(buf)
if len(r.readResults) == 1 {
r.readResults = nil
} else {
r.readResults = r.readResults[1:]
}
} else {
*buf = (*buf)[n:]
r.readResults[0].buf = buf
}

return n, err
}

func New(r io.Reader) *BGReader {
return &BGReader{
r: r,
cond: &sync.Cond{
L: &sync.Mutex{},
},
}
}
140 changes: 140 additions & 0 deletions pgconn/internal/bgreader/bgreader_test.go
@@ -0,0 +1,140 @@
package bgreader_test

import (
"bytes"
"errors"
"io"
"math/rand"
"testing"
"time"

"github.com/jackc/pgx/v5/pgconn/internal/bgreader"
"github.com/stretchr/testify/require"
)

func TestBGReaderReadWhenStopped(t *testing.T) {
r := bytes.NewReader([]byte("foo bar baz"))
bgr := bgreader.New(r)
buf, err := io.ReadAll(bgr)
require.NoError(t, err)
require.Equal(t, []byte("foo bar baz"), buf)
}

func TestBGReaderReadWhenStarted(t *testing.T) {
r := bytes.NewReader([]byte("foo bar baz"))
bgr := bgreader.New(r)
bgr.Start()
buf, err := io.ReadAll(bgr)
require.NoError(t, err)
require.Equal(t, []byte("foo bar baz"), buf)
}

type mockReadFunc func(p []byte) (int, error)

type mockReader struct {
readFuncs []mockReadFunc
}

func (r *mockReader) Read(p []byte) (int, error) {
if len(r.readFuncs) == 0 {
return 0, io.EOF
}

fn := r.readFuncs[0]
r.readFuncs = r.readFuncs[1:]

return fn(p)
}

func TestBGReaderReadWaitsForBackgroundRead(t *testing.T) {
rr := &mockReader{
readFuncs: []mockReadFunc{
func(p []byte) (int, error) { time.Sleep(1 * time.Second); return copy(p, []byte("foo")), nil },
func(p []byte) (int, error) { return copy(p, []byte("bar")), nil },
func(p []byte) (int, error) { return copy(p, []byte("baz")), nil },
},
}
bgr := bgreader.New(rr)
bgr.Start()
buf := make([]byte, 3)
n, err := bgr.Read(buf)
require.NoError(t, err)
require.EqualValues(t, 3, n)
require.Equal(t, []byte("foo"), buf)
}

func TestBGReaderErrorWhenStarted(t *testing.T) {
rr := &mockReader{
readFuncs: []mockReadFunc{
func(p []byte) (int, error) { return copy(p, []byte("foo")), nil },
func(p []byte) (int, error) { return copy(p, []byte("bar")), nil },
func(p []byte) (int, error) { return copy(p, []byte("baz")), errors.New("oops") },
},
}

bgr := bgreader.New(rr)
bgr.Start()
buf, err := io.ReadAll(bgr)
require.Equal(t, []byte("foobarbaz"), buf)
require.EqualError(t, err, "oops")
}

func TestBGReaderErrorWhenStopped(t *testing.T) {
rr := &mockReader{
readFuncs: []mockReadFunc{
func(p []byte) (int, error) { return copy(p, []byte("foo")), nil },
func(p []byte) (int, error) { return copy(p, []byte("bar")), nil },
func(p []byte) (int, error) { return copy(p, []byte("baz")), errors.New("oops") },
},
}

bgr := bgreader.New(rr)
buf, err := io.ReadAll(bgr)
require.Equal(t, []byte("foobarbaz"), buf)
require.EqualError(t, err, "oops")
}

type numberReader struct {
v uint8
rng *rand.Rand
}

func (nr *numberReader) Read(p []byte) (int, error) {
n := nr.rng.Intn(len(p))
for i := 0; i < n; i++ {
p[i] = nr.v
nr.v++
}

return n, nil
}

// TestBGReaderStress stress tests BGReader by reading a lot of bytes in random sizes while randomly starting and
// stopping the background worker from other goroutines.
func TestBGReaderStress(t *testing.T) {
nr := &numberReader{rng: rand.New(rand.NewSource(0))}
bgr := bgreader.New(nr)

bytesRead := 0
var expected uint8
buf := make([]byte, 10_000)
rng := rand.New(rand.NewSource(0))

for bytesRead < 1_000_000 {
randomNumber := rng.Intn(100)
switch {
case randomNumber < 10:
go bgr.Start()
case randomNumber < 20:
go bgr.Stop()
default:
n, err := bgr.Read(buf)
require.NoError(t, err)
for i := 0; i < n; i++ {
require.Equal(t, expected, buf[i])
expected++
}
bytesRead += n
}
}
}
2 changes: 1 addition & 1 deletion pgconn/krb5.go
Expand Up @@ -63,7 +63,7 @@ func (c *PgConn) gssAuth() error {
Data: nextData,
}
c.frontend.Send(gssResponse)
err = c.frontend.Flush()
err = c.flushWithPotentialWriteReadDeadlock()
if err != nil {
return err
}
Expand Down

0 comments on commit 26c79eb

Please sign in to comment.