From 02ab92d7588b2bada24c35b7074a806f57eee1b6 Mon Sep 17 00:00:00 2001 From: Andy Zhao Date: Fri, 23 Oct 2020 15:05:49 -0700 Subject: [PATCH 1/4] feat(transport): Add default certificate caching support Cache and return the cached certificate as long as it has not expired. This avoids having to exec the cert provider command multiple times in the same session, especially when using a dial pool. --- transport/cert/default_cert.go | 42 +++++++++++++++------ transport/cert/default_cert_test.go | 57 +++++++++++++++++++++++++++-- 2 files changed, 84 insertions(+), 15 deletions(-) diff --git a/transport/cert/default_cert.go b/transport/cert/default_cert.go index c03af65fd73..839350e3084 100644 --- a/transport/cert/default_cert.go +++ b/transport/cert/default_cert.go @@ -14,6 +14,7 @@ package cert import ( "crypto/tls" + "crypto/x509" "encoding/json" "errors" "fmt" @@ -23,6 +24,7 @@ import ( "os/user" "path/filepath" "sync" + "time" ) const ( @@ -31,9 +33,11 @@ const ( ) var ( - defaultSourceOnce sync.Once - defaultSource Source - defaultSourceErr error + defaultSourceOnce sync.Once + defaultSource Source + defaultSourceErr error + defaultSourceCachedCertMutex sync.Mutex + defaultSourceCachedCert *tls.Certificate ) // Source is a function that can be passed into crypto/tls.Config.GetClientCertificate. @@ -95,16 +99,30 @@ func validateMetadata(metadata secureConnectMetadata) error { } func (s *secureConnectSource) getClientCertificate(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { - // TODO(cbro): consider caching valid certificates rather than exec'ing every time. - command := s.metadata.Cmd - data, err := exec.Command(command[0], command[1:]...).Output() - if err != nil { - // TODO(cbro): read stderr for error message? Might contain sensitive info. - return nil, err + defaultSourceCachedCertMutex.Lock() + defer defaultSourceCachedCertMutex.Unlock() + if defaultSourceCachedCert != nil && !isCertificateExpired(defaultSourceCachedCert) { + return defaultSourceCachedCert, nil + } else { + command := s.metadata.Cmd + data, err := exec.Command(command[0], command[1:]...).Output() + if err != nil { + // TODO(cbro): read stderr for error message? Might contain sensitive info. + return nil, err + } + cert, err := tls.X509KeyPair(data, data) + if err != nil { + return nil, err + } + defaultSourceCachedCert = &cert + return &cert, nil } - cert, err := tls.X509KeyPair(data, data) +} + +func isCertificateExpired(cert *tls.Certificate) bool { + parsed, err := x509.ParseCertificate(cert.Certificate[0]) if err != nil { - return nil, err + return true } - return &cert, nil + return time.Now().After(parsed.NotAfter) } diff --git a/transport/cert/default_cert_test.go b/transport/cert/default_cert_test.go index 0ec3c44b144..feee586a15d 100644 --- a/transport/cert/default_cert_test.go +++ b/transport/cert/default_cert_test.go @@ -5,31 +5,34 @@ package cert import ( + "bytes" "testing" ) func TestGetClientCertificateSuccess(t *testing.T) { + defaultSourceCachedCert = nil source := secureConnectSource{metadata: secureConnectMetadata{Cmd: []string{"cat", "testdata/testcert.pem"}}} cert, err := source.getClientCertificate(nil) if err != nil { t.Error(err) } if cert.Certificate == nil { - t.Error("want non-nil cert, got nil") + t.Error("getClientCertificate: want non-nil cert, got nil") } if cert.PrivateKey == nil { - t.Error("want non-nil PrivateKey, got nil") + t.Error("getClientCertificate: want non-nil PrivateKey, got nil") } } func TestGetClientCertificateFailure(t *testing.T) { + defaultSourceCachedCert = nil source := secureConnectSource{metadata: secureConnectMetadata{Cmd: []string{"cat"}}} _, err := source.getClientCertificate(nil) if err == nil { t.Error("Expecting error.") } if got, want := err.Error(), "tls: failed to find any PEM data in certificate input"; got != want { - t.Errorf("getClientCertificate, want %v err, got %v", want, got) + t.Errorf("getClientCertificate: want %v err, got %v", want, got) } } @@ -51,3 +54,51 @@ func TestValidateMetadataFailure(t *testing.T) { t.Errorf("validateMetadata: want %v err, got %v", want, got) } } + +func TestIsCertificateExpiredTrue(t *testing.T) { + defaultSourceCachedCert = nil + source := secureConnectSource{metadata: secureConnectMetadata{Cmd: []string{"cat", "testdata/testcert.pem"}}} + cert, err := source.getClientCertificate(nil) + if err != nil { + t.Error(err) + } + if !isCertificateExpired(cert) { + t.Error("isCertificateExpired: want true, got false") + } +} + +func TestIsCertificateExpiredFalse(t *testing.T) { + defaultSourceCachedCert = nil + source := secureConnectSource{metadata: secureConnectMetadata{Cmd: []string{"cat", "testdata/nonexpiringtestcert.pem"}}} + cert, err := source.getClientCertificate(nil) + if err != nil { + t.Error(err) + } + if isCertificateExpired(cert) { + t.Error("isCertificateExpired: want false, got true") + } +} + +func TestCertificateCaching(t *testing.T) { + defaultSourceCachedCert = nil + source := secureConnectSource{metadata: secureConnectMetadata{Cmd: []string{"cat", "testdata/nonexpiringtestcert.pem"}}} + cert, err := source.getClientCertificate(nil) + if err != nil { + t.Error(err) + } + if defaultSourceCachedCert == nil { + t.Error("getClientCertificate: want non-nil defaultSourceCachedCert, got nil") + } + + source = secureConnectSource{metadata: secureConnectMetadata{Cmd: []string{"cat", "testdata/testcert.pem"}}} + cert, err = source.getClientCertificate(nil) + if err != nil { + t.Error(err) + } + if bytes.Compare(cert.Certificate[0], defaultSourceCachedCert.Certificate[0]) != 0 { + t.Error("getClientCertificate: want cached Certificate, got different Certificate") + } + if cert.PrivateKey != defaultSourceCachedCert.PrivateKey { + t.Error("getClientCertificate: want cached PrivateKey, got different PrivateKey") + } +} From 5e5fdfd518679918bb9bb9d3bf44d6bbcf164b2e Mon Sep 17 00:00:00 2001 From: Andy Zhao Date: Tue, 27 Oct 2020 16:06:53 -0700 Subject: [PATCH 2/4] feat(transport): Fix formatting. Upload nonexpiringtestcert.pem --- transport/cert/default_cert.go | 25 +++++----- .../cert/testdata/nonexpiringtestcert.pem | 50 +++++++++++++++++++ 2 files changed, 62 insertions(+), 13 deletions(-) create mode 100644 transport/cert/testdata/nonexpiringtestcert.pem diff --git a/transport/cert/default_cert.go b/transport/cert/default_cert.go index 839350e3084..111c5e07590 100644 --- a/transport/cert/default_cert.go +++ b/transport/cert/default_cert.go @@ -103,20 +103,19 @@ func (s *secureConnectSource) getClientCertificate(info *tls.CertificateRequestI defer defaultSourceCachedCertMutex.Unlock() if defaultSourceCachedCert != nil && !isCertificateExpired(defaultSourceCachedCert) { return defaultSourceCachedCert, nil - } else { - command := s.metadata.Cmd - data, err := exec.Command(command[0], command[1:]...).Output() - if err != nil { - // TODO(cbro): read stderr for error message? Might contain sensitive info. - return nil, err - } - cert, err := tls.X509KeyPair(data, data) - if err != nil { - return nil, err - } - defaultSourceCachedCert = &cert - return &cert, nil } + command := s.metadata.Cmd + data, err := exec.Command(command[0], command[1:]...).Output() + if err != nil { + // TODO(cbro): read stderr for error message? Might contain sensitive info. + return nil, err + } + cert, err := tls.X509KeyPair(data, data) + if err != nil { + return nil, err + } + defaultSourceCachedCert = &cert + return &cert, nil } func isCertificateExpired(cert *tls.Certificate) bool { diff --git a/transport/cert/testdata/nonexpiringtestcert.pem b/transport/cert/testdata/nonexpiringtestcert.pem new file mode 100644 index 00000000000..43260a9c7ef --- /dev/null +++ b/transport/cert/testdata/nonexpiringtestcert.pem @@ -0,0 +1,50 @@ +-----BEGIN CERTIFICATE----- +MIIDujCCAqICCQD+yrCYuiC8djANBgkqhkiG9w0BAQsFADCBnTELMAkGA1UEBhMC +VVMxEzARBgNVBAgMCldhc2hpbmd0b24xETAPBgNVBAcMCEtpcmtsYW5kMQ8wDQYD +VQQKDAZHb29nbGUxDjAMBgNVBAsMBUNsb3VkMRswGQYDVQQDDBJnb29nbGVhcGlz +dGVzdC5jb20xKDAmBgkqhkiG9w0BCQEWGWdvb2dsZWFwaXN0ZXN0QGdvb2dsZS5j +b20wIBcNMjAxMDIzMjEyNTU1WhgPMjEyMDA5MjkyMTI1NTVaMIGdMQswCQYDVQQG +EwJVUzETMBEGA1UECAwKV2FzaGluZ3RvbjERMA8GA1UEBwwIS2lya2xhbmQxDzAN +BgNVBAoMBkdvb2dsZTEOMAwGA1UECwwFQ2xvdWQxGzAZBgNVBAMMEmdvb2dsZWFw +aXN0ZXN0LmNvbTEoMCYGCSqGSIb3DQEJARYZZ29vZ2xlYXBpc3Rlc3RAZ29vZ2xl +LmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKnzFX97VP4XSQ8l +4/Z08eajnAiGpK+ZQTV9k7Qy2tpo5+iFFiL0JLGP9+GRILuDGQufYlPLDhLLho9V +YXIR9UOhhapmQJqUAUFhvZlBEixLxcfwa2LecNiJ6+8gvJCoRbrPIrz91crY+t59 +aY/09vmsCbFDX8d8WWVnww4285dfKwE2IDinqZ1VuT4zYR66f4lL8qj6t5TXeGAW +Nkd6O3yuAVO8RLiXBRRABP5217mq0jNL+kJUormzhuKgvP+oxRsi56XHPGiq7l2e +54PS/cqa4atjqbhZI1xV27y0sVr0/CmBsfeM3TwLbCSjv7r0lCz64xtCJa8R45MA +22or9z8CAwEAATANBgkqhkiG9w0BAQsFAAOCAQEAnwLY9qBIQ2IYDLNLx16av8C6 +9vca8gOzMpYZ4UKHDN+Qk2CidpmFamXWDXqmOLNZYlmEoGY5n8zg8rwYK+vauqwb +o94HzxLmQcQ4kmAI4xJnMqKZAbukRdWw2GCuvdVqG4Osngz4WBIHrAsl4btogdJy +ACU/YUA3K0tLjwe6wUYYF6eu5sb6zJkF4cfLpqECWtF9XG6nkJbo2GomHFuHm+6t +gOj7YiqU/cHCyU4FQF9/2jDLzFHxt2Bb30zi602YjuIZhYp35ktI66XwsE4kFmwo +iHCEG0fXMNN7OMFmNg2YVLhaHxrQNFxbzOQdfKg2gi2qzX4AiCo1tx5LCg6aGw== +-----END CERTIFICATE----- +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCp8xV/e1T+F0kP +JeP2dPHmo5wIhqSvmUE1fZO0MtraaOfohRYi9CSxj/fhkSC7gxkLn2JTyw4Sy4aP +VWFyEfVDoYWqZkCalAFBYb2ZQRIsS8XH8Gti3nDYievvILyQqEW6zyK8/dXK2Pre +fWmP9Pb5rAmxQ1/HfFllZ8MONvOXXysBNiA4p6mdVbk+M2Eeun+JS/Ko+reU13hg +FjZHejt8rgFTvES4lwUUQAT+dte5qtIzS/pCVKK5s4bioLz/qMUbIuelxzxoqu5d +nueD0v3KmuGrY6m4WSNcVdu8tLFa9PwpgbH3jN08C2wko7+69JQs+uMbQiWvEeOT +ANtqK/c/AgMBAAECggEAYjeE3hb1yJ7Gb0WzmDR/tI4rV9YQiRcl03cOjJ6zUnQ8 +SmnXoD2+kwuj8y1/YD7kk436MnjwWjZbPqzWUylDuGE5sX/EqFEO5K1K+K3dhdII +rIMqXIo3Zz1WJ+2gbG2DVvHsnpKIIuIBIeISxsqIjUQ6mcJZMR2RQISV+roRTxIU +1Ga0xWrExcKL8FSjs8ih0DWU4vHoSYH4DFXB1/ViyLn+DEljnOlo8Q+7DG0uQQnX +ixfYMbXSJcZxFm1iwuZv8SESjqbTsogNny5Wi6H9Vp0JFasAPUjnc+QuD/U1HTDn +PCX3eBNMcxvVJDhu/7nnO7kcU1Cx0gJeN+1bklrAcQKBgQDURl0Ac8N94I82n4Lg +wjGLWj3AMxSEHNcZuomCvoYcLTmJdd2tOnunXhh1jANnx6q8P8aR5fiTthokIUdx +bOmWwFAbP6kMe0WFWQhXjX4mXLRmJ4mWayWCE7hstnDb3/Fr7LuJeg5L3OU4ss3b +j4UvhtuQ9Qh8piVhKwFkQh3tOQKBgQDM9NSkRDVW3Q37lMUdyn8B2FBF78e/9ck+ +5bHOs52G2hXJ4tyLYNjBoLXPpMp9VWRTXxUaii+gHSa4DkHTkFwIg34hLgrCX7Gc +a0rldvkpX0xWSANfvO9bvavPgKnLSP8j3mjDiwqJuy3L5TBThIHDvPV9F/akpLne +bdcywa4ANwKBgHlvAzcGAniZJPRXjfRrwxH3/slbr0nggcDLMG0l9uxZhse3MKgv +g5t8PbvI7A3LcEWeqka+a1R84Tl3/DnL11kRDQJ5iYiFYIDnLNmBLQBfGigySAhP +pTZjd6ZhO/DcjGx0EdiUhWcqp8qmpxMKaGOG30ZulntQRKPwiSxEkoApAoGBAJ1o +h4ulawXMfnmyt3T62XJ0TKp5zoKqZSYuSNIEdr5j7goAdvuApNiI8jmISY/arlOt +mcqpSIyC9wKyyHGQ1G4hdxRKhS7lScZlTL9REWlp7HnzksvLklV2JWcXXNBovrMw +lGth9PT00eZfni72fKb1D+FEL0Qh0zJ2T6mGwHkfAoGAMOy8bbyCASCYG9MYzqaP +Lf+AKKNEYUvUGspyJUqu5ERudr5stmei6PrchxFiKjm5Qg7B/M1VnKsCtL9kk8Z9 +lHgwU5mOATZvd9k/5oiuRxzXyrWqFoT/mivI2rZE+g5cLTLytCTnyLjHm5B/aTy8 +1AmbAh5hvWYs+EMKZAlQ5GM= +-----END PRIVATE KEY----- From 117fe426bb7f0e0a163f963fa334937dc5ddc777 Mon Sep 17 00:00:00 2001 From: Andy Zhao Date: Wed, 28 Oct 2020 10:36:29 -0700 Subject: [PATCH 3/4] feat(transport): Fix style for go 115 --- transport/cert/default_cert_test.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/transport/cert/default_cert_test.go b/transport/cert/default_cert_test.go index feee586a15d..6f11dcd5a61 100644 --- a/transport/cert/default_cert_test.go +++ b/transport/cert/default_cert_test.go @@ -17,7 +17,7 @@ func TestGetClientCertificateSuccess(t *testing.T) { t.Error(err) } if cert.Certificate == nil { - t.Error("getClientCertificate: want non-nil cert, got nil") + t.Error("getClientCertificate: want non-nil Certificate, got nil") } if cert.PrivateKey == nil { t.Error("getClientCertificate: want non-nil PrivateKey, got nil") @@ -86,6 +86,9 @@ func TestCertificateCaching(t *testing.T) { if err != nil { t.Error(err) } + if cert == nil { + t.Error("getClientCertificate: want non-nil cert, got nil") + } if defaultSourceCachedCert == nil { t.Error("getClientCertificate: want non-nil defaultSourceCachedCert, got nil") } @@ -95,7 +98,7 @@ func TestCertificateCaching(t *testing.T) { if err != nil { t.Error(err) } - if bytes.Compare(cert.Certificate[0], defaultSourceCachedCert.Certificate[0]) != 0 { + if !bytes.Equal(cert.Certificate[0], defaultSourceCachedCert.Certificate[0]) { t.Error("getClientCertificate: want cached Certificate, got different Certificate") } if cert.PrivateKey != defaultSourceCachedCert.PrivateKey { From 63bb6d8cb59aa96f2b4e1816dd942bf76c91670d Mon Sep 17 00:00:00 2001 From: Andy Zhao Date: Mon, 23 Nov 2020 11:30:27 -0800 Subject: [PATCH 4/4] feat(transport): added defaultCertData struct --- transport/cert/default_cert.go | 36 ++++++++++++++++++----------- transport/cert/default_cert_test.go | 16 ++++++------- 2 files changed, 31 insertions(+), 21 deletions(-) diff --git a/transport/cert/default_cert.go b/transport/cert/default_cert.go index 111c5e07590..141ae457936 100644 --- a/transport/cert/default_cert.go +++ b/transport/cert/default_cert.go @@ -32,12 +32,18 @@ const ( metadataFile = "context_aware_metadata.json" ) +// defaultCertData holds all the variables pertaining to +// the default certficate source created by DefaultSource. +type defaultCertData struct { + once sync.Once + source Source + err error + cachedCertMutex sync.Mutex + cachedCert *tls.Certificate +} + var ( - defaultSourceOnce sync.Once - defaultSource Source - defaultSourceErr error - defaultSourceCachedCertMutex sync.Mutex - defaultSourceCachedCert *tls.Certificate + defaultCert defaultCertData ) // Source is a function that can be passed into crypto/tls.Config.GetClientCertificate. @@ -48,10 +54,10 @@ type Source func(*tls.CertificateRequestInfo) (*tls.Certificate, error) // // If that file does not exist, a nil source is returned. func DefaultSource() (Source, error) { - defaultSourceOnce.Do(func() { - defaultSource, defaultSourceErr = newSecureConnectSource() + defaultCert.once.Do(func() { + defaultCert.source, defaultCert.err = newSecureConnectSource() }) - return defaultSource, defaultSourceErr + return defaultCert.source, defaultCert.err } type secureConnectSource struct { @@ -99,10 +105,10 @@ func validateMetadata(metadata secureConnectMetadata) error { } func (s *secureConnectSource) getClientCertificate(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { - defaultSourceCachedCertMutex.Lock() - defer defaultSourceCachedCertMutex.Unlock() - if defaultSourceCachedCert != nil && !isCertificateExpired(defaultSourceCachedCert) { - return defaultSourceCachedCert, nil + defaultCert.cachedCertMutex.Lock() + defer defaultCert.cachedCertMutex.Unlock() + if defaultCert.cachedCert != nil && !isCertificateExpired(defaultCert.cachedCert) { + return defaultCert.cachedCert, nil } command := s.metadata.Cmd data, err := exec.Command(command[0], command[1:]...).Output() @@ -114,11 +120,15 @@ func (s *secureConnectSource) getClientCertificate(info *tls.CertificateRequestI if err != nil { return nil, err } - defaultSourceCachedCert = &cert + defaultCert.cachedCert = &cert return &cert, nil } +// isCertificateExpired returns true if the given cert is expired or invalid. func isCertificateExpired(cert *tls.Certificate) bool { + if len(cert.Certificate) == 0 { + return true + } parsed, err := x509.ParseCertificate(cert.Certificate[0]) if err != nil { return true diff --git a/transport/cert/default_cert_test.go b/transport/cert/default_cert_test.go index 6f11dcd5a61..2d7e333f332 100644 --- a/transport/cert/default_cert_test.go +++ b/transport/cert/default_cert_test.go @@ -10,7 +10,7 @@ import ( ) func TestGetClientCertificateSuccess(t *testing.T) { - defaultSourceCachedCert = nil + defaultCert.cachedCert = nil source := secureConnectSource{metadata: secureConnectMetadata{Cmd: []string{"cat", "testdata/testcert.pem"}}} cert, err := source.getClientCertificate(nil) if err != nil { @@ -25,7 +25,7 @@ func TestGetClientCertificateSuccess(t *testing.T) { } func TestGetClientCertificateFailure(t *testing.T) { - defaultSourceCachedCert = nil + defaultCert.cachedCert = nil source := secureConnectSource{metadata: secureConnectMetadata{Cmd: []string{"cat"}}} _, err := source.getClientCertificate(nil) if err == nil { @@ -56,7 +56,7 @@ func TestValidateMetadataFailure(t *testing.T) { } func TestIsCertificateExpiredTrue(t *testing.T) { - defaultSourceCachedCert = nil + defaultCert.cachedCert = nil source := secureConnectSource{metadata: secureConnectMetadata{Cmd: []string{"cat", "testdata/testcert.pem"}}} cert, err := source.getClientCertificate(nil) if err != nil { @@ -68,7 +68,7 @@ func TestIsCertificateExpiredTrue(t *testing.T) { } func TestIsCertificateExpiredFalse(t *testing.T) { - defaultSourceCachedCert = nil + defaultCert.cachedCert = nil source := secureConnectSource{metadata: secureConnectMetadata{Cmd: []string{"cat", "testdata/nonexpiringtestcert.pem"}}} cert, err := source.getClientCertificate(nil) if err != nil { @@ -80,7 +80,7 @@ func TestIsCertificateExpiredFalse(t *testing.T) { } func TestCertificateCaching(t *testing.T) { - defaultSourceCachedCert = nil + defaultCert.cachedCert = nil source := secureConnectSource{metadata: secureConnectMetadata{Cmd: []string{"cat", "testdata/nonexpiringtestcert.pem"}}} cert, err := source.getClientCertificate(nil) if err != nil { @@ -89,7 +89,7 @@ func TestCertificateCaching(t *testing.T) { if cert == nil { t.Error("getClientCertificate: want non-nil cert, got nil") } - if defaultSourceCachedCert == nil { + if defaultCert.cachedCert == nil { t.Error("getClientCertificate: want non-nil defaultSourceCachedCert, got nil") } @@ -98,10 +98,10 @@ func TestCertificateCaching(t *testing.T) { if err != nil { t.Error(err) } - if !bytes.Equal(cert.Certificate[0], defaultSourceCachedCert.Certificate[0]) { + if !bytes.Equal(cert.Certificate[0], defaultCert.cachedCert.Certificate[0]) { t.Error("getClientCertificate: want cached Certificate, got different Certificate") } - if cert.PrivateKey != defaultSourceCachedCert.PrivateKey { + if cert.PrivateKey != defaultCert.cachedCert.PrivateKey { t.Error("getClientCertificate: want cached PrivateKey, got different PrivateKey") } }