Skip to content

Commit

Permalink
crypto/tls: set CipherSuite for VerifyConnection
Browse files Browse the repository at this point in the history
The ConnectionState's CipherSuite was not set prior
to the VerifyConnection callback in TLS 1.2 servers,
both for full handshakes and resumptions.

Change-Id: Iab91783eff84d1b42ca09c8df08e07861e18da30
Reviewed-on: https://go-review.googlesource.com/c/go/+/236558
Run-TryBot: Katie Hockman <katie@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
  • Loading branch information
katiehockman committed Jun 4, 2020
1 parent 07ced37 commit fb86c70
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 15 deletions.
31 changes: 17 additions & 14 deletions src/crypto/tls/handshake_client_test.go
Expand Up @@ -1470,25 +1470,28 @@ func TestVerifyConnection(t *testing.T) {
}

func testVerifyConnection(t *testing.T, version uint16) {
checkFields := func(c ConnectionState, called *int) error {
checkFields := func(c ConnectionState, called *int, errorType string) error {
if c.Version != version {
return fmt.Errorf("got Version %v, want %v", c.Version, version)
return fmt.Errorf("%s: got Version %v, want %v", errorType, c.Version, version)
}
if c.HandshakeComplete {
return fmt.Errorf("got HandshakeComplete, want false")
return fmt.Errorf("%s: got HandshakeComplete, want false", errorType)
}
if c.ServerName != "example.golang" {
return fmt.Errorf("got ServerName %s, want %s", c.ServerName, "example.golang")
return fmt.Errorf("%s: got ServerName %s, want %s", errorType, c.ServerName, "example.golang")
}
if c.NegotiatedProtocol != "protocol1" {
return fmt.Errorf("got NegotiatedProtocol %s, want %s", c.NegotiatedProtocol, "protocol1")
return fmt.Errorf("%s: got NegotiatedProtocol %s, want %s", errorType, c.NegotiatedProtocol, "protocol1")
}
if c.CipherSuite == 0 {
return fmt.Errorf("%s: got CipherSuite 0, want non-zero", errorType)
}
wantDidResume := false
if *called == 2 { // if this is the second time, then it should be a resumption
wantDidResume = true
}
if c.DidResume != wantDidResume {
return fmt.Errorf("got DidResume %t, want %t", c.DidResume, wantDidResume)
return fmt.Errorf("%s: got DidResume %t, want %t", errorType, c.DidResume, wantDidResume)
}
return nil
}
Expand All @@ -1510,7 +1513,7 @@ func testVerifyConnection(t *testing.T, version uint16) {
if len(c.VerifiedChains) == 0 {
return fmt.Errorf("server: got len(VerifiedChains) = 0, wanted non-zero")
}
return checkFields(c, called)
return checkFields(c, called, "server")
}
},
configureClient: func(config *Config, called *int) {
Expand All @@ -1533,7 +1536,7 @@ func testVerifyConnection(t *testing.T, version uint16) {
if len(c.SignedCertificateTimestamps) == 0 {
return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
}
return checkFields(c, called)
return checkFields(c, called, "client")
}
},
},
Expand All @@ -1550,7 +1553,7 @@ func testVerifyConnection(t *testing.T, version uint16) {
if c.VerifiedChains != nil {
return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains)
}
return checkFields(c, called)
return checkFields(c, called, "server")
}
},
configureClient: func(config *Config, called *int) {
Expand All @@ -1574,7 +1577,7 @@ func testVerifyConnection(t *testing.T, version uint16) {
if len(c.SignedCertificateTimestamps) == 0 {
return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
}
return checkFields(c, called)
return checkFields(c, called, "client")
}
},
},
Expand All @@ -1584,13 +1587,13 @@ func testVerifyConnection(t *testing.T, version uint16) {
config.ClientAuth = NoClientCert
config.VerifyConnection = func(c ConnectionState) error {
*called++
return checkFields(c, called)
return checkFields(c, called, "server")
}
},
configureClient: func(config *Config, called *int) {
config.VerifyConnection = func(c ConnectionState) error {
*called++
return checkFields(c, called)
return checkFields(c, called, "client")
}
},
},
Expand All @@ -1600,7 +1603,7 @@ func testVerifyConnection(t *testing.T, version uint16) {
config.ClientAuth = RequestClientCert
config.VerifyConnection = func(c ConnectionState) error {
*called++
return checkFields(c, called)
return checkFields(c, called, "server")
}
},
configureClient: func(config *Config, called *int) {
Expand All @@ -1624,7 +1627,7 @@ func testVerifyConnection(t *testing.T, version uint16) {
if len(c.SignedCertificateTimestamps) == 0 {
return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
}
return checkFields(c, called)
return checkFields(c, called, "client")
}
},
},
Expand Down
3 changes: 2 additions & 1 deletion src/crypto/tls/handshake_server.go
Expand Up @@ -308,6 +308,7 @@ func (hs *serverHandshakeState) pickCipherSuite() error {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: no cipher suite supported by both client and server")
}
c.cipherSuite = hs.suite.id

for _, id := range hs.clientHello.cipherSuites {
if id == TLS_FALLBACK_SCSV {
Expand Down Expand Up @@ -407,6 +408,7 @@ func (hs *serverHandshakeState) doResumeHandshake() error {
c := hs.c

hs.hello.cipherSuite = hs.suite.id
c.cipherSuite = hs.suite.id
// We echo the client's session ID in the ServerHello to let it know
// that we're doing a resumption.
hs.hello.sessionId = hs.clientHello.sessionId
Expand Down Expand Up @@ -743,7 +745,6 @@ func (hs *serverHandshakeState) sendFinished(out []byte) error {
return err
}

c.cipherSuite = hs.suite.id
copy(out, finished.verifyData)

return nil
Expand Down

0 comments on commit fb86c70

Please sign in to comment.