Skip to content

Commit

Permalink
add unit testing and documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
mudhireddy committed May 14, 2024
1 parent 6653a37 commit ce997a6
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 1 deletion.
9 changes: 8 additions & 1 deletion security/advancedtls/advancedtls.go
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,11 @@ func (c *advancedTLSCreds) OverrideServerName(serverNameOverride string) error {
// 1. does not have a good support on root cert reloading.
// 2. will ignore basic certificate check when setting InsecureSkipVerify
// to true.
//
// peerVerifiedChains(output param): verified chain of certs from leaf to the
// trust cert that the peer trusts.
// 1. For server it is, client certs + Root ca that the server trusts
// 2. For client it is, server certs + Root ca that the client trusts
func buildVerifyFunc(c *advancedTLSCreds,
serverName string,
rawConn net.Conn,
Expand Down Expand Up @@ -637,7 +642,9 @@ func buildVerifyFunc(c *advancedTLSCreds,
VerifiedChains: chains,
Leaf: leafCert,
})
return err
if err != nil {
return err
}
}
*peerVerifiedChains = chains
return nil
Expand Down
71 changes: 71 additions & 0 deletions security/advancedtls/advancedtls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package advancedtls

import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
Expand Down Expand Up @@ -896,6 +897,76 @@ func (s) TestClientServerHandshake(t *testing.T) {
t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr,
clientAuthInfo, serverAuthInfo)
}
serverVerifiedChains := serverAuthInfo.(credentials.TLSInfo).State.VerifiedChains
if test.serverMutualTLS && !test.serverExpectError {
if len(serverVerifiedChains) == 0 {
t.Fatalf("server verified chains is empty")
}
var clientCert *tls.Certificate
if len(test.clientCert) > 0 {
clientCert = &test.clientCert[0]
} else if test.clientGetCert != nil {
cert, _ := test.clientGetCert(&tls.CertificateRequestInfo{})
clientCert = cert
} else if test.clientIdentityProvider != nil {
km, _ := test.clientIdentityProvider.KeyMaterial(nil)
clientCert = &km.Certs[0]
}
if !bytes.Equal((*serverVerifiedChains[0][0]).Raw, clientCert.Certificate[0]) {
t.Fatal("server verifiedChains leaf cert doesn't match client cert")
}

var serverRoot *x509.CertPool
if test.serverRoot != nil {
serverRoot = test.serverRoot
} else if test.serverGetRoot != nil {
result, _ := test.serverGetRoot(&GetRootCAsParams{})
serverRoot = result.TrustCerts
} else if test.serverRootProvider != nil {
km, _ := test.serverRootProvider.KeyMaterial(nil)
serverRoot = km.Roots
}
serverVerifiedChainsCp := x509.NewCertPool()
serverVerifiedChainsCp.AddCert(serverVerifiedChains[0][len(serverVerifiedChains[0])-1])
if !serverVerifiedChainsCp.Equal(serverRoot) {
t.Fatalf("server verified chain hierarchy doesn't match")
}
}
clientVerifiedChains := clientAuthInfo.(credentials.TLSInfo).State.VerifiedChains
if test.serverMutualTLS && !test.clientExpectHandshakeError {
if len(clientVerifiedChains) == 0 {
t.Fatalf("client verified chains is empty")
}
var serverCert *tls.Certificate
if len(test.serverCert) > 0 {
serverCert = &test.serverCert[0]
} else if test.serverGetCert != nil {
cert, _ := test.serverGetCert(&tls.ClientHelloInfo{})
serverCert = cert[0]
} else if test.serverIdentityProvider != nil {
km, _ := test.serverIdentityProvider.KeyMaterial(nil)
serverCert = &km.Certs[0]
}
if !bytes.Equal((*clientVerifiedChains[0][0]).Raw, serverCert.Certificate[0]) {
t.Fatal("client verifiedChains leaf cert doesn't match server cert")
}

var clientRoot *x509.CertPool
if test.clientRoot != nil {
clientRoot = test.clientRoot
} else if test.clientGetRoot != nil {
result, _ := test.clientGetRoot(&GetRootCAsParams{})
clientRoot = result.TrustCerts
} else if test.clientRootProvider != nil {
km, _ := test.clientRootProvider.KeyMaterial(nil)
clientRoot = km.Roots
}
clientVerifiedChainsCp := x509.NewCertPool()
clientVerifiedChainsCp.AddCert(clientVerifiedChains[0][len(clientVerifiedChains[0])-1])
if !clientVerifiedChainsCp.Equal(clientRoot) {
t.Fatalf("client verified chain hierarchy doesn't match")
}
}
})
}
}
Expand Down

0 comments on commit ce997a6

Please sign in to comment.