Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Close connection and stop listening when port forwarding errors occur so that kubectl can exit #103526

Merged
merged 1 commit into from
Nov 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
20 changes: 13 additions & 7 deletions staging/src/k8s.io/client-go/tools/portforward/portforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,15 +299,20 @@ func (pf *PortForwarder) getListener(protocol string, hostname string, port *For
// the background.
func (pf *PortForwarder) waitForConnection(listener net.Listener, port ForwardedPort) {
for {
conn, err := listener.Accept()
if err != nil {
// TODO consider using something like https://github.com/hydrogen18/stoppableListener?
if !strings.Contains(strings.ToLower(err.Error()), "use of closed network connection") {
runtime.HandleError(fmt.Errorf("error accepting connection on port %d: %v", port.Local, err))
}
select {
case <-pf.streamConn.CloseChan():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good improvement.

return
default:
conn, err := listener.Accept()
if err != nil {
// TODO consider using something like https://github.com/hydrogen18/stoppableListener?
if !strings.Contains(strings.ToLower(err.Error()), "use of closed network connection") {
runtime.HandleError(fmt.Errorf("error accepting connection on port %d: %v", port.Local, err))
}
return
}
go pf.handleConnection(conn, port)
Comment on lines +312 to +314
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pre-existng and not blocking for this PR, but this code is so strange. it seems ripe for cleanup to have fewer go funcs launch go funcs and infinite loops.

}
go pf.handleConnection(conn, port)
}
}

Expand Down Expand Up @@ -398,6 +403,7 @@ func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) {
err = <-errorChan
if err != nil {
runtime.HandleError(err)
brianpursley marked this conversation as resolved.
Show resolved Hide resolved
pf.streamConn.Close()
}
}

Expand Down
177 changes: 172 additions & 5 deletions staging/src/k8s.io/client-go/tools/portforward/portforward_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package portforward

import (
"bytes"
"fmt"
"net"
"net/http"
Expand All @@ -27,6 +28,9 @@ import (
"testing"
"time"

"github.com/stretchr/testify/assert"

v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/util/httpstream"
)

Expand All @@ -43,18 +47,29 @@ func (d *fakeDialer) Dial(protocols ...string) (httpstream.Connection, string, e
}

type fakeConnection struct {
closed bool
closeChan chan bool
closed bool
closeChan chan bool
dataStream *fakeStream
errorStream *fakeStream
}

func newFakeConnection() httpstream.Connection {
func newFakeConnection() *fakeConnection {
return &fakeConnection{
closeChan: make(chan bool),
closeChan: make(chan bool),
dataStream: &fakeStream{},
errorStream: &fakeStream{},
}
}

func (c *fakeConnection) CreateStream(headers http.Header) (httpstream.Stream, error) {
return nil, nil
switch headers.Get(v1.StreamType) {
case v1.StreamTypeData:
return c.dataStream, nil
case v1.StreamTypeError:
return c.errorStream, nil
default:
return nil, fmt.Errorf("fakeStream creation not supported for stream type %s", headers.Get(v1.StreamType))
}
}

func (c *fakeConnection) Close() error {
Expand All @@ -76,6 +91,65 @@ func (c *fakeConnection) SetIdleTimeout(timeout time.Duration) {
// no-op
}

type fakeListener struct {
net.Listener
closeChan chan bool
}

func newFakeListener() fakeListener {
return fakeListener{
closeChan: make(chan bool),
}
}

func (l *fakeListener) Accept() (net.Conn, error) {
select {
case <-l.closeChan:
return nil, fmt.Errorf("listener closed")
}
}

func (l *fakeListener) Close() error {
close(l.closeChan)
return nil
}

func (l *fakeListener) Addr() net.Addr {
return fakeAddr{}
}

type fakeAddr struct{}

func (fakeAddr) Network() string { return "fake" }
func (fakeAddr) String() string { return "fake" }

type fakeStream struct {
headers http.Header
readFunc func(p []byte) (int, error)
writeFunc func(p []byte) (int, error)
}

func (s *fakeStream) Read(p []byte) (n int, err error) { return s.readFunc(p) }
func (s *fakeStream) Write(p []byte) (n int, err error) { return s.writeFunc(p) }
func (*fakeStream) Close() error { return nil }
func (*fakeStream) Reset() error { return nil }
func (s *fakeStream) Headers() http.Header { return s.headers }
func (*fakeStream) Identifier() uint32 { return 0 }

type fakeConn struct {
sendBuffer *bytes.Buffer
receiveBuffer *bytes.Buffer
}

func (f fakeConn) Read(p []byte) (int, error) { return f.sendBuffer.Read(p) }
func (f fakeConn) Write(p []byte) (int, error) { return f.receiveBuffer.Write(p) }
func (fakeConn) Close() error { return nil }
func (fakeConn) LocalAddr() net.Addr { return nil }
func (fakeConn) RemoteAddr() net.Addr { return nil }
func (fakeConn) SetDeadline(t time.Time) error { return nil }
func (fakeConn) SetReadDeadline(t time.Time) error { return nil }
func (fakeConn) SetWriteDeadline(t time.Time) error { return nil }

func TestParsePortsAndNew(t *testing.T) {
tests := []struct {
input []string
Expand Down Expand Up @@ -393,3 +467,96 @@ func TestGetPortsReturnsDynamicallyAssignedLocalPort(t *testing.T) {
t.Fatalf("local port is 0, expected != 0")
}
}

func TestHandleConnection(t *testing.T) {
out := bytes.NewBufferString("")

pf, err := New(&fakeDialer{}, []string{":2222"}, nil, nil, out, nil)
if err != nil {
t.Fatalf("error while calling New: %s", err)
}

// Setup fake local connection
localConnection := &fakeConn{
sendBuffer: bytes.NewBufferString("test data from local"),
receiveBuffer: bytes.NewBufferString(""),
}

// Setup fake remote connection to send data on the data stream after it receives data from the local connection
remoteDataToSend := bytes.NewBufferString("test data from remote")
remoteDataReceived := bytes.NewBufferString("")
remoteErrorToSend := bytes.NewBufferString("")
blockRemoteSend := make(chan struct{})
remoteConnection := newFakeConnection()
remoteConnection.dataStream.readFunc = func(p []byte) (int, error) {
<-blockRemoteSend // Wait for the expected data to be received before responding
return remoteDataToSend.Read(p)
}
remoteConnection.dataStream.writeFunc = func(p []byte) (int, error) {
n, err := remoteDataReceived.Write(p)
if remoteDataReceived.String() == "test data from local" {
close(blockRemoteSend)
}
return n, err
}
remoteConnection.errorStream.readFunc = remoteErrorToSend.Read
pf.streamConn = remoteConnection

// Test handleConnection
pf.handleConnection(localConnection, ForwardedPort{Local: 1111, Remote: 2222})

assert.Equal(t, "test data from local", remoteDataReceived.String())
assert.Equal(t, "test data from remote", localConnection.receiveBuffer.String())
assert.Equal(t, "Handling connection for 1111\n", out.String())
}

func TestHandleConnectionSendsRemoteError(t *testing.T) {
out := bytes.NewBufferString("")
errOut := bytes.NewBufferString("")

pf, err := New(&fakeDialer{}, []string{":2222"}, nil, nil, out, errOut)
if err != nil {
t.Fatalf("error while calling New: %s", err)
}

// Setup fake local connection
localConnection := &fakeConn{
sendBuffer: bytes.NewBufferString(""),
receiveBuffer: bytes.NewBufferString(""),
}

// Setup fake remote connection to return an error message on the error stream
remoteDataToSend := bytes.NewBufferString("")
remoteDataReceived := bytes.NewBufferString("")
remoteErrorToSend := bytes.NewBufferString("error")
remoteConnection := newFakeConnection()
remoteConnection.dataStream.readFunc = remoteDataToSend.Read
remoteConnection.dataStream.writeFunc = remoteDataReceived.Write
remoteConnection.errorStream.readFunc = remoteErrorToSend.Read
pf.streamConn = remoteConnection

// Test handleConnection, using go-routine because it needs to be able to write to unbuffered pf.errorChan
pf.handleConnection(localConnection, ForwardedPort{Local: 1111, Remote: 2222})

assert.Equal(t, "", remoteDataReceived.String())
assert.Equal(t, "", localConnection.receiveBuffer.String())
assert.Equal(t, "Handling connection for 1111\n", out.String())
}

func TestWaitForConnectionExitsOnStreamConnClosed(t *testing.T) {
out := bytes.NewBufferString("")
errOut := bytes.NewBufferString("")

pf, err := New(&fakeDialer{}, []string{":2222"}, nil, nil, out, errOut)
if err != nil {
t.Fatalf("error while calling New: %s", err)
}

listener := newFakeListener()

pf.streamConn = newFakeConnection()
pf.streamConn.Close()

port := ForwardedPort{}
pf.waitForConnection(&listener, port)
}