diff --git a/lib/proxy/router.go b/lib/proxy/router.go index c2243eaae03f2..e825090ccc34c 100644 --- a/lib/proxy/router.go +++ b/lib/proxy/router.go @@ -82,24 +82,24 @@ func init() { metrics.RegisterPrometheusCollectors(proxiedSessions, failedConnectingToNode, connectingToNode) } -// proxiedMetricConn wraps [net.Conn] opened by +// ProxiedMetricConn wraps [net.Conn] opened by // the [Router] so that the proxiedSessions counter // can be decremented when it is closed. -type proxiedMetricConn struct { +type ProxiedMetricConn struct { // once ensures that proxiedSessions is only decremented // a single time per [net.Conn] once sync.Once net.Conn } -// newProxiedMetricConn increments proxiedSessions and creates -// a proxiedMetricConn that defers to the provided [net.Conn]. -func newProxiedMetricConn(conn net.Conn) *proxiedMetricConn { +// NewProxiedMetricConn increments proxiedSessions and creates +// a ProxiedMetricConn that defers to the provided [net.Conn]. +func NewProxiedMetricConn(conn net.Conn) *ProxiedMetricConn { proxiedSessions.Inc() - return &proxiedMetricConn{Conn: conn} + return &ProxiedMetricConn{Conn: conn} } -func (c *proxiedMetricConn) Close() error { +func (c *ProxiedMetricConn) Close() error { c.once.Do(proxiedSessions.Dec) return trace.Wrap(c.Conn.Close()) } @@ -334,7 +334,7 @@ func (r *Router) DialHost(ctx context.Context, clientSrcAddr, clientDstAddr net. return nil, trace.Wrap(err) } - return newProxiedMetricConn(conn), trace.Wrap(err) + return NewProxiedMetricConn(conn), trace.Wrap(err) } // getRemoteCluster looks up the provided clusterName to determine if a remote site exists with @@ -496,7 +496,7 @@ func (r *Router) DialSite(ctx context.Context, clusterName string, clientSrcAddr return nil, trace.Wrap(err) } - return newProxiedMetricConn(conn), trace.Wrap(err) + return NewProxiedMetricConn(conn), trace.Wrap(err) } // GetSiteClient returns an auth client for the provided cluster. diff --git a/lib/reversetunnel/agentpool.go b/lib/reversetunnel/agentpool.go index 1cfb26e8e2e07..6c58bbb72793a 100644 --- a/lib/reversetunnel/agentpool.go +++ b/lib/reversetunnel/agentpool.go @@ -514,7 +514,7 @@ func (p *AgentPool) getVersion(ctx context.Context) (string, error) { // transport creates a new transport instance. func (p *AgentPool) transport(ctx context.Context, channel ssh.Channel, requests <-chan *ssh.Request, conn sshutils.Conn) *transport { - return &transport{ + t := &transport{ closeContext: ctx, component: p.Component, localClusterName: p.LocalCluster, @@ -531,6 +531,17 @@ func (p *AgentPool) transport(ctx context.Context, channel ssh.Channel, requests proxySigner: p.PROXYSigner, forwardClientAddress: true, } + + // If the AgentPool is being used for Proxy to Proxy communication between two clusters, then + // we check if the reverse tunnel server is capable of tracking user connections. This allows + // the leaf proxy to track sessions that are initiated via the root cluster. Without providing + // the user tracker the leaf cluster metrics will be incorrect and graceful shutdown will not + // wait for user sessions to be terminated prior to proceeding with the shutdown operation. + if p.IsRemoteCluster && p.ReverseTunnelServer != nil { + t.trackUserConnection = p.ReverseTunnelServer.TrackUserConnection + } + + return t } // agentPoolRuntimeConfig contains configurations dynamically set and updated diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index 8a588ca129e5d..5875440cd322a 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -1084,6 +1084,13 @@ func (s *server) rejectRequest(ch ssh.NewChannel, reason ssh.RejectionReason, ms } } +// TrackUserConnection tracks a user connection that should prevent +// the server from being terminated if active. The returned function +// should be called when the connection is terminated. +func (s *server) TrackUserConnection() (release func()) { + return s.srv.TrackUserConnection() +} + // newRemoteSite helper creates and initializes 'remoteSite' instance func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, error) { connInfo, err := types.NewTunnelConnection( diff --git a/lib/reversetunnel/transport.go b/lib/reversetunnel/transport.go index e7f26196aa25f..77ac4ff09ca31 100644 --- a/lib/reversetunnel/transport.go +++ b/lib/reversetunnel/transport.go @@ -36,6 +36,7 @@ import ( "github.com/gravitational/teleport/api/utils/sshutils" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/multiplexer" + "github.com/gravitational/teleport/lib/proxy" "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/utils" ) @@ -93,6 +94,9 @@ type transport struct { // preventing users connecting to the proxy tunnel listener spoofing their address; but we are still able to // correctly propagate client address in reverse tunnel agents of nodes/services. forwardClientAddress bool + + // trackUserConnection is an optional mechanism used to count active user sessions. + trackUserConnection func() (release func()) } // start will start the transporting data over the tunnel. This function will @@ -246,6 +250,10 @@ func (p *transport) start() { // tunnel from the SSH node by dreq.ServerID. We'll need to forward // dreq.Address as well. directAddress = dreq.Address + + if p.trackUserConnection != nil { + defer p.trackUserConnection()() + } default: // Not a special address; could be empty. directAddress = dreq.Address @@ -395,6 +403,11 @@ func (p *transport) getConn(addr string, r *sshutils.DialReq) (net.Conn, bool, e } p.log.Debugf("Returning connection dialed through tunnel with server ID %v.", r.ServerID) + + if r.ConnType == types.NodeTunnel { + return proxy.NewProxiedMetricConn(conn), true, nil + } + return conn, true, nil } diff --git a/lib/reversetunnelclient/api.go b/lib/reversetunnelclient/api.go index 175d06be705b3..461b209e4a38c 100644 --- a/lib/reversetunnelclient/api.go +++ b/lib/reversetunnelclient/api.go @@ -162,6 +162,10 @@ type Server interface { Wait(ctx context.Context) // GetProxyPeerClient returns the proxy peer client GetProxyPeerClient() *peer.Client + // TrackUserConnection tracks a user connection that should prevent + // the server from being terminated if active. The returned function + // should be called when the connection is terminated. + TrackUserConnection() (release func()) } const ( diff --git a/lib/service/service.go b/lib/service/service.go index 3351a713b1b97..1ed14d77412dd 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -4657,12 +4657,12 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { // really guaranteed to be capable to serve new requests if we're // halfway through a shutdown, and double closing a listener is fine. listeners.Close() - rcWatcher.Close() if payload == nil { log.Infof("Shutting down immediately.") if tsrv != nil { warnOnErr(tsrv.Close(), log) } + warnOnErr(rcWatcher.Close(), log) if proxyServer != nil { warnOnErr(proxyServer.Close(), log) } @@ -4709,6 +4709,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { if tsrv != nil { warnOnErr(tsrv.Shutdown(ctx), log) } + warnOnErr(rcWatcher.Close(), log) if proxyServer != nil { warnOnErr(proxyServer.Shutdown(), log) } diff --git a/lib/sshutils/server.go b/lib/sshutils/server.go index d9ecdfd9ef157..9af9f4b769f02 100644 --- a/lib/sshutils/server.go +++ b/lib/sshutils/server.go @@ -442,6 +442,17 @@ func (s *Server) trackUserConnections(delta int32) int32 { return atomic.AddInt32(&s.userConns, delta) } +// TrackUserConnection tracks a user connection that should prevent +// the server from being terminated if active. The returned function +// should be called when the connection is terminated. +func (s *Server) TrackUserConnection() (release func()) { + s.trackUserConnections(1) + + return sync.OnceFunc(func() { + s.trackUserConnections(-1) + }) +} + // ActiveConnections returns the number of connections that are // being served. func (s *Server) ActiveConnections() int32 {