Skip to content

Commit

Permalink
Fix headerInterceptingConn handling
Browse files Browse the repository at this point in the history
  • Loading branch information
liggitt authored and seans3 committed Mar 4, 2024
1 parent 9e15462 commit 2443b3f
Show file tree
Hide file tree
Showing 2 changed files with 246 additions and 58 deletions.
65 changes: 43 additions & 22 deletions staging/src/k8s.io/apiserver/pkg/util/proxy/streamtunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"bytes"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
Expand Down Expand Up @@ -203,59 +202,81 @@ type headerInterceptingConn struct {
// and initializableConn#InitializeWrite() has been called with the result.
initializableConn

lock sync.Mutex
headerBuffer []byte
initialized bool
lock sync.Mutex
headerBuffer []byte
initialized bool
initializeErr error
}

// initializableConn is a connection that will be initialized before any calls to Write are made
type initializableConn interface {
net.Conn
InitializeWrite(backendResponse *http.Response) error
// InitializeWrite is called when the backend response headers have been read.
// backendResponse contains the parsed headers.
// backendResponseBytes are the raw bytes the headers were parsed from.
InitializeWrite(backendResponse *http.Response, backendResponseBytes []byte) error
}

const maxHeaderBytes = 1 << 20

// token for normal header / body separation (\r\n\r\n, but go tolerates the leading \r being absent)
var lfCRLF = []byte("\n\r\n")

// token for header / body separation without \r (which go tolerates)
var lfLF = []byte("\n\n")

// Write intercepts to initially swallow the HTTP response, then
// delegate to the tunneling "net.Conn" once the response has been
// seen and processed.
func (h *headerInterceptingConn) Write(b []byte) (int, error) {
h.lock.Lock()
defer h.lock.Unlock()

if h.initializeErr != nil {
return 0, h.initializeErr
}
if h.initialized {
return h.initializableConn.Write(b)
}

// Write into the headerBuffer, then attempt to parse the bytes
// as an http response.
// Guard against excessive buffering
if len(h.headerBuffer)+len(b) > maxHeaderBytes {
return 0, fmt.Errorf("header size limit exceeded")
}

// Accumulate into headerBuffer
h.headerBuffer = append(h.headerBuffer, b...)
bufferedReader := bufio.NewReader(bytes.NewReader(h.headerBuffer))
resp, err := http.ReadResponse(bufferedReader, nil)
if errors.Is(err, io.ErrUnexpectedEOF) {
// don't yet have a complete set of headers

// Attempt to parse http response headers
var headerBytes, bodyBytes []byte
if i := bytes.Index(h.headerBuffer, lfCRLF); i != -1 {
// headers terminated with \n\r\n
headerBytes = h.headerBuffer[0 : i+len(lfCRLF)]
bodyBytes = h.headerBuffer[i+len(lfCRLF):]
} else if i := bytes.Index(h.headerBuffer, lfLF); i != -1 {
// headers terminated with \n\n (which go tolerates)
headerBytes = h.headerBuffer[0 : i+len(lfLF)]
bodyBytes = h.headerBuffer[i+len(lfLF):]
} else {
// don't yet have a complete set of headers yet
return len(b), nil
}
resp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(headerBytes)), nil)
if err != nil {
klog.Errorf("invalid headers: %v", err)
h.initializeErr = err
return len(b), err
}
resp.Body.Close() //nolint:errcheck

h.headerBuffer = nil
err = h.initializableConn.InitializeWrite(resp)
h.initialized = true
if err != nil {
return len(b), err
h.initializeErr = h.initializableConn.InitializeWrite(resp, headerBytes)
if h.initializeErr != nil {
return len(b), h.initializeErr
}

// Copy any remaining buffered data to the underlying conn
remainingBuffer, _ := io.ReadAll(bufferedReader)
if len(remainingBuffer) > 0 {
_, err = h.initializableConn.Write(remainingBuffer)
if len(bodyBytes) > 0 {
_, err = h.initializableConn.Write(bodyBytes)
}
return len(b), err
}
Expand All @@ -274,7 +295,7 @@ type tunnelingWebsocketUpgraderConn struct {
err error
}

func (u *tunnelingWebsocketUpgraderConn) InitializeWrite(backendResponse *http.Response) (err error) {
func (u *tunnelingWebsocketUpgraderConn) InitializeWrite(backendResponse *http.Response, backendResponseBytes []byte) (err error) {
// make sure we close a connection we open in error cases
var conn net.Conn
defer func() {
Expand Down Expand Up @@ -337,9 +358,9 @@ func (u *tunnelingWebsocketUpgraderConn) InitializeWrite(backendResponse *http.R
u.err = err
return u.err
}
// replay the backend response to the hijacked conn
// replay the backend response bytes to the hijacked conn
conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) //nolint:errcheck
err = backendResponse.Write(conn)
_, err = conn.Write(backendResponseBytes)
if err != nil {
u.err = err
return u.err
Expand Down

0 comments on commit 2443b3f

Please sign in to comment.