Skip to content

Commit

Permalink
Add idle connection timeouts to http clients and servers (#22885) (#2…
Browse files Browse the repository at this point in the history
…2916)

Sets `http.Server.IdleTimeout` and `http.Client.IdleConnTimeout`
on clients and servers which didn't have them set. A default of
360s was chosen to be on par with the default of an NLB without
being identical.

This was added as another safety measure to prevent leaking any
idle connections indefinitely as seen in #22757.

`apidefaults.DefaultDialTimeout` was also renamed to
`apidefaults.DefaultIOTimeout` to better reflect its usage.
  • Loading branch information
rosstimothy committed Mar 10, 2023
1 parent 9cfe32a commit c6d6c67
Show file tree
Hide file tree
Showing 23 changed files with 40 additions and 30 deletions.
2 changes: 1 addition & 1 deletion api/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ func (c *Config) CheckAndSetDefaults() error {
c.KeepAliveCount = defaults.KeepAliveCountMax
}
if c.DialTimeout == 0 {
c.DialTimeout = defaults.DefaultDialTimeout
c.DialTimeout = defaults.DefaultIOTimeout
}
if c.CircuitBreakerConfig.Trip == nil || c.CircuitBreakerConfig.IsSuccessful == nil {
c.CircuitBreakerConfig = breaker.DefaultBreakerConfig(clockwork.NewRealClock())
Expand Down
3 changes: 2 additions & 1 deletion api/client/webclient/webclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func (c *Config) CheckAndSetDefaults() error {
return trace.BadParameter(message, "missing parameter ProxyAddr")
}
if c.Timeout == 0 {
c.Timeout = defaults.DefaultDialTimeout
c.Timeout = defaults.DefaultIOTimeout
}
if c.TraceProvider == nil {
c.TraceProvider = tracing.DefaultProvider()
Expand All @@ -100,6 +100,7 @@ func newWebClient(cfg *Config) (*http.Client, error) {
Proxy: func(req *http.Request) (*url.URL, error) {
return httpproxy.FromEnvironment().ProxyFunc()(req.URL)
},
IdleConnTimeout: defaults.DefaultIOTimeout,
}, nil)

return &http.Client{
Expand Down
8 changes: 5 additions & 3 deletions api/defaults/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@ const (
// Namespace is default namespace
Namespace = "default"

// DefaultDialTimeout is a default TCP dial timeout we set for our
// connection attempts
DefaultDialTimeout = 30 * time.Second
// DefaultIOTimeout is a default network IO timeout.
DefaultIOTimeout = 30 * time.Second

// DefaultIdleTimeout is a default idle connection timeout.
DefaultIdleTimeout = 360 * time.Second

// KeepAliveCountMax is the number of keep-alive messages that can be sent
// without receiving a response from the client before the client is
Expand Down
2 changes: 1 addition & 1 deletion api/utils/sshutils/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func ProxyClientSSHConfig(sshCert *ssh.Certificate, priv crypto.Signer, sshCAs .

cfg := &ssh.ClientConfig{
Auth: []ssh.AuthMethod{authMethod},
Timeout: defaults.DefaultDialTimeout,
Timeout: defaults.DefaultIOTimeout,
}

// The KeyId is not always a valid principal, so we use the first valid principal instead.
Expand Down
2 changes: 1 addition & 1 deletion lib/auth/clt.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ func NewHTTPClient(cfg client.Config, tls *tls.Config, params ...roundtrip.Clien
// custom DialContext overrides this DNS name to the real address.
// In addition this dialer tries multiple addresses if provided
DialContext: dialer.DialContext,
ResponseHeaderTimeout: apidefaults.DefaultDialTimeout,
ResponseHeaderTimeout: apidefaults.DefaultIOTimeout,
TLSClientConfig: tls,

// Increase the size of the connection pool. This substantially improves the
Expand Down
2 changes: 1 addition & 1 deletion lib/auth/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,7 @@ func (i *Identity) SSHClientConfig(fips bool) (*ssh.ClientConfig, error) {
User: i.ID.HostUUID,
Auth: []ssh.AuthMethod{ssh.PublicKeys(i.KeySigner)},
HostKeyCallback: callback,
Timeout: apidefaults.DefaultDialTimeout,
Timeout: apidefaults.DefaultIOTimeout,
}, nil
}

Expand Down
3 changes: 2 additions & 1 deletion lib/auth/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ func NewTLSServer(cfg TLSServerConfig) (*TLSServer, error) {
cfg: cfg,
httpServer: &http.Server{
Handler: httplib.MakeTracingHandler(limiter, teleport.ComponentAuth),
ReadHeaderTimeout: apidefaults.DefaultDialTimeout,
ReadHeaderTimeout: apidefaults.DefaultIOTimeout,
IdleTimeout: apidefaults.DefaultIdleTimeout,
},
log: logrus.WithFields(logrus.Fields{
trace.Component: cfg.Component,
Expand Down
2 changes: 1 addition & 1 deletion lib/backend/etcdbk/etcd.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ func (cfg *Config) Validate() error {
cfg.BufferSize = backend.DefaultBufferCapacity
}
if cfg.DialTimeout == 0 {
cfg.DialTimeout = apidefaults.DefaultDialTimeout
cfg.DialTimeout = apidefaults.DefaultIOTimeout
}
if cfg.PasswordFile != "" {
out, err := os.ReadFile(cfg.PasswordFile)
Expand Down
2 changes: 1 addition & 1 deletion lib/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ func (proxy *ProxyClient) GetSites(ctx context.Context) ([]types.Site, error) {
}()
select {
case <-done:
case <-time.After(apidefaults.DefaultDialTimeout):
case <-time.After(apidefaults.DefaultIOTimeout):
return nil, trace.ConnectionProblem(nil, "timeout")
}
log.Debugf("Found clusters: %v", stdout.String())
Expand Down
2 changes: 1 addition & 1 deletion lib/client/conntest/connection_tester.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func (r *TestConnectionRequest) CheckAndSetDefaults() error {
}

if r.DialTimeout <= 0 {
r.DialTimeout = defaults.DefaultDialTimeout
r.DialTimeout = defaults.DefaultIOTimeout
}

return nil
Expand Down
2 changes: 1 addition & 1 deletion lib/events/filesessions/fileasync.go
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ func (u *Uploader) upload(ctx context.Context, up *upload) error {
// before the files are closed to avoid async writes
// the timeout is a defensive measure to avoid blocking
// indefinitely in case of unforeseen error (e.g. write taking too long)
wctx, wcancel := context.WithTimeout(ctx, apidefaults.DefaultDialTimeout)
wctx, wcancel := context.WithTimeout(ctx, apidefaults.DefaultIOTimeout)
defer wcancel()

<-wctx.Done()
Expand Down
3 changes: 2 additions & 1 deletion lib/kube/proxy/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ func NewTLSServer(cfg TLSServerConfig) (*TLSServer, error) {
TLSServerConfig: cfg,
Server: &http.Server{
Handler: httplib.MakeTracingHandler(limiter, teleport.ComponentKube),
ReadHeaderTimeout: apidefaults.DefaultDialTimeout * 2,
ReadHeaderTimeout: apidefaults.DefaultIOTimeout * 2,
IdleTimeout: apidefaults.DefaultIdleTimeout,
TLSConfig: cfg.TLS,
},
}
Expand Down
4 changes: 2 additions & 2 deletions lib/reversetunnel/agent_dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (d *agentDialer) DialContext(ctx context.Context, addr utils.NetAddr) (SSHC
for _, authMethod := range d.authMethods {
// Create a dialer (that respects HTTP proxies) and connect to remote host.
dialer := proxy.DialerFromEnvironment(addr.Addr, d.options...)
pconn, err := dialer.DialTimeout(ctx, addr.AddrNetwork, addr.Addr, apidefaults.DefaultDialTimeout)
pconn, err := dialer.DialTimeout(ctx, addr.AddrNetwork, addr.Addr, apidefaults.DefaultIOTimeout)
if err != nil {
d.log.WithError(err).Debugf("Failed to dial %s.", addr.Addr)
continue
Expand All @@ -75,7 +75,7 @@ func (d *agentDialer) DialContext(ctx context.Context, addr utils.NetAddr) (SSHC
User: d.username,
Auth: []ssh.AuthMethod{authMethod},
HostKeyCallback: callback,
Timeout: apidefaults.DefaultDialTimeout,
Timeout: apidefaults.DefaultIOTimeout,
})
if err != nil {
d.log.WithError(err).Debugf("Failed to create client to %v.", addr.Addr)
Expand Down
2 changes: 1 addition & 1 deletion lib/reversetunnel/agentpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ func (c *agentPoolRuntimeConfig) updateRemote(ctx context.Context, addr *utils.N

c.mu.RUnlock()

ctx, cancel := context.WithTimeout(ctx, defaults.DefaultDialTimeout)
ctx, cancel := context.WithTimeout(ctx, defaults.DefaultIOTimeout)
defer cancel()

tlsRoutingEnabled := false
Expand Down
4 changes: 2 additions & 2 deletions lib/reversetunnel/localsite.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ func (s *localSite) DialAuthServer() (net.Conn, error) {
}

addr := utils.ChooseRandomString(s.authServers)
conn, err := net.DialTimeout("tcp", addr, apidefaults.DefaultDialTimeout)
conn, err := net.DialTimeout("tcp", addr, apidefaults.DefaultIOTimeout)
if err != nil {
return nil, trace.ConnectionProblem(err, "unable to connect to auth server")
}
Expand Down Expand Up @@ -498,7 +498,7 @@ func (s *localSite) getConn(params DialParams) (conn net.Conn, useTunnel bool, e

// If no tunnel connection was found, dial to the target host.
dialer := proxyutils.DialerFromEnvironment(params.To.String())
conn, directErr = dialer.DialTimeout(s.srv.Context, params.To.Network(), params.To.String(), apidefaults.DefaultDialTimeout)
conn, directErr = dialer.DialTimeout(s.srv.Context, params.To.Network(), params.To.String(), apidefaults.DefaultIOTimeout)
if directErr != nil {
directMsg := getTunnelErrorMessage(params, "direct dial", directErr)
s.log.WithError(directErr).WithField("address", params.To.String()).Debug("Error occurred while dialing directly.")
Expand Down
4 changes: 2 additions & 2 deletions lib/reversetunnel/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ func (p *transport) start() {
if req == nil {
return
}
case <-time.After(apidefaults.DefaultDialTimeout):
case <-time.After(apidefaults.DefaultIOTimeout):
p.log.Warnf("Transport request failed: timed out waiting for request.")
return
}
Expand Down Expand Up @@ -440,7 +440,7 @@ func (p *transport) directDial(addr string) (net.Conn, error) {
}

d := net.Dialer{
Timeout: apidefaults.DefaultDialTimeout,
Timeout: apidefaults.DefaultIOTimeout,
}
conn, err := d.DialContext(p.closeContext, "tcp", addr)
if err != nil {
Expand Down
10 changes: 7 additions & 3 deletions lib/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -2666,6 +2666,7 @@ func (process *TeleportProcess) initMetricsService() error {
server := &http.Server{
Handler: mux,
ReadHeaderTimeout: defaults.ReadHeadersTimeout,
IdleTimeout: apidefaults.DefaultIdleTimeout,
ErrorLog: utils.NewStdlogger(log.Error, teleport.ComponentMetrics),
TLSConfig: tlsConfig,
}
Expand Down Expand Up @@ -2785,7 +2786,8 @@ func (process *TeleportProcess) initDiagnosticService() error {

server := &http.Server{
Handler: mux,
ReadHeaderTimeout: apidefaults.DefaultDialTimeout,
ReadHeaderTimeout: apidefaults.DefaultIOTimeout,
IdleTimeout: apidefaults.DefaultIdleTimeout,
ErrorLog: utils.NewStdlogger(log.Error, teleport.ComponentDiagnostic),
}

Expand Down Expand Up @@ -3627,7 +3629,8 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {

webServer = &http.Server{
Handler: httplib.MakeTracingHandler(proxyLimiter, teleport.ComponentProxy),
ReadHeaderTimeout: apidefaults.DefaultDialTimeout,
ReadHeaderTimeout: apidefaults.DefaultIOTimeout,
IdleTimeout: apidefaults.DefaultIdleTimeout,
ErrorLog: utils.NewStdlogger(log.Error, teleport.ComponentProxy),
}
process.RegisterCriticalFunc("proxy.web", func() error {
Expand Down Expand Up @@ -4124,7 +4127,8 @@ func (process *TeleportProcess) initMinimalReverseTunnel(listeners *proxyListene

minimalWebServer = &http.Server{
Handler: httplib.MakeTracingHandler(minimalProxyLimiter, teleport.ComponentProxy),
ReadHeaderTimeout: apidefaults.DefaultDialTimeout,
ReadHeaderTimeout: apidefaults.DefaultIOTimeout,
IdleTimeout: apidefaults.DefaultIdleTimeout,
ErrorLog: utils.NewStdlogger(log.Error, teleport.ComponentReverseTunnelServer),
}
process.RegisterCriticalFunc("proxy.reversetunnel.web", func() error {
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/alpnproxy/auth/auth_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func (s *AuthProxyDialerService) dialLocalAuthServer(ctx context.Context) (net.C

addr := utils.ChooseRandomString(s.authServers)
d := &net.Dialer{
Timeout: defaults.DefaultDialTimeout,
Timeout: defaults.DefaultIOTimeout,
}
conn, err := d.DialContext(ctx, "tcp", addr)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/alpnproxy/conn_upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ import (
// preserve the ALPN and SNI information.
func IsALPNConnUpgradeRequired(addr string, insecure bool) bool {
netDialer := &net.Dialer{
Timeout: defaults.DefaultDialTimeout,
Timeout: defaults.DefaultIOTimeout,
}
tlsConfig := &tls.Config{
NextProtos: []string{string(common.ProtocolReverseTunnel)},
Expand Down
3 changes: 2 additions & 1 deletion lib/srv/app/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1011,7 +1011,8 @@ func (s *Server) newHTTPServer() *http.Server {

return &http.Server{
Handler: httplib.MakeTracingHandler(s.authMiddleware, teleport.ComponentApp),
ReadHeaderTimeout: apidefaults.DefaultDialTimeout,
ReadHeaderTimeout: apidefaults.DefaultIOTimeout,
IdleTimeout: apidefaults.DefaultIdleTimeout,
ErrorLog: utils.NewStdlogger(s.log.Error, teleport.ComponentApp),
ConnContext: func(ctx context.Context, c net.Conn) context.Context {
return context.WithValue(ctx, connContextKey, c)
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/app/tcpserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func (s *tcpServer) handleConnection(ctx context.Context, clientConn net.Conn, i
return trace.BadParameter(`unexpected app %q address network, expected "tcp": %+v`, app.GetName(), addr)
}
dialer := net.Dialer{
Timeout: apidefaults.DefaultDialTimeout,
Timeout: apidefaults.DefaultIOTimeout,
}
serverConn, err := dialer.DialContext(ctx, addr.AddrNetwork, addr.String())
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/forward/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ func (s *Server) newRemoteClient(ctx context.Context, systemLogin string) (*trac
authMethod,
},
HostKeyCallback: s.authHandlers.HostKeyAuth,
Timeout: apidefaults.DefaultDialTimeout,
Timeout: apidefaults.DefaultIOTimeout,
}

// Ciphers, KEX, and MACs preferences are honored by both the in-memory
Expand Down
2 changes: 1 addition & 1 deletion lib/tbot/identity/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ func (i *Identity) SSHClientConfig() (*ssh.ClientConfig, error) {
User: i.SSHCert.ValidPrincipals[0],
Auth: []ssh.AuthMethod{ssh.PublicKeys(i.KeySigner)},
HostKeyCallback: callback,
Timeout: apidefaults.DefaultDialTimeout,
Timeout: apidefaults.DefaultIOTimeout,
}, nil
}

Expand Down

0 comments on commit c6d6c67

Please sign in to comment.