Skip to content

Commit

Permalink
Introduce barrier for past read deadline propagation of websocket alp…
Browse files Browse the repository at this point in the history
…n upgraded connection.

This fixes kubectl exec not working properly over L7 load balancer, because in that case connection was upgraded and hijacked twice.
And golang http server during hicjaking aborts pending reads by setting read deadline to past value.
Second hijack went to the bottom and aborted our upgraded base connection.
  • Loading branch information
AntonAM committed May 23, 2024
1 parent 27fd7c7 commit 213d1b5
Showing 1 changed file with 104 additions and 15 deletions.
119 changes: 104 additions & 15 deletions lib/web/conn_upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ package web

import (
"context"
"errors"
"io"
"net"
"net/http"
"os"
"slices"
"sync"
"time"
Expand Down Expand Up @@ -235,20 +237,65 @@ func (conn *waitConn) Close() error {
return trace.Wrap(err)
}

type readResult struct {
bytes []byte
err error
}

type websocketALPNServerConn struct {
*websocket.Conn
readBuffer []byte
readError error
readMutex sync.Mutex

writeMutex sync.Mutex

readMutex sync.Mutex
readBuffer []byte
readCh chan readResult
readError error
isReadingWSData bool

readDeadlineMutex sync.Mutex
// readAbortCtx is used to decouple our Read() function from the underlying websocket Read() function, so
// we could abort read for our callers without disturbing websocket connection. It is here to make sure that
// http connection hijacking doesn't close our underlying connection.
readAbortCtx context.Context
readAbortCancel context.CancelFunc
}

func newWebSocketALPNServerConn(wsConn *websocket.Conn) *websocketALPNServerConn {
abortCtx, abortCancel := context.WithCancel(context.Background())

return &websocketALPNServerConn{
Conn: wsConn,
Conn: wsConn,
readAbortCtx: abortCtx,
readAbortCancel: abortCancel,
readCh: make(chan readResult, 1), // Buffered to make sure when connection is closed readWSData() doesn't get stuck on writing.
}
}

func (c *websocketALPNServerConn) SetReadDeadline(t time.Time) error {
c.readDeadlineMutex.Lock()
defer c.readDeadlineMutex.Unlock()

// If new read deadline is in the past, trigger read abort context and exit.
// We don't propagate deadline to the underlying connection because it might be already hijacked connection
// and in that case it will be closed, leading to closure of the topmost connection as well.
if !t.Equal(time.Time{}) && t.Before(time.Now()) {
c.readAbortCancel()
return nil
}

// Reset readAbortCtx if it was done, since here we have a new read deadline, that is in the future.
select {
case <-c.readAbortCtx.Done():
abortCtx, abortCancel := context.WithCancel(context.Background())
c.readAbortCtx = abortCtx
c.readAbortCancel = abortCancel
default:
}

return c.Conn.SetReadDeadline(t)
}

func (c *websocketALPNServerConn) convertError(err error) error {
if websocket.IsCloseError(err,
websocket.CloseAbnormalClosure,
Expand All @@ -265,13 +312,26 @@ func (c *websocketALPNServerConn) Read(b []byte) (int, error) {
defer c.readMutex.Unlock()

n, err := c.readLocked(b)
// For http server to ignore this error as "pendingAbort error" we shouldn't wrap it.
if errors.Is(err, os.ErrDeadlineExceeded) {
return n, err
}
return n, trace.Wrap(err)
}

func (c *websocketALPNServerConn) readLocked(b []byte) (int, error) {
// Stop reading if any previous read err.
if c.readError != nil {
return 0, trace.Wrap(c.readError)
return 0, c.readError
}

c.readDeadlineMutex.Lock()
abortCtx := c.readAbortCtx
c.readDeadlineMutex.Unlock()

select {
case <-abortCtx.Done():
return 0, os.ErrDeadlineExceeded
default:
}

if len(c.readBuffer) > 0 {
Expand All @@ -284,19 +344,48 @@ func (c *websocketALPNServerConn) readLocked(b []byte) (int, error) {
return n, nil
}

for {
messageType, data, err := c.Conn.ReadMessage()
if err != nil {
c.readError = c.convertError(err)
return 0, trace.Wrap(c.readError)
// If websocket data reading goroutine isn't running - start it.
if !c.isReadingWSData {
go c.readWSData()
c.isReadingWSData = true
}

// Wait for websocket data to be read of abort of our pending read. If we abort websocket reading
// goroutine will continue
select {
case <-abortCtx.Done():
return 0, os.ErrDeadlineExceeded

case res := <-c.readCh:
c.isReadingWSData = false
if res.err != nil {
c.readError = res.err
return 0, res.err
}
if len(res.bytes) == 0 {
return 0, nil
}

c.readBuffer = res.bytes
return c.readLocked(b)
}
}

func (c *websocketALPNServerConn) readWSData() {
messageType, data, err := c.Conn.ReadMessage()
if err != nil {
c.readCh <- readResult{nil, trace.Wrap(c.convertError(err))}
return
}

for {
switch messageType {
case websocket.CloseMessage:
return 0, nil
c.readCh <- readResult{[]byte{}, nil}
return
case websocket.BinaryMessage:
c.readBuffer = data
return c.readLocked(b)
c.readCh <- readResult{data, nil}
return
case websocket.PongMessage:
// Receives Pong as response to Ping. Nothing to do.
}
Expand Down Expand Up @@ -326,7 +415,7 @@ func (c *websocketALPNServerConn) SetDeadline(t time.Time) error {
c.writeMutex.Lock()
defer c.writeMutex.Unlock()
return trace.NewAggregate(
c.Conn.SetReadDeadline(t),
c.SetReadDeadline(t),
c.Conn.SetWriteDeadline(t),
)
}
Expand Down

0 comments on commit 213d1b5

Please sign in to comment.