Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add a timeout for TLS handshake #176

Merged
merged 1 commit into from Sep 5, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
101 changes: 62 additions & 39 deletions rpc/connection.go
Expand Up @@ -197,14 +197,15 @@ type ConnectionTransportTLS struct {
dialable Dialable

// Protects everything below.
mutex sync.Mutex
transport Transporter
stagedTransport Transporter
conn net.Conn
dialerTimeout time.Duration
logFactory LogFactory
wef WrapErrorFunc
log ConnectionLog
mutex sync.Mutex
transport Transporter
stagedTransport Transporter
conn net.Conn
dialerTimeout time.Duration
handshakeTimeout time.Duration
logFactory LogFactory
wef WrapErrorFunc
log ConnectionLog
}

// Test that ConnectionTransportTLS fully implements the ConnectionTransport interface.
Expand Down Expand Up @@ -274,8 +275,23 @@ func (ct *ConnectionTransportTLS) Dial(ctx context.Context) (
LogField{Key: "local-addr", Value: baseConn.LocalAddr()},
LogField{Key: ConnectionLogMsgKey, Value: "Handshake"})
conn := tls.Client(baseConn, config)
if err := conn.Handshake(); err != nil {
return nil, err

// run TLS handshake with a timeout
errCh := make(chan error, 1)
go func() {
errCh <- conn.Handshake()
}()
handshakeTimeout := ct.handshakeTimeout
if handshakeTimeout == 0 {
handshakeTimeout = time.Minute
}
select {
case err := <-errCh:
if err != nil {
return nil, err
}
case <-time.After(handshakeTimeout):
return nil, errors.New("handshake timeout")
}
ct.log.Debug("%s", LogField{Key: ConnectionLogMsgKey, Value: "Handshaken"})

Expand Down Expand Up @@ -391,6 +407,9 @@ type ConnectionOpts struct {
// connections. Zero value is passed as-is to net.Dialer, which means no
// timeout. Note that OS may impose its own timeout.
DialerTimeout time.Duration
// HandshakeTimeout is a timeout on how long we wait for TLS handshake to
// complete. If no value specified, we default to time.Minute.
HandshakeTimeout time.Duration
}

// NewTLSConnectionWithConnectionLogFactory is like NewTLSConnection,
Expand All @@ -406,13 +425,14 @@ func NewTLSConnectionWithConnectionLogFactory(
opts ConnectionOpts,
) *Connection {
transport := &ConnectionTransportTLS{
rootCerts: rootCerts,
srvRemote: srvRemote,
maxFrameLength: maxFrameLength,
logFactory: logFactory,
wef: opts.WrapErrorFunc,
dialerTimeout: opts.DialerTimeout,
log: connectionLogFactory.Make("conn_tspt"),
rootCerts: rootCerts,
srvRemote: srvRemote,
maxFrameLength: maxFrameLength,
logFactory: logFactory,
wef: opts.WrapErrorFunc,
dialerTimeout: opts.DialerTimeout,
handshakeTimeout: opts.HandshakeTimeout,
log: connectionLogFactory.Make("conn_tspt"),
}
connLog := connectionLogFactory.Make("conn")
return newConnectionWithTransportAndProtocolsWithLog(
Expand All @@ -432,13 +452,14 @@ func NewTLSConnection(
opts ConnectionOpts,
) *Connection {
transport := &ConnectionTransportTLS{
rootCerts: rootCerts,
srvRemote: srvRemote,
maxFrameLength: maxFrameLength,
logFactory: logFactory,
wef: opts.WrapErrorFunc,
dialerTimeout: opts.DialerTimeout,
log: newConnectionLogUnstructured(logOutput, "CONNTSPT"),
rootCerts: rootCerts,
srvRemote: srvRemote,
maxFrameLength: maxFrameLength,
logFactory: logFactory,
wef: opts.WrapErrorFunc,
dialerTimeout: opts.DialerTimeout,
handshakeTimeout: opts.HandshakeTimeout,
log: newConnectionLogUnstructured(logOutput, "CONNTSPT"),
}
return newConnectionWithTransportAndProtocols(handler, transport, errorUnwrapper, logOutput, opts)
}
Expand All @@ -456,13 +477,14 @@ func NewTLSConnectionWithTLSConfig(
opts ConnectionOpts,
) *Connection {
transport := &ConnectionTransportTLS{
srvRemote: srvRemote,
tlsConfig: copyTLSConfig(tlsConfig),
maxFrameLength: maxFrameLength,
logFactory: logFactory,
wef: opts.WrapErrorFunc,
dialerTimeout: opts.DialerTimeout,
log: newConnectionLogUnstructured(logOutput, "CONNTSPT"),
srvRemote: srvRemote,
tlsConfig: copyTLSConfig(tlsConfig),
maxFrameLength: maxFrameLength,
logFactory: logFactory,
wef: opts.WrapErrorFunc,
dialerTimeout: opts.DialerTimeout,
handshakeTimeout: opts.HandshakeTimeout,
log: newConnectionLogUnstructured(logOutput, "CONNTSPT"),
}
return newConnectionWithTransportAndProtocols(handler, transport, errorUnwrapper, logOutput, opts)
}
Expand All @@ -481,14 +503,15 @@ func NewTLSConnectionWithDialable(
dialable Dialable,
) *Connection {
transport := &ConnectionTransportTLS{
rootCerts: rootCerts,
srvRemote: srvRemote,
maxFrameLength: maxFrameLength,
logFactory: logFactory,
wef: opts.WrapErrorFunc,
dialerTimeout: opts.DialerTimeout,
log: newConnectionLogUnstructured(logOutput, "CONNTSPT"),
dialable: dialable,
rootCerts: rootCerts,
srvRemote: srvRemote,
maxFrameLength: maxFrameLength,
logFactory: logFactory,
wef: opts.WrapErrorFunc,
dialerTimeout: opts.DialerTimeout,
handshakeTimeout: opts.HandshakeTimeout,
log: newConnectionLogUnstructured(logOutput, "CONNTSPT"),
dialable: dialable,
}
return newConnectionWithTransportAndProtocols(handler, transport, errorUnwrapper, logOutput, opts)
}
Expand Down