diff --git a/tlsdialer.go b/tlsdialer.go index c359a70..6f72e67 100644 --- a/tlsdialer.go +++ b/tlsdialer.go @@ -78,17 +78,21 @@ func DialWithDialer(dialer *net.Dialer, network, addr string, sendServerName boo serverName = hostname } - c := new(tls.Config) - *c = *config + // copy config so we can tweak it + configCopy := new(tls.Config) + *configCopy = *config if sendServerName { - config.ServerName = serverName + // Set the ServerName and rely on the usual logic in + // tls.Conn.Handshake() to do its verification + configCopy.ServerName = serverName } else { - // Don't verify, we'll verify manually after handshaking - config.InsecureSkipVerify = true + // Disable verification in tls.Conn.Handshake(). We'll verify manually + // after handshaking + configCopy.InsecureSkipVerify = true } - conn := tls.Client(rawConn, config) + conn := tls.Client(rawConn, configCopy) if timeout == 0 { err = conn.Handshake() @@ -96,13 +100,12 @@ func DialWithDialer(dialer *net.Dialer, network, addr string, sendServerName boo go func() { errChannel <- conn.Handshake() }() - err = <-errChannel } - if !sendServerName && err == nil && !c.InsecureSkipVerify { + if !sendServerName && err == nil && !config.InsecureSkipVerify { // Manually verify certificates - err = verifyServerCerts(conn, serverName, config) + err = verifyServerCerts(conn, serverName, configCopy) } if err != nil { rawConn.Close() diff --git a/tlsdialer_test.go b/tlsdialer_test.go index 48a0acf..bba1bdb 100644 --- a/tlsdialer_test.go +++ b/tlsdialer_test.go @@ -76,8 +76,25 @@ func TestOKWithServerName(t *testing.T) { } func TestOKWithoutServerName(t *testing.T) { - _, err := Dial("tcp", ADDR, false, &tls.Config{ + config := &tls.Config{ RootCAs: cert.PoolContainingCert(), + } + _, err := Dial("tcp", ADDR, false, config) + if err != nil { + t.Errorf("Unable to dial: %s", err.Error()) + } + serverName := <-receivedServerNames + if serverName != "" { + t.Errorf("Unexpected ServerName on server: %s", serverName) + } + if config.InsecureSkipVerify { + t.Errorf("Original config shouldn't have been modified, but it was") + } +} + +func TestOKWithInsecureSkipVerify(t *testing.T) { + _, err := Dial("tcp", ADDR, false, &tls.Config{ + InsecureSkipVerify: true, }) if err != nil { t.Errorf("Unable to dial: %s", err.Error())