From 8728b186fa6874e93bbc051dc73d1596dd63d390 Mon Sep 17 00:00:00 2001 From: Matt Holt Date: Thu, 11 May 2023 12:36:44 -0600 Subject: [PATCH] Refactor Managers into on-demand config (#231) --- certificates.go | 3 +++ certmagic.go | 9 +++++++ config.go | 12 ---------- handshake.go | 62 +++++++++++++++++++++---------------------------- 4 files changed, 39 insertions(+), 47 deletions(-) diff --git a/certificates.go b/certificates.go index 5f9d0c89..9e983406 100644 --- a/certificates.go +++ b/certificates.go @@ -113,6 +113,9 @@ func (cert Certificate) HasTag(tag string) bool { // resolution of ASN.1 UTCTime/GeneralizedTime by including the extra fraction // of a second of certificate validity beyond the NotAfter value. func expiresAt(cert *x509.Certificate) time.Time { + if cert == nil { + return time.Time{} + } return cert.NotAfter.Truncate(time.Second).Add(1 * time.Second) } diff --git a/certmagic.go b/certmagic.go index 1c6471c4..8200431f 100644 --- a/certmagic.go +++ b/certmagic.go @@ -270,6 +270,15 @@ type OnDemandConfig struct { // request will be denied. DecisionFunc func(name string) error + // Sources for getting new, unmanaged certificates. + // They will be invoked only during TLS handshakes + // before on-demand certificate management occurs, + // for certificates that are not already loaded into + // the in-memory cache. + // + // TODO: EXPERIMENTAL: subject to change and/or removal. + Managers []Manager + // List of allowed hostnames (SNI values) for // deferred (on-demand) obtaining of certificates. // Used only by higher-level functions in this diff --git a/config.go b/config.go index a02f4335..ddafa87e 100644 --- a/config.go +++ b/config.go @@ -95,15 +95,6 @@ type Config struct { // turn until one succeeds. Issuers []Issuer - // Sources for getting new, unmanaged certificates. - // They will be invoked only during TLS handshakes - // before on-demand certificate management occurs, - // for certificates that are not already loaded into - // the in-memory cache. - // - // TODO: EXPERIMENTAL: subject to change and/or removal. - Managers []Manager - // The source of new private keys for certificates; // the default KeySource is StandardKeyGenerator. KeySource KeyGenerator @@ -234,9 +225,6 @@ func newWithCache(certCache *Cache, cfg Config) *Config { cfg.Issuers = []Issuer{NewACMEIssuer(&cfg, DefaultACME)} } } - if cfg.Managers == nil { - cfg.Managers = Default.Managers - } if cfg.RenewalWindowRatio == 0 { cfg.RenewalWindowRatio = Default.RenewalWindowRatio } diff --git a/handshake.go b/handshake.go index 65bae2ac..6d28ddfc 100644 --- a/handshake.go +++ b/handshake.go @@ -81,7 +81,7 @@ func (cfg *Config) GetCertificateWithContext(ctx context.Context, clientHello *t } // get the certificate and serve it up - cert, err := cfg.getCertDuringHandshake(ctx, clientHello, true, true) + cert, err := cfg.getCertDuringHandshake(ctx, clientHello, true) return &cert.Certificate, err } @@ -253,19 +253,19 @@ func DefaultCertificateSelector(hello *tls.ClientHelloInfo, choices []Certificat // An error will be returned if and only if no certificate is available. // // This function is safe for concurrent use. -func (cfg *Config) getCertDuringHandshake(ctx context.Context, hello *tls.ClientHelloInfo, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) { - log := logWithRemote(cfg.Logger.Named("handshake"), hello) +func (cfg *Config) getCertDuringHandshake(ctx context.Context, hello *tls.ClientHelloInfo, loadOrObtainIfNecessary bool) (Certificate, error) { + logger := logWithRemote(cfg.Logger.Named("handshake"), hello) name := cfg.getNameFromClientHello(hello) // First check our in-memory cache to see if we've already loaded it cert, matched, defaulted := cfg.getCertificateFromCache(hello) if matched { - log.Debug("matched certificate in cache", + logger.Debug("matched certificate in cache", zap.Strings("subjects", cert.Names), zap.Bool("managed", cert.managed), zap.Time("expiration", expiresAt(cert.Leaf)), zap.String("hash", cert.hash)) - if cert.managed && cfg.OnDemand != nil && obtainIfNecessary { + if cert.managed && cfg.OnDemand != nil && loadOrObtainIfNecessary { // On-demand certificates are maintained in the background, but // maintenance is triggered by handshakes instead of by a timer // as in maintain.go. @@ -294,7 +294,7 @@ func (cfg *Config) getCertDuringHandshake(ctx context.Context, hello *tls.Client timeout.Stop() } - return cfg.getCertDuringHandshake(ctx, hello, false, false) + return cfg.getCertDuringHandshake(ctx, hello, false) } else { // no other goroutine is currently trying to load this cert wait = make(chan struct{}) @@ -319,7 +319,7 @@ func (cfg *Config) getCertDuringHandshake(ctx context.Context, hello *tls.Client // If an external Manager is configured, try to get it from them. // Only continue to use our own logic if it returns empty+nil. - externalCert, err := cfg.getCertFromAnyCertManager(ctx, hello, log) + externalCert, err := cfg.getCertFromAnyCertManager(ctx, hello, logger) if err != nil { return Certificate{}, err } @@ -345,24 +345,25 @@ func (cfg *Config) getCertDuringHandshake(ctx context.Context, hello *tls.Client cacheAlmostFull := cacheCapacity > 0 && float64(cacheSize) >= cacheCapacity*.9 loadDynamically := cfg.OnDemand != nil || cacheAlmostFull - if loadDynamically && loadIfNecessary { + if loadDynamically && loadOrObtainIfNecessary { // Check to see if we have one on disk - loadedCert, err := cfg.loadCertFromStorage(ctx, log, hello) + loadedCert, err := cfg.loadCertFromStorage(ctx, logger, hello) if err == nil { return loadedCert, nil } - log.Debug("did not load cert from storage", + logger.Debug("did not load cert from storage", zap.String("server_name", hello.ServerName), zap.Error(err)) if cfg.OnDemand != nil { // By this point, we need to ask the CA for a certificate return cfg.obtainOnDemandCertificate(ctx, hello) } + return loadedCert, nil } // Fall back to another certificate if there is one (either DefaultServerName or FallbackServerName) if defaulted { - log.Debug("fell back to other certificate", + logger.Debug("fell back to default certificate", zap.Strings("subjects", cert.Names), zap.Bool("managed", cert.managed), zap.Time("expiration", expiresAt(cert.Leaf)), @@ -370,20 +371,19 @@ func (cfg *Config) getCertDuringHandshake(ctx context.Context, hello *tls.Client return cert, nil } - log.Debug("no certificate matching TLS ClientHello", + logger.Debug("no certificate matching TLS ClientHello", zap.String("server_name", hello.ServerName), zap.String("remote", hello.Conn.RemoteAddr().String()), zap.String("identifier", name), zap.Uint16s("cipher_suites", hello.CipherSuites), zap.Float64("cert_cache_fill", float64(cacheSize)/cacheCapacity), // may be approximate! because we are not within the lock - zap.Bool("load_if_necessary", loadIfNecessary), - zap.Bool("obtain_if_necessary", obtainIfNecessary), + zap.Bool("load_or_obtain_if_necessary", loadOrObtainIfNecessary), zap.Bool("on_demand", cfg.OnDemand != nil)) return Certificate{}, fmt.Errorf("no certificate available for '%s'", name) } -func (cfg *Config) loadCertFromStorage(ctx context.Context, log *zap.Logger, hello *tls.ClientHelloInfo) (Certificate, error) { +func (cfg *Config) loadCertFromStorage(ctx context.Context, logger *zap.Logger, hello *tls.ClientHelloInfo) (Certificate, error) { name := normalizedName(hello.ServerName) loadedCert, err := cfg.CacheManagedCertificate(ctx, name) if errors.Is(err, fs.ErrNotExist) { @@ -395,14 +395,14 @@ func (cfg *Config) loadCertFromStorage(ctx context.Context, log *zap.Logger, hel if err != nil { return Certificate{}, fmt.Errorf("no matching certificate to load for %s: %w", name, err) } - log.Debug("loaded certificate from storage", + logger.Debug("loaded certificate from storage", zap.Strings("subjects", loadedCert.Names), zap.Bool("managed", loadedCert.managed), zap.Time("expiration", expiresAt(loadedCert.Leaf)), zap.String("hash", loadedCert.hash)) loadedCert, err = cfg.handshakeMaintenance(ctx, hello, loadedCert) if err != nil { - log.Error("maintaining newly-loaded certificate", + logger.Error("maintaining newly-loaded certificate", zap.String("server_name", name), zap.Error(err)) } @@ -465,10 +465,6 @@ func (cfg *Config) obtainOnDemandCertificate(ctx context.Context, hello *tls.Cli name := cfg.getNameFromClientHello(hello) - getCertWithoutReobtaining := func() (Certificate, error) { - return cfg.loadCertFromStorage(ctx, log, hello) - } - // We must protect this process from happening concurrently, so synchronize. obtainCertWaitChansMu.Lock() wait, ok := obtainCertWaitChans[name] @@ -486,7 +482,7 @@ func (cfg *Config) obtainOnDemandCertificate(ctx context.Context, hello *tls.Cli timeout.Stop() } - return getCertWithoutReobtaining() + return cfg.loadCertFromStorage(ctx, log, hello) } // looks like it's up to us to do all the work and obtain the cert. @@ -525,7 +521,7 @@ func (cfg *Config) obtainOnDemandCertificate(ctx context.Context, hello *tls.Cli // success; certificate was just placed on disk, so // we need only restart serving the certificate - return getCertWithoutReobtaining() + return cfg.loadCertFromStorage(ctx, log, hello) } // handshakeMaintenance performs a check on cert for expiration and OCSP validity. @@ -613,10 +609,6 @@ func (cfg *Config) renewDynamicCertificate(ctx context.Context, hello *tls.Clien timeLeft := time.Until(expiresAt(currentCert.Leaf)) revoked := currentCert.ocsp != nil && currentCert.ocsp.Status == ocsp.Revoked - getCertWithoutReobtaining := func() (Certificate, error) { - return cfg.loadCertFromStorage(ctx, log, hello) - } - // see if another goroutine is already working on this certificate obtainCertWaitChansMu.Lock() wait, ok := obtainCertWaitChans[name] @@ -651,7 +643,7 @@ func (cfg *Config) renewDynamicCertificate(ctx context.Context, hello *tls.Clien timeout.Stop() } - return getCertWithoutReobtaining() + return cfg.loadCertFromStorage(ctx, log, hello) } // looks like it's up to us to do all the work and renew the cert @@ -726,7 +718,7 @@ func (cfg *Config) renewDynamicCertificate(ctx context.Context, hello *tls.Clien return newCert, err } - return getCertWithoutReobtaining() + return cfg.loadCertFromStorage(ctx, log, hello) } // if the certificate hasn't expired, we can serve what we have and renew in the background @@ -744,20 +736,20 @@ func (cfg *Config) renewDynamicCertificate(ctx context.Context, hello *tls.Clien // getCertFromAnyCertManager gets a certificate from cfg's Managers. If there are no Managers defined, this is // a no-op that returns empty values. Otherwise, it gets a certificate for hello from the first Manager that // returns a certificate and no error. -func (cfg *Config) getCertFromAnyCertManager(ctx context.Context, hello *tls.ClientHelloInfo, log *zap.Logger) (Certificate, error) { +func (cfg *Config) getCertFromAnyCertManager(ctx context.Context, hello *tls.ClientHelloInfo, logger *zap.Logger) (Certificate, error) { // fast path if nothing to do - if len(cfg.Managers) == 0 { + if cfg.OnDemand == nil || len(cfg.OnDemand.Managers) == 0 { return Certificate{}, nil } var upstreamCert *tls.Certificate // try all the GetCertificate methods on external managers; use first one that returns a certificate - for i, certManager := range cfg.Managers { + for i, certManager := range cfg.OnDemand.Managers { var err error upstreamCert, err = certManager.GetCertificate(ctx, hello) if err != nil { - log.Error("getting certificate from external certificate manager", + logger.Error("getting certificate from external certificate manager", zap.String("sni", hello.ServerName), zap.Int("cert_manager", i), zap.Error(err)) @@ -768,7 +760,7 @@ func (cfg *Config) getCertFromAnyCertManager(ctx context.Context, hello *tls.Cli } } if upstreamCert == nil { - log.Debug("all external certificate managers yielded no certificates and no errors", zap.String("sni", hello.ServerName)) + logger.Debug("all external certificate managers yielded no certificates and no errors", zap.String("sni", hello.ServerName)) return Certificate{}, nil } @@ -778,7 +770,7 @@ func (cfg *Config) getCertFromAnyCertManager(ctx context.Context, hello *tls.Cli return Certificate{}, fmt.Errorf("external certificate manager: %s: filling cert from leaf: %v", hello.ServerName, err) } - log.Debug("using externally-managed certificate", + logger.Debug("using externally-managed certificate", zap.String("sni", hello.ServerName), zap.Strings("names", cert.Names), zap.Time("expiration", expiresAt(cert.Leaf)))