Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
jackc committed Jun 3, 2022
1 parent 2e7b46d commit ca22396
Show file tree
Hide file tree
Showing 5 changed files with 366 additions and 112 deletions.
76 changes: 76 additions & 0 deletions internal/nbbconn/bufferqueue.go.deleted
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package nbbconn

import (
"sync"

"github.com/jackc/pgx/v5/internal/iobufpool"
)

const minBufferQueueLen = 8

type bufferQueue struct {
lock sync.Mutex
queue [][]byte
r, w int
}

func (bq *bufferQueue) pushBack(buf []byte) {
bq.lock.Lock()
defer bq.lock.Unlock()

if bq.w >= len(bq.queue) {
bq.growQueue()
}
bq.queue[bq.w] = buf
bq.w++
}

func (bq *bufferQueue) pushFront(buf []byte) {
bq.lock.Lock()
defer bq.lock.Unlock()

if bq.w >= len(bq.queue) {
bq.growQueue()
}
copy(bq.queue[bq.r+1:bq.w+1], bq.queue[bq.r:bq.w])
bq.queue[bq.r] = buf
bq.w++
}

func (bq *bufferQueue) popFront() []byte {
bq.lock.Lock()
defer bq.lock.Unlock()

if bq.r == bq.w {
return nil
}

buf := bq.queue[bq.r]
bq.queue[bq.r] = nil // Clear reference so it can be garbage collected.
bq.r++

if bq.r == bq.w {
bq.r = 0
bq.w = 0
if len(bq.queue) > minBufferQueueLen {
bq.queue = make([][]byte, minBufferQueueLen)
}
}

return buf
}

func (bq *bufferQueue) growQueue() {
desiredLen := (len(bq.queue) + 1) * 3 / 2
if desiredLen < minBufferQueueLen {
desiredLen = minBufferQueueLen
}

newQueue := make([][]byte, desiredLen)
copy(newQueue, bq.queue)
bq.queue = newQueue
}

func releaseBuf(buf []byte) {
iobufpool.Put(buf[:cap(buf)])
}
51 changes: 33 additions & 18 deletions internal/nbbconn/nbbconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ import (
)

var errClosed = errors.New("closed")
var errWouldBlock = errors.New("would block")
var ErrWouldBlock = errors.New("would block")

const fakeNonblockingWaitDuration = 100 * time.Millisecond

var NonBlockingDeadline = time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC)

// Conn is a non-blocking, buffered net.Conn wrapper. It implements net.Conn.
//
// It is designed to solve three problems.
Expand All @@ -37,6 +39,7 @@ type Conn struct {

readDeadlineLock sync.Mutex
readDeadline time.Time
readNonblocking bool

writeDeadlineLock sync.Mutex
writeDeadline time.Time
Expand Down Expand Up @@ -74,9 +77,19 @@ func (c *Conn) Read(b []byte) (n int, err error) {
releaseBuf(buf)
}
return n, nil
// TODO - must return error if n != len(b)
}

return c.netConn.Read(b)
var readNonblocking bool
c.readDeadlineLock.Lock()
readNonblocking = c.readNonblocking
c.readDeadlineLock.Unlock()

if readNonblocking {
return c.nonblockingRead(b)
} else {
return c.netConn.Read(b)
}
}

// Write implements io.Writer. It never blocks due to buffering all writes. It will only return an error if the Conn is
Expand Down Expand Up @@ -123,29 +136,31 @@ func (c *Conn) RemoteAddr() net.Addr {
return c.netConn.RemoteAddr()
}

// SetDeadline is the equivalent of calling SetReadDealine(t) and SetWriteDeadline(t).
func (c *Conn) SetDeadline(t time.Time) error {
if c.isClosed() {
return errClosed
err := c.SetReadDeadline(t)
if err != nil {
return err
}

c.readDeadlineLock.Lock()
defer c.readDeadlineLock.Unlock()
c.readDeadline = t

c.writeDeadlineLock.Lock()
defer c.writeDeadlineLock.Unlock()
c.writeDeadline = t

return c.netConn.SetDeadline(t)
return c.SetWriteDeadline(t)
}

// SetReadDeadline sets the read deadline as t. If t == NonBlockingDeadline then future reads will be non-blocking.
func (c *Conn) SetReadDeadline(t time.Time) error {
if c.isClosed() {
return errClosed
}

c.readDeadlineLock.Lock()
defer c.readDeadlineLock.Unlock()

if t == NonBlockingDeadline {
c.readNonblocking = true
t = time.Time{}
} else {
c.readNonblocking = false
}

c.readDeadline = t

return c.netConn.SetReadDeadline(t)
Expand Down Expand Up @@ -193,7 +208,7 @@ func (c *Conn) flush() error {
n, err := c.nonblockingWrite(remainingBuf)
remainingBuf = remainingBuf[n:]
if err != nil {
if !errors.Is(err, errWouldBlock) {
if !errors.Is(err, ErrWouldBlock) {
buf = buf[:len(remainingBuf)]
copy(buf, remainingBuf)
c.writeQueue.pushFront(buf)
Expand Down Expand Up @@ -234,7 +249,7 @@ func (c *Conn) bufferNonblockingRead() (stopChan chan struct{}, errChan chan err
}

if err != nil {
if !errors.Is(err, errWouldBlock) {
if !errors.Is(err, ErrWouldBlock) {
errChan <- err
return
}
Expand Down Expand Up @@ -276,7 +291,7 @@ func (c *Conn) fakeNonblockingWrite(b []byte) (n int, err error) {

if err != nil {
if errors.Is(err, os.ErrDeadlineExceeded) {
err = errWouldBlock
err = ErrWouldBlock
}
}
}()
Expand Down Expand Up @@ -305,7 +320,7 @@ func (c *Conn) fakeNonblockingRead(b []byte) (n int, err error) {

if err != nil {
if errors.Is(err, os.ErrDeadlineExceeded) {
err = errWouldBlock
err = ErrWouldBlock
}
}
}()
Expand Down
129 changes: 129 additions & 0 deletions internal/nbbconn/nbbconn_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
package nbbconn_test

import (
"net"
"testing"
"time"

"github.com/jackc/pgx/v5/internal/nbbconn"
"github.com/stretchr/testify/require"
)

func TestWriteIsBuffered(t *testing.T) {
local, remote := net.Pipe()
defer func() {
local.Close()
remote.Close()
}()

conn := nbbconn.New(local)

// net.Pipe is synchronous so the Write would block if not buffered.
writeBuf := []byte("test")
n, err := conn.Write(writeBuf)
require.NoError(t, err)
require.EqualValues(t, 4, n)

errChan := make(chan error, 1)
go func() {
err := conn.Flush()
errChan <- err
}()

readBuf := make([]byte, len(writeBuf))
_, err = remote.Read(readBuf)
require.NoError(t, err)

require.NoError(t, <-errChan)
}

func TestReadFlushesWriteBuffer(t *testing.T) {
local, remote := net.Pipe()
defer func() {
local.Close()
remote.Close()
}()

conn := nbbconn.New(local)

writeBuf := []byte("test")
n, err := conn.Write(writeBuf)
require.NoError(t, err)
require.EqualValues(t, 4, n)

errChan := make(chan error, 2)
go func() {
readBuf := make([]byte, len(writeBuf))
_, err := remote.Read(readBuf)
errChan <- err

_, err = remote.Write([]byte("okay"))
errChan <- err
}()

readBuf := make([]byte, 4)
_, err = conn.Read(readBuf)
require.NoError(t, err)
require.Equal(t, []byte("okay"), readBuf)

require.NoError(t, <-errChan)
require.NoError(t, <-errChan)
}

func TestCloseFlushesWriteBuffer(t *testing.T) {
local, remote := net.Pipe()
defer func() {
local.Close()
remote.Close()
}()

conn := nbbconn.New(local)

writeBuf := []byte("test")
n, err := conn.Write(writeBuf)
require.NoError(t, err)
require.EqualValues(t, 4, n)

errChan := make(chan error, 1)
go func() {
readBuf := make([]byte, len(writeBuf))
_, err := remote.Read(readBuf)
errChan <- err
}()

err = conn.Close()
require.NoError(t, err)

require.NoError(t, <-errChan)
}

func TestNonBlockingRead(t *testing.T) {
local, remote := net.Pipe()
defer func() {
local.Close()
remote.Close()
}()

conn := nbbconn.New(local)

err := conn.SetReadDeadline(nbbconn.NonBlockingDeadline)
require.NoError(t, err)

buf := make([]byte, 4)
n, err := conn.Read(buf)
require.ErrorIs(t, err, nbbconn.ErrWouldBlock)
require.EqualValues(t, 0, n)

errChan := make(chan error, 1)
go func() {
_, err := remote.Write([]byte("okay"))
errChan <- err
}()

err = conn.SetReadDeadline(time.Time{})
require.NoError(t, err)

n, err = conn.Read(buf)
require.NoError(t, err)
require.EqualValues(t, 4, n)
}
Loading

0 comments on commit ca22396

Please sign in to comment.