diff --git a/pkg/registry/core/pod/rest/subresources.go b/pkg/registry/core/pod/rest/subresources.go index 1ffc85ff2763..cc4073db5a6b 100644 --- a/pkg/registry/core/pod/rest/subresources.go +++ b/pkg/registry/core/pod/rest/subresources.go @@ -70,7 +70,7 @@ func (r *ProxyREST) Connect(ctx api.Context, id string, opts runtime.Object, res } location.Path = path.Join(location.Path, proxyOpts.Path) // Return a proxy handler that uses the desired transport, wrapped with additional proxy handling (to get URL rewriting, X-Forwarded-* headers, etc) - return newThrottledUpgradeAwareProxyHandler(location, transport, true, false, responder), nil + return newThrottledUpgradeAwareProxyHandler(location, transport, true, false, false, responder), nil } // Support both GET and POST methods. We must support GET for browsers that want to use WebSockets. @@ -100,7 +100,7 @@ func (r *AttachREST) Connect(ctx api.Context, name string, opts runtime.Object, if err != nil { return nil, err } - return newThrottledUpgradeAwareProxyHandler(location, transport, false, true, responder), nil + return newThrottledUpgradeAwareProxyHandler(location, transport, false, true, true, responder), nil } // NewConnectOptions returns the versioned object that represents exec parameters @@ -137,7 +137,7 @@ func (r *ExecREST) Connect(ctx api.Context, name string, opts runtime.Object, re if err != nil { return nil, err } - return newThrottledUpgradeAwareProxyHandler(location, transport, false, true, responder), nil + return newThrottledUpgradeAwareProxyHandler(location, transport, false, true, true, responder), nil } // NewConnectOptions returns the versioned object that represents exec parameters @@ -180,11 +180,12 @@ func (r *PortForwardREST) Connect(ctx api.Context, name string, opts runtime.Obj if err != nil { return nil, err } - return newThrottledUpgradeAwareProxyHandler(location, transport, false, true, responder), nil + return newThrottledUpgradeAwareProxyHandler(location, transport, false, true, true, responder), nil } -func newThrottledUpgradeAwareProxyHandler(location *url.URL, transport http.RoundTripper, wrapTransport, upgradeRequired bool, responder rest.Responder) *genericrest.UpgradeAwareProxyHandler { +func newThrottledUpgradeAwareProxyHandler(location *url.URL, transport http.RoundTripper, wrapTransport, upgradeRequired, interceptRedirects bool, responder rest.Responder) *genericrest.UpgradeAwareProxyHandler { handler := genericrest.NewUpgradeAwareProxyHandler(location, transport, wrapTransport, upgradeRequired, responder) + handler.InterceptRedirects = interceptRedirects handler.MaxBytesPerSec = capabilities.Get().PerConnectionBandwidthLimitBytesPerSec return handler } diff --git a/pkg/registry/generic/rest/BUILD b/pkg/registry/generic/rest/BUILD index d75240ba662c..70a40749f12a 100644 --- a/pkg/registry/generic/rest/BUILD +++ b/pkg/registry/generic/rest/BUILD @@ -23,9 +23,11 @@ go_library( "//pkg/api/errors:go_default_library", "//pkg/api/rest:go_default_library", "//pkg/api/unversioned:go_default_library", + "//pkg/util/config:go_default_library", "//pkg/util/httpstream:go_default_library", "//pkg/util/net:go_default_library", "//pkg/util/proxy:go_default_library", + "//pkg/util/runtime:go_default_library", "//vendor:github.com/golang/glog", "//vendor:github.com/mxk/go-flowrate/flowrate", ], @@ -43,8 +45,12 @@ go_test( deps = [ "//pkg/api:go_default_library", "//pkg/api/errors:go_default_library", + "//pkg/util/config:go_default_library", + "//pkg/util/httpstream:go_default_library", "//pkg/util/net:go_default_library", "//pkg/util/proxy:go_default_library", + "//vendor:github.com/stretchr/testify/assert", + "//vendor:github.com/stretchr/testify/require", "//vendor:golang.org/x/net/websocket", ], ) diff --git a/pkg/registry/generic/rest/proxy.go b/pkg/registry/generic/rest/proxy.go index 81a297bc3071..02940502f7fd 100644 --- a/pkg/registry/generic/rest/proxy.go +++ b/pkg/registry/generic/rest/proxy.go @@ -17,7 +17,11 @@ limitations under the License. package rest import ( + "bufio" + "bytes" + "fmt" "io" + "net" "net/http" "net/http/httputil" "net/url" @@ -26,9 +30,11 @@ import ( "time" "k8s.io/kubernetes/pkg/api/errors" + utilconfig "k8s.io/kubernetes/pkg/util/config" "k8s.io/kubernetes/pkg/util/httpstream" - "k8s.io/kubernetes/pkg/util/net" + utilnet "k8s.io/kubernetes/pkg/util/net" "k8s.io/kubernetes/pkg/util/proxy" + utilruntime "k8s.io/kubernetes/pkg/util/runtime" "github.com/golang/glog" "github.com/mxk/go-flowrate/flowrate" @@ -41,10 +47,13 @@ type UpgradeAwareProxyHandler struct { // Transport provides an optional round tripper to use to proxy. If nil, the default proxy transport is used Transport http.RoundTripper // WrapTransport indicates whether the provided Transport should be wrapped with default proxy transport behavior (URL rewriting, X-Forwarded-* header setting) - WrapTransport bool - FlushInterval time.Duration - MaxBytesPerSec int64 - Responder ErrorResponder + WrapTransport bool + // InterceptRedirects determines whether the proxy should sniff backend responses for redirects, + // following them as necessary. + InterceptRedirects bool + FlushInterval time.Duration + MaxBytesPerSec int64 + Responder ErrorResponder } const defaultFlushInterval = 200 * time.Millisecond @@ -131,32 +140,44 @@ func (h *UpgradeAwareProxyHandler) tryUpgrade(w http.ResponseWriter, req *http.R return false } - backendConn, err := proxy.DialURL(h.Location, h.Transport) + var ( + backendConn net.Conn + rawResponse []byte + err error + ) + if h.InterceptRedirects && utilconfig.DefaultFeatureGate.StreamingProxyRedirects() { + backendConn, rawResponse, err = h.connectBackendWithRedirects(req) + } else { + backendConn, err = h.connectBackend(req.Method, h.Location, req.Header, req.Body) + } if err != nil { h.Responder.Error(err) return true } defer backendConn.Close() - requestHijackedConn, _, err := w.(http.Hijacker).Hijack() - if err != nil { - h.Responder.Error(err) + // Once the connection is hijacked, the ErrorResponder will no longer work, so + // hijacking should be the last step in the upgrade. + requestHijacker, ok := w.(http.Hijacker) + if !ok { + h.Responder.Error(fmt.Errorf("request connection cannot be hijacked: %T", w)) return true } - defer requestHijackedConn.Close() - - newReq, err := http.NewRequest(req.Method, h.Location.String(), req.Body) + requestHijackedConn, _, err := requestHijacker.Hijack() if err != nil { - h.Responder.Error(err) + h.Responder.Error(fmt.Errorf("error hijacking request connection: %v", err)) return true } - newReq.Header = req.Header + defer requestHijackedConn.Close() - if err = newReq.Write(backendConn); err != nil { - h.Responder.Error(err) - return true + // Forward raw response bytes back to client. + if len(rawResponse) > 0 { + if _, err = requestHijackedConn.Write(rawResponse); err != nil { + utilruntime.HandleError(fmt.Errorf("Error proxying response from backend to client: %v", err)) + } } + // Proxy the connection. wg := &sync.WaitGroup{} wg.Add(2) @@ -192,6 +213,113 @@ func (h *UpgradeAwareProxyHandler) tryUpgrade(w http.ResponseWriter, req *http.R return true } +// connectBackend dials the backend at location and forwards a copy of the client request. +func (h *UpgradeAwareProxyHandler) connectBackend(method string, location *url.URL, header http.Header, body io.Reader) (conn net.Conn, err error) { + defer func() { + if err != nil && conn != nil { + conn.Close() + conn = nil + } + }() + + beReq, err := http.NewRequest(method, location.String(), body) + if err != nil { + return nil, err + } + beReq.Header = header + + conn, err = proxy.DialURL(location, h.Transport) + if err != nil { + return conn, fmt.Errorf("error dialing backend: %v", err) + } + + if err = beReq.Write(conn); err != nil { + return conn, fmt.Errorf("error sending request: %v", err) + } + + return conn, err +} + +// connectBackendWithRedirects dials the backend and forwards a copy of the client request. If the +// client responds with a redirect, it is followed. The raw response bytes are returned, and should +// be forwarded back to the client. +func (h *UpgradeAwareProxyHandler) connectBackendWithRedirects(req *http.Request) (net.Conn, []byte, error) { + const ( + maxRedirects = 10 + maxResponseSize = 4096 + ) + var ( + initialReq = req + rawResponse = bytes.NewBuffer(make([]byte, 0, 256)) + location = h.Location + intermediateConn net.Conn + err error + ) + defer func() { + if intermediateConn != nil { + intermediateConn.Close() + } + }() + +redirectLoop: + for redirects := 0; ; redirects++ { + if redirects > maxRedirects { + return nil, nil, fmt.Errorf("too many redirects (%d)", redirects) + } + + if redirects == 0 { + intermediateConn, err = h.connectBackend(req.Method, location, req.Header, req.Body) + } else { + // Redirected requests switch to "GET" according to the HTTP spec: + // https://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3 + intermediateConn, err = h.connectBackend("GET", location, initialReq.Header, nil) + } + + if err != nil { + return nil, nil, err + } + + // Peek at the backend response. + rawResponse.Reset() + respReader := bufio.NewReader(io.TeeReader( + io.LimitReader(intermediateConn, maxResponseSize), // Don't read more than maxResponseSize bytes. + rawResponse)) // Save the raw response. + resp, err := http.ReadResponse(respReader, req) + if err != nil { + // Unable to read the backend response; let the client handle it. + glog.Warningf("Error reading backend response: %v", err) + break redirectLoop + } + resp.Body.Close() // Unused. + + switch resp.StatusCode { + case http.StatusFound: + // Redirect, continue. + default: + // Don't redirect. + break redirectLoop + } + + // Reset the connection. + intermediateConn.Close() + intermediateConn = nil + + // Prepare to follow the redirect. + redirectStr := resp.Header.Get("Location") + if redirectStr == "" { + return nil, nil, fmt.Errorf("%d response missing Location header", resp.StatusCode) + } + location, err = h.Location.Parse(redirectStr) + if err != nil { + return nil, nil, fmt.Errorf("malformed Location header: %v", err) + } + } + + backendConn := intermediateConn + intermediateConn = nil // Don't close the connection when we return it. + return backendConn, rawResponse.Bytes(), nil +} + func (h *UpgradeAwareProxyHandler) defaultProxyTransport(url *url.URL, internalTransport http.RoundTripper) http.RoundTripper { scheme := url.Scheme host := url.Host @@ -213,12 +341,15 @@ func (h *UpgradeAwareProxyHandler) defaultProxyTransport(url *url.URL, internalT // corsRemovingTransport is a wrapper for an internal transport. It removes CORS headers // from the internal response. +// Implements pkg/util/net.RoundTripperWrapper type corsRemovingTransport struct { http.RoundTripper } -func (p *corsRemovingTransport) RoundTrip(req *http.Request) (*http.Response, error) { - resp, err := p.RoundTripper.RoundTrip(req) +var _ = utilnet.RoundTripperWrapper(&corsRemovingTransport{}) + +func (rt *corsRemovingTransport) RoundTrip(req *http.Request) (*http.Response, error) { + resp, err := rt.RoundTripper.RoundTrip(req) if err != nil { return nil, err } @@ -226,8 +357,6 @@ func (p *corsRemovingTransport) RoundTrip(req *http.Request) (*http.Response, er return resp, nil } -var _ = net.RoundTripperWrapper(&corsRemovingTransport{}) - func (rt *corsRemovingTransport) WrappedRoundTripper() http.RoundTripper { return rt.RoundTripper } diff --git a/pkg/registry/generic/rest/proxy_test.go b/pkg/registry/generic/rest/proxy_test.go index 644da05e5180..80454f18ab22 100644 --- a/pkg/registry/generic/rest/proxy_test.go +++ b/pkg/registry/generic/rest/proxy_test.go @@ -21,6 +21,7 @@ import ( "compress/gzip" "crypto/tls" "crypto/x509" + "errors" "fmt" "io" "io/ioutil" @@ -33,26 +34,59 @@ import ( "strconv" "strings" "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/net/websocket" + utilconfig "k8s.io/kubernetes/pkg/util/config" + "k8s.io/kubernetes/pkg/util/httpstream" utilnet "k8s.io/kubernetes/pkg/util/net" "k8s.io/kubernetes/pkg/util/proxy" ) +const fakeStatusCode = 567 + type fakeResponder struct { + t *testing.T called bool err error + // called chan error + w http.ResponseWriter } func (r *fakeResponder) Error(err error) { if r.called { - panic("called twice") + r.t.Errorf("Error responder called again!\nprevious error: %v\nnew error: %v", r.err, err) + } + + if r.w != nil { + r.w.WriteHeader(fakeStatusCode) + _, writeErr := r.w.Write([]byte(err.Error())) + assert.NoError(r.t, writeErr) + } else { + r.t.Logf("No ResponseWriter set") } + r.called = true r.err = err } +type fakeConn struct { + err error // The error to return when io is performed over the connection. +} + +func (f *fakeConn) Read([]byte) (int, error) { return 0, f.err } +func (f *fakeConn) Write([]byte) (int, error) { return 0, f.err } +func (f *fakeConn) Close() error { return nil } +func (fakeConn) LocalAddr() net.Addr { return nil } +func (fakeConn) RemoteAddr() net.Addr { return nil } +func (fakeConn) SetDeadline(t time.Time) error { return nil } +func (fakeConn) SetReadDeadline(t time.Time) error { return nil } +func (fakeConn) SetWriteDeadline(t time.Time) error { return nil } + type SimpleBackendHandler struct { requestURL url.URL requestHeader http.Header @@ -210,7 +244,7 @@ func TestServeHTTP(t *testing.T) { backendServer := httptest.NewServer(backendHandler) defer backendServer.Close() - responder := &fakeResponder{} + responder := &fakeResponder{t: t} backendURL, _ := url.Parse(backendServer.URL) backendURL.Path = test.requestPath proxyHandler := &UpgradeAwareProxyHandler{ @@ -367,43 +401,101 @@ func TestProxyUpgrade(t *testing.T) { }, } - for k, tc := range testcases { + // Enable StreamingProxyRedirects for test. + utilconfig.DefaultFeatureGate.Set("StreamingProxyRedirects=true") - backendServer := tc.ServerFunc(websocket.Handler(func(ws *websocket.Conn) { - defer ws.Close() - body := make([]byte, 5) - ws.Read(body) - ws.Write([]byte("hello " + string(body))) - })) - defer backendServer.Close() + for k, tc := range testcases { + for _, redirect := range []bool{false, true} { + tcName := k + backendPath := "/hello" + if redirect { + tcName += " with redirect" + backendPath = "/redirect" + } + func() { // Cleanup after each test case. + backend := http.NewServeMux() + backend.Handle("/hello", websocket.Handler(func(ws *websocket.Conn) { + defer ws.Close() + body := make([]byte, 5) + ws.Read(body) + ws.Write([]byte("hello " + string(body))) + })) + backend.Handle("/redirect", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "/hello", http.StatusFound) + })) + backendServer := tc.ServerFunc(backend) + defer backendServer.Close() + + serverURL, _ := url.Parse(backendServer.URL) + serverURL.Path = backendPath + proxyHandler := &UpgradeAwareProxyHandler{ + Location: serverURL, + Transport: tc.ProxyTransport, + InterceptRedirects: redirect, + } + proxy := httptest.NewServer(proxyHandler) + defer proxy.Close() - serverURL, _ := url.Parse(backendServer.URL) - proxyHandler := &UpgradeAwareProxyHandler{ - Location: serverURL, - Transport: tc.ProxyTransport, - } - proxy := httptest.NewServer(proxyHandler) - defer proxy.Close() + ws, err := websocket.Dial("ws://"+proxy.Listener.Addr().String()+"/some/path", "", "http://127.0.0.1/") + if err != nil { + t.Fatalf("%s: websocket dial err: %s", tcName, err) + } + defer ws.Close() - ws, err := websocket.Dial("ws://"+proxy.Listener.Addr().String()+"/some/path", "", "http://127.0.0.1/") - if err != nil { - t.Fatalf("%s: websocket dial err: %s", k, err) - } - defer ws.Close() + if _, err := ws.Write([]byte("world")); err != nil { + t.Fatalf("%s: write err: %s", tcName, err) + } - if _, err := ws.Write([]byte("world")); err != nil { - t.Fatalf("%s: write err: %s", k, err) + response := make([]byte, 20) + n, err := ws.Read(response) + if err != nil { + t.Fatalf("%s: read err: %s", tcName, err) + } + if e, a := "hello world", string(response[0:n]); e != a { + t.Fatalf("%s: expected '%#v', got '%#v'", tcName, e, a) + } + }() } + } +} - response := make([]byte, 20) - n, err := ws.Read(response) - if err != nil { - t.Fatalf("%s: read err: %s", k, err) +func TestProxyUpgradeErrorResponse(t *testing.T) { + var ( + responder *fakeResponder + expectedErr = errors.New("EXPECTED") + ) + proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + transport := http.DefaultTransport.(*http.Transport) + transport.Dial = func(network, addr string) (net.Conn, error) { + return &fakeConn{err: expectedErr}, nil } - if e, a := "hello world", string(response[0:n]); e != a { - t.Fatalf("%s: expected '%#v', got '%#v'", k, e, a) + responder = &fakeResponder{t: t, w: w} + proxyHandler := &UpgradeAwareProxyHandler{ + Location: &url.URL{ + Host: "fake-backend", + }, + UpgradeRequired: true, + Responder: responder, + Transport: transport, } - } + proxyHandler.ServeHTTP(w, r) + })) + defer proxy.Close() + + // Send request to proxy server. + req, err := http.NewRequest("POST", "http://"+proxy.Listener.Addr().String()+"/some/path", nil) + require.NoError(t, err) + req.Header.Set(httpstream.HeaderConnection, httpstream.HeaderUpgrade) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Expect error response. + assert.True(t, responder.called) + assert.Equal(t, fakeStatusCode, resp.StatusCode) + msg, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + assert.Contains(t, string(msg), expectedErr.Error()) } func TestDefaultProxyTransport(t *testing.T) { @@ -415,7 +507,6 @@ func TestDefaultProxyTransport(t *testing.T) { expectedHost, expectedPathPrepend string }{ - { name: "simple path", url: "http://test.server:8080/a/test/location", @@ -619,7 +710,7 @@ func TestProxyRequestContentLengthAndTransferEncoding(t *testing.T) { })) defer downstreamServer.Close() - responder := &fakeResponder{} + responder := &fakeResponder{t: t} backendURL, _ := url.Parse(downstreamServer.URL) proxyHandler := &UpgradeAwareProxyHandler{ Location: backendURL, diff --git a/pkg/util/config/feature_gate.go b/pkg/util/config/feature_gate.go index df3fb57276fa..11996dabf14c 100644 --- a/pkg/util/config/feature_gate.go +++ b/pkg/util/config/feature_gate.go @@ -42,6 +42,7 @@ const ( appArmor = "AppArmor" dynamicKubeletConfig = "DynamicKubeletConfig" dynamicVolumeProvisioning = "DynamicVolumeProvisioning" + streamingProxyRedirects = "StreamingProxyRedirects" ) var ( @@ -53,6 +54,7 @@ var ( appArmor: {true, beta}, dynamicKubeletConfig: {false, alpha}, dynamicVolumeProvisioning: {true, alpha}, + streamingProxyRedirects: {false, alpha}, } // Special handling for a few gates. @@ -109,6 +111,10 @@ type FeatureGate interface { // owner: @mtaufen // alpha: v1.4 DynamicKubeletConfig() bool + + // owner: timstclair + // alpha: v1.5 + StreamingProxyRedirects() bool } // featureGate implements FeatureGate as well as pflag.Value for flag parsing. @@ -197,6 +203,12 @@ func (f *featureGate) DynamicVolumeProvisioning() bool { return f.lookup(dynamicVolumeProvisioning) } +// StreamingProxyRedirects controls whether the apiserver should intercept (and follow) +// redirects from the backend (Kubelet) for streaming requests (exec/attach/port-forward). +func (f *featureGate) StreamingProxyRedirects() bool { + return f.lookup(streamingProxyRedirects) +} + func (f *featureGate) lookup(key string) bool { defaultValue := f.known[key].enabled if f.enabled != nil {