Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v13] Fix an issue ALPN handshake test does not respect "HTTPS_PROXY" #27810

Merged
merged 2 commits into from
Jun 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
20 changes: 15 additions & 5 deletions api/client/alpn_conn_upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"net/url"
"os"
"strings"
"time"

"github.com/gravitational/trace"
"github.com/sirupsen/logrus"
Expand All @@ -48,19 +49,28 @@ import (
// In those cases, the Teleport client should make a HTTP "upgrade" call to the
// Proxy Service to establish a tunnel for the originally planned traffic to
// preserve the ALPN and SNI information.
func IsALPNConnUpgradeRequired(addr string, insecure bool) bool {
func IsALPNConnUpgradeRequired(ctx context.Context, addr string, insecure bool, opts ...DialOption) bool {
if result, ok := OverwriteALPNConnUpgradeRequirementByEnv(addr); ok {
return result
}

netDialer := &net.Dialer{
Timeout: defaults.DefaultIOTimeout,
}
// Use NewDialer which takes care of ProxyURL, and use a shorter I/O
// timeout to avoid blocking caller.
baseDialer := NewDialer(
ctx,
defaults.DefaultIdleTimeout,
5*time.Second,
append(opts,
WithInsecureSkipVerify(insecure),
WithALPNConnUpgrade(false),
)...,
)

tlsConfig := &tls.Config{
NextProtos: []string{string(constants.ALPNSNIProtocolReverseTunnel)},
InsecureSkipVerify: insecure,
}
testConn, err := tls.DialWithDialer(netDialer, "tcp", addr, tlsConfig)
testConn, err := tlsutils.TLSDial(ctx, baseDialer, "tcp", addr, tlsConfig)
if err != nil {
if isRemoteNoALPNError(err) {
logrus.Debugf("ALPN connection upgrade required for %q: %v. No ALPN protocol is negotiated by the server.", addr, true)
Expand Down
85 changes: 70 additions & 15 deletions api/client/alpn_conn_upgrade_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (

"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/api/fixtures"
"github.com/gravitational/teleport/api/testhelpers"
"github.com/gravitational/teleport/api/utils/pingconn"
)

Expand Down Expand Up @@ -70,10 +71,21 @@ func TestIsALPNConnUpgradeRequired(t *testing.T) {
},
}

ctx := context.Background()
forwardProxy, forwardProxyURL := mustStartForwardProxy(t)

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
server := mustStartMockALPNServer(t, test.serverProtos)
require.Equal(t, test.expectedResult, IsALPNConnUpgradeRequired(server.Addr().String(), test.insecure))
t.Run("direct", func(t *testing.T) {
require.Equal(t, test.expectedResult, IsALPNConnUpgradeRequired(ctx, server.Addr().String(), test.insecure))
})

t.Run("with ProxyURL", func(t *testing.T) {
countBeforeTest := forwardProxy.Count()
require.Equal(t, test.expectedResult, IsALPNConnUpgradeRequired(ctx, server.Addr().String(), test.insecure, withProxyURL(forwardProxyURL)))
require.Equal(t, countBeforeTest+1, forwardProxy.Count())
})
})
}
}
Expand Down Expand Up @@ -160,24 +172,50 @@ func TestALPNConnUpgradeDialer(t *testing.T) {
pool.AddCert(server.Certificate())

tlsConfig := &tls.Config{RootCAs: pool}
preDialer := newDirectDialer(0, 5*time.Second)
dialer := newALPNConnUpgradeDialer(preDialer, tlsConfig, test.withPing)
conn, err := dialer.DialContext(ctx, "tcp", addr.Host)
if test.wantError {
require.Error(t, err)
return
}
require.NoError(t, err)
defer conn.Close()

data := make([]byte, 100)
n, err := conn.Read(data)
require.NoError(t, err)
require.Equal(t, string(data[:n]), "hello")
directDialer := newDirectDialer(0, 5*time.Second)

t.Run("direct", func(t *testing.T) {
dialer := newALPNConnUpgradeDialer(directDialer, tlsConfig, test.withPing)
conn, err := dialer.DialContext(ctx, "tcp", addr.Host)
if test.wantError {
require.Error(t, err)
return
}
require.NoError(t, err)
defer conn.Close()

mustReadConnData(t, conn, "hello")
})

t.Run("with ProxyURL", func(t *testing.T) {
forwardProxy, forwardProxyURL := mustStartForwardProxy(t)
countBeforeTest := forwardProxy.Count()

proxyURLDialer := newProxyURLDialer(forwardProxyURL, directDialer)
dialer := newALPNConnUpgradeDialer(proxyURLDialer, tlsConfig, test.withPing)
conn, err := dialer.DialContext(ctx, "tcp", addr.Host)
if test.wantError {
require.Error(t, err)
return
}
require.NoError(t, err)
defer conn.Close()

mustReadConnData(t, conn, "hello")
require.Equal(t, countBeforeTest+1, forwardProxy.Count())
})
})
}
}

func mustReadConnData(t *testing.T, conn net.Conn, wantText string) {
data := make([]byte, len(wantText)*2)
n, err := conn.Read(data)
require.NoError(t, err)
require.Equal(t, len(wantText), n)
require.Equal(t, string(data[:n]), wantText)
}

type mockALPNServer struct {
net.Listener
cert tls.Certificate
Expand Down Expand Up @@ -273,3 +311,20 @@ func mockConnUpgradeHandler(t *testing.T, upgradeType string, write []byte) http
}
})
}

func mustStartForwardProxy(t *testing.T) (*testhelpers.ProxyHandler, *url.URL) {
t.Helper()

listener, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
t.Cleanup(func() {
listener.Close()
})

url, err := url.Parse("http://" + listener.Addr().String())
require.NoError(t, err)

handler := &testhelpers.ProxyHandler{}
go http.Serve(listener, handler)
return handler, url
}
24 changes: 22 additions & 2 deletions api/client/contextdialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,18 @@ type dialConfig struct {
// proxyHeaderGetter is used if present to get signed PROXY headers to propagate client's IP.
// Used by proxy's web server to make calls on behalf of connected clients.
proxyHeaderGetter PROXYHeaderGetter
// proxyURLFunc is a function used to get ProxyURL. Defaults to
// utils.GetProxyURL if not specified. Currently only used in tests to
// overwrite the ProxyURL as httpproxy.FromEnvironment skips localhost
// proxies.
proxyURLFunc func(dialAddr string) *url.URL
}

func (c *dialConfig) getProxyURL(dialAddr string) *url.URL {
if c.proxyURLFunc != nil {
return c.proxyURLFunc(dialAddr)
}
return utils.GetProxyURL(dialAddr)
}

// WithInsecureSkipVerify specifies if dialing insecure when using an HTTPS proxy.
Expand All @@ -74,6 +86,14 @@ func WithALPNConnUpgradePing(alpnConnUpgradeWithPing bool) DialOption {
}
}

func withProxyURL(proxyURL *url.URL) DialProxyOption {
return func(cfg *dialProxyConfig) {
cfg.proxyURLFunc = func(_ string) *url.URL {
return proxyURL
}
}
}

// WithPROXYHeaderGetter provides PROXY headers signer so client's real IP could be propagated.
// Used by proxy's web server to make calls on behalf of connected clients.
func WithPROXYHeaderGetter(proxyHeaderGetter PROXYHeaderGetter) DialProxyOption {
Expand Down Expand Up @@ -179,7 +199,7 @@ func NewDialer(ctx context.Context, keepAlivePeriod, dialTimeout time.Duration,
}

// Wrap with proxy URL dialer if proxy URL is detected.
if proxyURL := utils.GetProxyURL(addr); proxyURL != nil {
if proxyURL := cfg.getProxyURL(addr); proxyURL != nil {
dialer = newProxyURLDialer(proxyURL, dialer, opts...)
}

Expand Down Expand Up @@ -327,7 +347,7 @@ func newTLSRoutingWithConnUpgradeDialer(ssh ssh.ClientConfig, params connectPara
InsecureSkipVerify: insecure,
ServerName: host,
},
ALPNConnUpgradeRequired: IsALPNConnUpgradeRequired(params.addr, insecure),
ALPNConnUpgradeRequired: IsALPNConnUpgradeRequired(ctx, params.addr, insecure),
GetClusterCAs: func(_ context.Context) (*x509.CertPool, error) {
tlsConfig, err := params.cfg.Credentials[0].TLSConfig()
if err != nil {
Expand Down
100 changes: 100 additions & 0 deletions api/testhelpers/proxy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright 2023 Gravitational, Inc
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package testhelpers

import (
"io"
"net"
"net/http"
"sync"
"time"

"github.com/gravitational/trace"
)

// ProxyHandler is a http.Handler that implements a simple HTTP proxy server.
type ProxyHandler struct {
sync.Mutex
count int
}

// ServeHTTP only accepts the CONNECT verb and will tunnel your connection to
// the specified host. Also tracks the number of connections that it proxies for
// debugging purposes.
func (p *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Validate http connect parameters.
if r.Method != http.MethodConnect {
trace.WriteError(w, trace.BadParameter("%v not supported", r.Method))
return
}
if r.Host == "" {
trace.WriteError(w, trace.BadParameter("host not set"))
return
}

// Dial to the target host, this is done before hijacking the connection to
// ensure the target host is accessible.
dialer := net.Dialer{}
dconn, err := dialer.DialContext(r.Context(), "tcp", r.Host)
if err != nil {
trace.WriteError(w, err)
return
}
defer dconn.Close()

// Once the client receives 200 OK, the rest of the data will no longer be
// http, but whatever protocol is being tunneled.
w.WriteHeader(http.StatusOK)

// Hijack request so we can get underlying connection.
hj, ok := w.(http.Hijacker)
if !ok {
trace.WriteError(w, trace.AccessDenied("unable to hijack connection"))
return
}
sconn, _, err := hj.Hijack()
if err != nil {
trace.WriteError(w, err)
return
}
defer sconn.Close()

// Success, we're proxying data now.
p.Lock()
p.count++
p.Unlock()

// Copy from src to dst and dst to src.
errc := make(chan error, 2)
replicate := func(dst io.Writer, src io.Reader) {
_, err := io.Copy(dst, src)
errc <- err
}
go replicate(sconn, dconn)
go replicate(dconn, sconn)

// Wait until done, error, or 10 second.
select {
case <-time.After(10 * time.Second):
case <-errc:
}
}

// Count returns the number of requests that have been proxied.
func (p *ProxyHandler) Count() int {
p.Lock()
defer p.Unlock()
return p.count
}