Skip to content

Commit

Permalink
Restore pgx v4 style CopyFrom implementation
Browse files Browse the repository at this point in the history
This approach uses an extra goroutine to write while the main goroutine
continues to read. This avoids the need to use non-blocking I/O.
  • Loading branch information
jackc committed Jun 12, 2023
1 parent 4410fc0 commit 85136a8
Showing 1 changed file with 89 additions and 49 deletions.
138 changes: 89 additions & 49 deletions pgconn/pgconn.go
Expand Up @@ -13,6 +13,7 @@ import (
"net"
"strconv"
"strings"
"sync"
"time"

"github.com/jackc/pgx/v5/internal/iobufpool"
Expand Down Expand Up @@ -75,6 +76,11 @@ type PgConn struct {

status byte // One of connStatus* constants

bufferingReceive bool
bufferingReceiveMux sync.Mutex
bufferingReceiveMsg pgproto3.BackendMessage
bufferingReceiveErr error

peekedMsg pgproto3.BackendMessage

// Reusable / preallocated resources
Expand Down Expand Up @@ -419,6 +425,24 @@ func hexMD5(s string) string {
return hex.EncodeToString(hash.Sum(nil))
}

func (pgConn *PgConn) signalMessage() chan struct{} {
if pgConn.bufferingReceive {
panic("BUG: signalMessage when already in progress")
}

pgConn.bufferingReceive = true
pgConn.bufferingReceiveMux.Lock()

ch := make(chan struct{})
go func() {
pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.frontend.Receive()
pgConn.bufferingReceiveMux.Unlock()
close(ch)
}()

return ch
}

// ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the
// connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages
// are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger
Expand Down Expand Up @@ -458,7 +482,23 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) {
return pgConn.peekedMsg, nil
}

msg, err := pgConn.frontend.Receive()
var msg pgproto3.BackendMessage
var err error
if pgConn.bufferingReceive {
pgConn.bufferingReceiveMux.Lock()
msg = pgConn.bufferingReceiveMsg
err = pgConn.bufferingReceiveErr
pgConn.bufferingReceiveMux.Unlock()
pgConn.bufferingReceive = false

// If a timeout error happened in the background try the read again.
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
msg, err = pgConn.frontend.Receive()
}
} else {
msg, err = pgConn.frontend.Receive()
}

if err != nil {
// Close on anything other than timeout error - everything else is fatal
Expand Down Expand Up @@ -1155,83 +1195,83 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
defer pgConn.contextWatcher.Unwatch()
}

// Send copy to command
// Send copy from query
pgConn.frontend.SendQuery(&pgproto3.Query{String: sql})
err := pgConn.frontend.Flush()
if err != nil {
pgConn.asyncClose()
return CommandTag{}, err
}

// err = pgConn.conn.SetReadDeadline(nbconn.NonBlockingDeadline)
// if err != nil {
// pgConn.asyncClose()
// return CommandTag{}, err
// }
nonblocking := true
defer func() {
if nonblocking {
pgConn.conn.SetReadDeadline(time.Time{})
}
}()

buf := iobufpool.Get(65536)
defer iobufpool.Put(buf)
(*buf)[0] = 'd'
// Send copy data
abortCopyChan := make(chan struct{})
copyErrChan := make(chan error, 1)
signalMessageChan := pgConn.signalMessage()
var wg sync.WaitGroup
wg.Add(1)

var readErr, pgErr error
for pgErr == nil {
// Read chunk from r.
var n int
n, readErr = r.Read((*buf)[5:cap(*buf)])
go func() {
defer wg.Done()
buf := iobufpool.Get(65536)
defer iobufpool.Put(buf)
(*buf)[0] = 'd'

// Send chunk to PostgreSQL.
if n > 0 {
*buf = (*buf)[0 : n+5]
pgio.SetInt32((*buf)[1:], int32(n+4))
for {
n, readErr := r.Read((*buf)[5:cap(*buf)])
if n > 0 {
*buf = (*buf)[0 : n+5]
pgio.SetInt32((*buf)[1:], int32(n+4))

writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(*buf)
if writeErr != nil {
// Write errors are always fatal, but we can't use asyncClose because we are in a different goroutine.
pgConn.conn.Close()

writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(*buf)
if writeErr != nil {
pgConn.asyncClose()
return CommandTag{}, err
copyErrChan <- writeErr
return
}
}
if readErr != nil {
copyErrChan <- readErr
return
}
}

// Abort loop if there was a read error.
if readErr != nil {
break
select {
case <-abortCopyChan:
return
default:
}
}
}()

// Read messages until error or none available.
for pgErr == nil {
var pgErr error
var copyErr error
for copyErr == nil && pgErr == nil {
select {
case copyErr = <-copyErrChan:
case <-signalMessageChan:
msg, err := pgConn.receiveMessage()
if err != nil {
// if errors.Is(err, nbconn.ErrWouldBlock) {
// break
// }
pgConn.asyncClose()
return CommandTag{}, normalizeTimeoutError(ctx, err)
}

switch msg := msg.(type) {
case *pgproto3.ErrorResponse:
pgErr = ErrorResponseToPgError(msg)
break
default:
signalMessageChan = pgConn.signalMessage()
}
}
}
close(abortCopyChan)
// Make sure io goroutine finishes before writing.
wg.Wait()

err = pgConn.conn.SetReadDeadline(time.Time{})
if err != nil {
pgConn.asyncClose()
return CommandTag{}, err
}
nonblocking = false

if readErr == io.EOF || pgErr != nil {
if copyErr == io.EOF || pgErr != nil {
pgConn.frontend.Send(&pgproto3.CopyDone{})
} else {
pgConn.frontend.Send(&pgproto3.CopyFail{Message: readErr.Error()})
pgConn.frontend.Send(&pgproto3.CopyFail{Message: copyErr.Error()})
}
err = pgConn.frontend.Flush()
if err != nil {
Expand Down

0 comments on commit 85136a8

Please sign in to comment.