From ce997a614017252a3fd8cbd7cd5258093d798074 Mon Sep 17 00:00:00 2001 From: ramesh Date: Tue, 14 May 2024 13:58:51 -0700 Subject: [PATCH] add unit testing and documentation --- security/advancedtls/advancedtls.go | 9 ++- security/advancedtls/advancedtls_test.go | 71 ++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 1 deletion(-) diff --git a/security/advancedtls/advancedtls.go b/security/advancedtls/advancedtls.go index 91886be4a6e..2fcc975906e 100644 --- a/security/advancedtls/advancedtls.go +++ b/security/advancedtls/advancedtls.go @@ -556,6 +556,11 @@ func (c *advancedTLSCreds) OverrideServerName(serverNameOverride string) error { // 1. does not have a good support on root cert reloading. // 2. will ignore basic certificate check when setting InsecureSkipVerify // to true. +// +// peerVerifiedChains(output param): verified chain of certs from leaf to the +// trust cert that the peer trusts. +// 1. For server it is, client certs + Root ca that the server trusts +// 2. For client it is, server certs + Root ca that the client trusts func buildVerifyFunc(c *advancedTLSCreds, serverName string, rawConn net.Conn, @@ -637,7 +642,9 @@ func buildVerifyFunc(c *advancedTLSCreds, VerifiedChains: chains, Leaf: leafCert, }) - return err + if err != nil { + return err + } } *peerVerifiedChains = chains return nil diff --git a/security/advancedtls/advancedtls_test.go b/security/advancedtls/advancedtls_test.go index e77e0f5e981..39e6f422da8 100644 --- a/security/advancedtls/advancedtls_test.go +++ b/security/advancedtls/advancedtls_test.go @@ -19,6 +19,7 @@ package advancedtls import ( + "bytes" "context" "crypto/tls" "crypto/x509" @@ -896,6 +897,76 @@ func (s) TestClientServerHandshake(t *testing.T) { t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr, clientAuthInfo, serverAuthInfo) } + serverVerifiedChains := serverAuthInfo.(credentials.TLSInfo).State.VerifiedChains + if test.serverMutualTLS && !test.serverExpectError { + if len(serverVerifiedChains) == 0 { + t.Fatalf("server verified chains is empty") + } + var clientCert *tls.Certificate + if len(test.clientCert) > 0 { + clientCert = &test.clientCert[0] + } else if test.clientGetCert != nil { + cert, _ := test.clientGetCert(&tls.CertificateRequestInfo{}) + clientCert = cert + } else if test.clientIdentityProvider != nil { + km, _ := test.clientIdentityProvider.KeyMaterial(nil) + clientCert = &km.Certs[0] + } + if !bytes.Equal((*serverVerifiedChains[0][0]).Raw, clientCert.Certificate[0]) { + t.Fatal("server verifiedChains leaf cert doesn't match client cert") + } + + var serverRoot *x509.CertPool + if test.serverRoot != nil { + serverRoot = test.serverRoot + } else if test.serverGetRoot != nil { + result, _ := test.serverGetRoot(&GetRootCAsParams{}) + serverRoot = result.TrustCerts + } else if test.serverRootProvider != nil { + km, _ := test.serverRootProvider.KeyMaterial(nil) + serverRoot = km.Roots + } + serverVerifiedChainsCp := x509.NewCertPool() + serverVerifiedChainsCp.AddCert(serverVerifiedChains[0][len(serverVerifiedChains[0])-1]) + if !serverVerifiedChainsCp.Equal(serverRoot) { + t.Fatalf("server verified chain hierarchy doesn't match") + } + } + clientVerifiedChains := clientAuthInfo.(credentials.TLSInfo).State.VerifiedChains + if test.serverMutualTLS && !test.clientExpectHandshakeError { + if len(clientVerifiedChains) == 0 { + t.Fatalf("client verified chains is empty") + } + var serverCert *tls.Certificate + if len(test.serverCert) > 0 { + serverCert = &test.serverCert[0] + } else if test.serverGetCert != nil { + cert, _ := test.serverGetCert(&tls.ClientHelloInfo{}) + serverCert = cert[0] + } else if test.serverIdentityProvider != nil { + km, _ := test.serverIdentityProvider.KeyMaterial(nil) + serverCert = &km.Certs[0] + } + if !bytes.Equal((*clientVerifiedChains[0][0]).Raw, serverCert.Certificate[0]) { + t.Fatal("client verifiedChains leaf cert doesn't match server cert") + } + + var clientRoot *x509.CertPool + if test.clientRoot != nil { + clientRoot = test.clientRoot + } else if test.clientGetRoot != nil { + result, _ := test.clientGetRoot(&GetRootCAsParams{}) + clientRoot = result.TrustCerts + } else if test.clientRootProvider != nil { + km, _ := test.clientRootProvider.KeyMaterial(nil) + clientRoot = km.Roots + } + clientVerifiedChainsCp := x509.NewCertPool() + clientVerifiedChainsCp.AddCert(clientVerifiedChains[0][len(clientVerifiedChains[0])-1]) + if !clientVerifiedChainsCp.Equal(clientRoot) { + t.Fatalf("client verified chain hierarchy doesn't match") + } + } }) } }