Skip to content

Commit

Permalink
Steps toward TLS
Browse files Browse the repository at this point in the history
  • Loading branch information
jackc committed Jun 4, 2022
1 parent 2b80beb commit 51655bf
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 8 deletions.
20 changes: 12 additions & 8 deletions internal/nbconn/nbconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ var NonBlockingDeadline = time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC)
//
// The third is to efficiently check if a connection has been closed via a non-blocking read.
type Conn struct {
netConn net.Conn
netConn net.Conn
tlsConn *tls.Conn
maybeTLSConn net.Conn

readQueue bufferQueue
writeQueue bufferQueue
Expand All @@ -51,13 +53,15 @@ type Conn struct {

func New(conn net.Conn) *Conn {
return &Conn{
netConn: conn,
netConn: conn,
maybeTLSConn: conn,
}
}

// StartTLS starts using TLS. It must not be called concurrently with any other method and must only be called once.
func (c *Conn) StartTLS(config *tls.Config) {
c.netConn = tls.Client(c.netConn, config)
c.tlsConn = tls.Client(c.netConn, config)
c.maybeTLSConn = c.tlsConn
}

// Read implements io.Reader.
Expand Down Expand Up @@ -102,7 +106,7 @@ func (c *Conn) Read(b []byte) (n int, err error) {
if readNonblocking {
readN, err = c.nonblockingRead(b[n:])
} else {
readN, err = c.netConn.Read(b[n:])
readN, err = c.maybeTLSConn.Read(b[n:])
}
n += readN
return n, err
Expand All @@ -128,7 +132,7 @@ func (c *Conn) Close() (err error) {
}

defer func() {
closeErr := c.netConn.Close()
closeErr := c.maybeTLSConn.Close()
if err == nil {
err = closeErr
}
Expand All @@ -145,11 +149,11 @@ func (c *Conn) Close() (err error) {
}

func (c *Conn) LocalAddr() net.Addr {
return c.netConn.LocalAddr()
return c.maybeTLSConn.LocalAddr()
}

func (c *Conn) RemoteAddr() net.Addr {
return c.netConn.RemoteAddr()
return c.maybeTLSConn.RemoteAddr()
}

// SetDeadline is the equivalent of calling SetReadDealine(t) and SetWriteDeadline(t).
Expand Down Expand Up @@ -179,7 +183,7 @@ func (c *Conn) SetReadDeadline(t time.Time) error {

c.readDeadline = t

return c.netConn.SetReadDeadline(t)
return c.maybeTLSConn.SetReadDeadline(t)
}

func (c *Conn) SetWriteDeadline(t time.Time) error {
Expand Down
12 changes: 12 additions & 0 deletions internal/nbconn/nbconn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,18 @@ func TestWriteIsBuffered(t *testing.T) {
})
}

func TestSetWriteDeadlineDoesNotBlockWrite(t *testing.T) {
testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) {
err := conn.SetWriteDeadline(time.Now())
require.NoError(t, err)

writeBuf := []byte("test")
n, err := conn.Write(writeBuf)
require.NoError(t, err)
require.EqualValues(t, 4, n)
})
}

func TestReadFlushesWriteBuffer(t *testing.T) {
testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) {
writeBuf := []byte("test")
Expand Down

0 comments on commit 51655bf

Please sign in to comment.