Skip to content

Commit

Permalink
Add true non-blocking IO
Browse files Browse the repository at this point in the history
  • Loading branch information
jackc committed Jun 18, 2022
1 parent 7dd26a3 commit 60ecdda
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 19 deletions.
80 changes: 75 additions & 5 deletions internal/nbconn/nbconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"os"
"sync"
"sync/atomic"
"syscall"
"time"

"github.com/jackc/pgx/v5/internal/iobufpool"
Expand Down Expand Up @@ -54,7 +55,8 @@ type Conn interface {

// NetConn is a non-blocking net.Conn wrapper. It implements net.Conn.
type NetConn struct {
conn net.Conn
conn net.Conn
rawConn syscall.RawConn

readQueue bufferQueue
writeQueue bufferQueue
Expand All @@ -72,10 +74,20 @@ type NetConn struct {
closed int64 // 0 = not closed, 1 = closed
}

func NewNetConn(conn net.Conn) *NetConn {
return &NetConn{
func NewNetConn(conn net.Conn, fakeNonBlockingIO bool) *NetConn {
nc := &NetConn{
conn: conn,
}

if !fakeNonBlockingIO {
if sc, ok := conn.(syscall.Conn); ok {
if rawConn, err := sc.SyscallConn(); err == nil {
nc.rawConn = rawConn
}
}
}

return nc
}

// Read implements io.Reader.
Expand Down Expand Up @@ -323,7 +335,11 @@ func (c *NetConn) isClosed() bool {
}

func (c *NetConn) nonblockingWrite(b []byte) (n int, err error) {
return c.fakeNonblockingWrite(b)
if c.rawConn == nil {
return c.fakeNonblockingWrite(b)
} else {
return c.realNonblockingWrite(b)
}
}

func (c *NetConn) fakeNonblockingWrite(b []byte) (n int, err error) {
Expand Down Expand Up @@ -351,8 +367,37 @@ func (c *NetConn) fakeNonblockingWrite(b []byte) (n int, err error) {
return c.conn.Write(b)
}

func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) {
var funcErr error
err = c.rawConn.Write(func(fd uintptr) (done bool) {
n, funcErr = syscall.Write(int(fd), b)
return true
})
if err == nil && funcErr != nil {
if errors.Is(funcErr, syscall.EWOULDBLOCK) {
err = ErrWouldBlock
} else {
err = funcErr
}
}
if err != nil {
// n may be -1 when an error occurs.
if n < 0 {
n = 0
}

return n, err
}

return n, nil
}

func (c *NetConn) nonblockingRead(b []byte) (n int, err error) {
return c.fakeNonblockingRead(b)
if c.rawConn == nil {
return c.fakeNonblockingRead(b)
} else {
return c.realNonblockingRead(b)
}
}

func (c *NetConn) fakeNonblockingRead(b []byte) (n int, err error) {
Expand Down Expand Up @@ -380,6 +425,31 @@ func (c *NetConn) fakeNonblockingRead(b []byte) (n int, err error) {
return c.conn.Read(b)
}

func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) {
var funcErr error
err = c.rawConn.Read(func(fd uintptr) (done bool) {
n, funcErr = syscall.Read(int(fd), b)
return true
})
if err == nil && funcErr != nil {
if errors.Is(funcErr, syscall.EWOULDBLOCK) {
err = ErrWouldBlock
} else {
err = funcErr
}
}
if err != nil {
// n may be -1 when an error occurs.
if n < 0 {
n = 0
}

return n, err
}

return n, nil
}

// syscall.Conn is interface

// TLSClient establishes a TLS connection as a client over conn using config.
Expand Down
102 changes: 89 additions & 13 deletions internal/nbconn/nbconn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,31 +67,53 @@ pz01y53wMSTJs0ocAxkYvUc5laF+vMsLpG2vp8f35w8uKuO7+vm5LAjUsPd099jG

func testVariants(t *testing.T, f func(t *testing.T, local nbconn.Conn, remote net.Conn)) {
for _, tt := range []struct {
name string
makeConns func(t *testing.T) (local, remote net.Conn)
useTLS bool
name string
makeConns func(t *testing.T) (local, remote net.Conn)
useTLS bool
fakeNonBlockingIO bool
}{
{
name: "Pipe",
makeConns: makePipeConns,
useTLS: false,
name: "Pipe",
makeConns: makePipeConns,
useTLS: false,
fakeNonBlockingIO: true,
},
{
name: "TCP",
makeConns: makeTCPConns,
useTLS: false,
name: "TCP with Fake Non-blocking IO",
makeConns: makeTCPConns,
useTLS: false,
fakeNonBlockingIO: true,
},
{
name: "TLS over TCP",
makeConns: makeTCPConns,
useTLS: true,
name: "TLS over TCP with Fake Non-blocking IO",
makeConns: makeTCPConns,
useTLS: true,
fakeNonBlockingIO: true,
},
{
name: "TCP with Real Non-blocking IO",
makeConns: makeTCPConns,
useTLS: false,
fakeNonBlockingIO: false,
},
{
name: "TLS over TCP with Real Non-blocking IO",
makeConns: makeTCPConns,
useTLS: true,
fakeNonBlockingIO: false,
},
} {
t.Run(tt.name, func(t *testing.T) {
local, remote := tt.makeConns(t)

// Just to be sure both ends get closed. Also, it retains a reference so one side of the connection doesn't get
// garbage collected. This could happen when a test is testing against a non-responsive remote. Since it never
// uses remote it may be garbage collected leading to the connection being closed.
defer local.Close()
defer remote.Close()

var conn nbconn.Conn
netConn := nbconn.NewNetConn(local)
netConn := nbconn.NewNetConn(local, tt.fakeNonBlockingIO)

if tt.useTLS {
cert, err := tls.X509KeyPair(testTLSPublicKey, testTLSPrivateKey)
Expand Down Expand Up @@ -244,6 +266,60 @@ func TestCloseFlushesWriteBuffer(t *testing.T) {
})
}

// This test exercises the non-blocking write path. Because writes are buffered it is difficult trigger this with
// certainty and visibility. So this test tries to trigger what would otherwise be a deadlock by both sides writing
// large values.
func TestInternalNonBlockingWrite(t *testing.T) {
const deadlockSize = 4 * 1024 * 1024

testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
writeBuf := make([]byte, deadlockSize)
n, err := conn.Write(writeBuf)
require.NoError(t, err)
require.EqualValues(t, deadlockSize, n)

errChan := make(chan error, 1)
go func() {
remoteWriteBuf := make([]byte, deadlockSize)
_, err := remote.Write(remoteWriteBuf)
if err != nil {
errChan <- err
return
}

readBuf := make([]byte, deadlockSize)
_, err = io.ReadFull(remote, readBuf)
errChan <- err
}()

readBuf := make([]byte, deadlockSize)
_, err = conn.Read(readBuf)
require.NoError(t, err)

err = conn.Close()
require.NoError(t, err)

require.NoError(t, <-errChan)
})
}

func TestInternalNonBlockingWriteWithDeadline(t *testing.T) {
const deadlockSize = 4 * 1024 * 1024

testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
writeBuf := make([]byte, deadlockSize)
n, err := conn.Write(writeBuf)
require.NoError(t, err)
require.EqualValues(t, deadlockSize, n)

err = conn.SetDeadline(time.Now().Add(100 * time.Millisecond))
require.NoError(t, err)

err = conn.Flush()
require.Error(t, err)
})
}

func TestNonBlockingRead(t *testing.T) {
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
err := conn.SetReadDeadline(nbconn.NonBlockingDeadline)
Expand Down
2 changes: 1 addition & 1 deletion pgconn/pgconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
}
return nil, &connectError{config: config, msg: "dial error", err: err}
}
netConn = nbconn.NewNetConn(netConn)
netConn = nbconn.NewNetConn(netConn, false)

pgConn.conn = netConn
pgConn.contextWatcher = newContextWatcher(netConn)
Expand Down

0 comments on commit 60ecdda

Please sign in to comment.