From aeedfbf5a7fdd1c7a70a952ae628eac6dbb47f14 Mon Sep 17 00:00:00 2001 From: Vince Prignano Date: Thu, 4 May 2023 14:25:32 -0700 Subject: [PATCH] Add certwatcher callback Signed-off-by: Vince Prignano --- pkg/certwatcher/certwatcher.go | 22 ++++++++++++++++++++++ pkg/certwatcher/certwatcher_test.go | 8 ++++++++ 2 files changed, 30 insertions(+) diff --git a/pkg/certwatcher/certwatcher.go b/pkg/certwatcher/certwatcher.go index 515a13bcb4..2b9b60d8d7 100644 --- a/pkg/certwatcher/certwatcher.go +++ b/pkg/certwatcher/certwatcher.go @@ -44,6 +44,9 @@ type CertWatcher struct { certPath string keyPath string + + // callback is a function to be invoked when the certificate changes. + callback func(tls.Certificate) } // New returns a new CertWatcher watching the given certificate and key. @@ -68,6 +71,17 @@ func New(certPath, keyPath string) (*CertWatcher, error) { return cw, nil } +// RegisterCallback registers a callback to be invoked when the certificate changes. +func (cw *CertWatcher) RegisterCallback(callback func(tls.Certificate)) { + cw.Lock() + defer cw.Unlock() + // If the current certificate is not nil, invoke the callback immediately. + if cw.currentCert != nil { + callback(*cw.currentCert) + } + cw.callback = callback +} + // GetCertificate fetches the currently loaded certificate, which may be nil. func (cw *CertWatcher) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { cw.RLock() @@ -146,6 +160,14 @@ func (cw *CertWatcher) ReadCertificate() error { log.Info("Updated current TLS certificate") + // If a callback is registered, invoke it with the new certificate. + cw.RLock() + defer cw.RUnlock() + if cw.callback != nil { + go func() { + cw.callback(cert) + }() + } return nil } diff --git a/pkg/certwatcher/certwatcher_test.go b/pkg/certwatcher/certwatcher_test.go index c7349ea80d..7eef9d8b0e 100644 --- a/pkg/certwatcher/certwatcher_test.go +++ b/pkg/certwatcher/certwatcher_test.go @@ -20,6 +20,7 @@ import ( "context" "crypto/rand" "crypto/rsa" + "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/pem" @@ -27,6 +28,7 @@ import ( "math/big" "net" "os" + "sync/atomic" "time" . "github.com/onsi/ginkgo/v2" @@ -97,6 +99,11 @@ var _ = Describe("CertWatcher", func() { It("should reload currentCert when changed", func() { doneCh := startWatcher() + called := atomic.Int64{} + watcher.RegisterCallback(func(crt tls.Certificate) { + called.Add(1) + Expect(crt.Certificate).ToNot(BeEmpty()) + }) firstcert, _ := watcher.GetCertificate(nil) @@ -111,6 +118,7 @@ var _ = Describe("CertWatcher", func() { ctxCancel() Eventually(doneCh, "4s").Should(BeClosed()) + Expect(called.Load()).To(BeNumerically(">=", 1)) }) Context("prometheus metric read_certificate_total", func() {