Skip to content

Commit

Permalink
[v15] Wrap diag service listener with multiplexer so it can work behi…
Browse files Browse the repository at this point in the history
…nd PROXY enabled loadbalancer/proxy. (#40138)

* Wrap diag service listener with multiplexer so it can work behind PROXY enabled loadbalancer/proxy.

It accept simultaneously connections that are prepended with PROXY line or not.
We also don't issue warnings about unspecified PROXY protocol mode for this listener.

* Fix wording.

Co-authored-by: Gus Luxton <gus@goteleport.com>

* Use ExitContext instead of GracefulExitContext

Co-authored-by: Edoardo Spadolini <edoardo.spadolini@goteleport.com>

* Close diag multiplexer listener during diagnostic.shutdown event.

* Refactor server.Serve() call

* Move creation of muxListener outside of diagnostic.service event.

* Combine declaration and usage

Co-authored-by: Edoardo Spadolini <edoardo.spadolini@goteleport.com>

---------

Co-authored-by: Gus Luxton <gus@goteleport.com>
Co-authored-by: Edoardo Spadolini <edoardo.spadolini@goteleport.com>
  • Loading branch information
3 people committed Apr 2, 2024
1 parent 453a6bc commit f54fcf6
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 11 deletions.
35 changes: 26 additions & 9 deletions lib/multiplexer/multiplexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ type Config struct {
Clock clockwork.Clock
// PROXYProtocolMode controls behavior related to unsigned PROXY protocol headers.
PROXYProtocolMode PROXYProtocolMode
// SuppressUnexpectedPROXYWarning makes multiplexer not issue warnings if it receives PROXY
// line when running in PROXYProtocolMode=PROXYProtocolUnspecified
SuppressUnexpectedPROXYWarning bool
// ID is an identifier used for debugging purposes
ID string
// CertAuthorityGetter is used to get CA to verify singed PROXY headers sent internally by teleport
Expand Down Expand Up @@ -171,13 +174,14 @@ type Mux struct {
sync.RWMutex
*log.Entry
Config
sshListener *Listener
tlsListener *Listener
dbListener *Listener
context context.Context
cancel context.CancelFunc
waitContext context.Context
waitCancel context.CancelFunc
sshListener *Listener
tlsListener *Listener
dbListener *Listener
httpListener *Listener
context context.Context
cancel context.CancelFunc
waitContext context.Context
waitCancel context.CancelFunc
// logLimiter is a goroutine responsible for deduplicating multiplexer errors
// (over a 1min window) that occur when detecting the types of new connections.
// This ensures that health checkers / malicious actors cannot overpower /
Expand Down Expand Up @@ -216,6 +220,16 @@ func (m *Mux) DB() net.Listener {
return m.dbListener
}

// HTTP returns listener that receives plain HTTP connections
func (m *Mux) HTTP() net.Listener {
m.Lock()
defer m.Unlock()
if m.httpListener == nil {
m.httpListener = newListener(m.context, m.Config.Listener.Addr())
}
return m.httpListener
}

func (m *Mux) closeListener() {
m.Lock()
defer m.Unlock()
Expand Down Expand Up @@ -287,6 +301,9 @@ func (m *Mux) protocolListener(proto Protocol) *Listener {
return m.sshListener
case ProtoPostgres:
return m.dbListener
case ProtoHTTP:
return m.httpListener

}
return nil
}
Expand Down Expand Up @@ -511,7 +528,7 @@ func (m *Mux) detect(conn net.Conn) (*Conn, error) {
}
unsignedPROXYLineReceived = true

if m.PROXYProtocolMode == PROXYProtocolUnspecified {
if m.PROXYProtocolMode == PROXYProtocolUnspecified && !m.SuppressUnexpectedPROXYWarning {
m.logLimiter.Log(m.WithFields(log.Fields{
"direct_src_addr": conn.RemoteAddr(),
"direct_dst_addr": conn.LocalAddr(),
Expand Down Expand Up @@ -589,7 +606,7 @@ func (m *Mux) detect(conn net.Conn) (*Conn, error) {
}
unsignedPROXYLineReceived = true

if m.PROXYProtocolMode == PROXYProtocolUnspecified {
if m.PROXYProtocolMode == PROXYProtocolUnspecified && !m.SuppressUnexpectedPROXYWarning {
m.logLimiter.Log(m.WithFields(log.Fields{
"direct_src_addr": conn.RemoteAddr(),
"direct_dst_addr": conn.LocalAddr(),
Expand Down
42 changes: 42 additions & 0 deletions lib/multiplexer/multiplexer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,48 @@ func TestMux(t *testing.T) {
}
require.Error(t, err)
})
t.Run("HTTP", func(t *testing.T) {
t.Parallel()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)

mux, err := New(Config{
Listener: listener,
})
require.NoError(t, err)
go mux.Serve()
defer mux.Close()

backend1 := &httptest.Server{
Listener: mux.HTTP(),
Config: &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "backend 1")
}),
},
}
backend1.Start()
defer backend1.Close()

re, err := http.Get(backend1.URL)
require.NoError(t, err)
defer re.Body.Close()
bytes, err := io.ReadAll(re.Body)
require.NoError(t, err)
require.Equal(t, "backend 1", string(bytes))

// Close mux, new requests should fail
mux.Close()
mux.Wait()

// Use new client to use new connection pool
client := &http.Client{Transport: &http.Transport{}}
re, err = client.Get(backend1.URL)
if err == nil {
re.Body.Close()
}
require.Error(t, err)
})
// ProxyLine tests proxy line protocol
t.Run("ProxyLines", func(t *testing.T) {
t.Parallel()
Expand Down
22 changes: 20 additions & 2 deletions lib/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -3238,15 +3238,33 @@ func (process *TeleportProcess) initDiagnosticService() error {

logger.InfoContext(process.ExitContext(), "Starting diagnostic service.", "listen_address", process.Config.DiagnosticAddr.Addr)

muxListener, err := multiplexer.New(multiplexer.Config{
Context: process.ExitContext(),
Listener: listener,
PROXYProtocolMode: multiplexer.PROXYProtocolUnspecified,
SuppressUnexpectedPROXYWarning: true,
ID: teleport.Component(teleport.ComponentDiagnostic),
})
if err != nil {
return trace.Wrap(err)
}

process.RegisterFunc("diagnostic.service", func() error {
err := server.Serve(listener)
if err != nil && err != http.ErrServerClosed {
listenerHTTP := muxListener.HTTP()
go func() {
if err := muxListener.Serve(); err != nil && !utils.IsOKNetworkError(err) {
muxListener.Entry.WithError(err).Error("Mux encountered err serving")
}
}()

if err := server.Serve(listenerHTTP); !errors.Is(err, http.ErrServerClosed) {
logger.WarnContext(process.ExitContext(), "Diagnostic server exited with error.", "error", err)
}
return nil
})

process.OnExit("diagnostic.shutdown", func(payload interface{}) {
warnOnErr(process.ExitContext(), muxListener.Close(), logger)
if payload == nil {
logger.InfoContext(process.ExitContext(), "Shutting down immediately.")
warnOnErr(process.ExitContext(), server.Close(), logger)
Expand Down

0 comments on commit f54fcf6

Please sign in to comment.