Skip to content

Commit

Permalink
server: Refactor TLS config loading
Browse files Browse the repository at this point in the history
Changes to improve readability.

Signed-off-by: Charlie Egan <charlie@styra.com>
  • Loading branch information
charlieegan3 committed Nov 21, 2023
1 parent a2fbe36 commit a621e2b
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 40 deletions.
10 changes: 8 additions & 2 deletions runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -564,9 +564,15 @@ func (rt *Runtime) Serve(ctx context.Context) error {
rt.server = rt.server.WithUnixSocketPermission(rt.Params.UnixSocketPerm)
}

// If a refresh period is set, then we will periodically reload the certificate and ca pool. Otherwise, we will only
// reload cert, key and ca pool files when they change on disk.
if rt.Params.CertificateRefresh > 0 {
rt.server = rt.server.WithCertificatePaths(rt.Params.CertificateFile, rt.Params.CertificateKeyFile, rt.Params.CertificateRefresh)
} else if rt.Params.Certificate != nil {
rt.server = rt.server.WithCertRefresh(rt.Params.CertificateRefresh)
}

// if either the cert or the ca pool file is set then these fields will be set on the server and reloaded when they
// change on disk.
if rt.Params.CertificateFile != "" || rt.Params.CertPoolFile != "" {
rt.server = rt.server.WithTLSConfig(&server.TLSConfig{
CertFile: rt.Params.CertificateFile,
KeyFile: rt.Params.CertificateKeyFile,
Expand Down
88 changes: 51 additions & 37 deletions server/certs.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,54 +27,56 @@ func (s *Server) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error
}

func (s *Server) reloadTLSConfig(logger logging.Logger) error {
certHash, err := hash(s.certFile)
if err != nil {
return fmt.Errorf("failed to refresh server certificate: %w", err)
}
certKeyHash, err := hash(s.certKeyFile)
if err != nil {
return fmt.Errorf("failed to refresh server key: %w", err)
}

s.tlsConfigMtx.Lock()
defer s.tlsConfigMtx.Unlock()

different := !bytes.Equal(s.certFileHash, certHash) ||
!bytes.Equal(s.certKeyFileHash, certKeyHash)
// if the server has a cert configured, then we need to check the cert and key for changes.
if s.certFile != "" {
certHash, err := hash(s.certFile)
if err != nil {
return fmt.Errorf("failed to check server certificate: %w", err)
}

if different { // load and store
newCert, err := tls.LoadX509KeyPair(s.certFile, s.certKeyFile)
certKeyHash, err := hash(s.certKeyFile)
if err != nil {
return fmt.Errorf("failed to refresh server certificate: %w", err)
return fmt.Errorf("failed to check server key: %w", err)
}
s.cert = &newCert
s.certFileHash = certHash
s.certKeyFileHash = certKeyHash
logger.Debug("Refreshed server certificate.")
}

// do not attempt to reload the ca cert pool if it has not been configured
if s.certPoolFile == "" {
return nil
}
different := !bytes.Equal(s.certFileHash, certHash) ||
!bytes.Equal(s.certKeyFileHash, certKeyHash)

certPoolHash, err := hash(s.certPoolFile)
if err != nil {
return fmt.Errorf("failed to refresh CA cert pool: %w", err)
if different { // load and store
newCert, err := tls.LoadX509KeyPair(s.certFile, s.certKeyFile)
if err != nil {
return fmt.Errorf("failed to refresh server certificate: %w", err)
}
s.cert = &newCert
s.certFileHash = certHash
s.certKeyFileHash = certKeyHash
logger.Debug("Refreshed server certificate.")
}
}

if !bytes.Equal(s.certPoolFileHash, certPoolHash) {
caCertPEM, err := os.ReadFile(s.certPoolFile)
// do not attempt to reload the ca cert pool if it has not been configured.
if s.certPoolFile != "" {
certPoolHash, err := hash(s.certPoolFile)
if err != nil {
return fmt.Errorf("failed to read CA cert pool file: %w", err)
return fmt.Errorf("failed to refresh CA cert pool: %w", err)
}

pool := x509.NewCertPool()
if ok := pool.AppendCertsFromPEM(caCertPEM); !ok {
return fmt.Errorf("failed to parse CA cert pool file %q", s.certPoolFile)
}
if !bytes.Equal(s.certPoolFileHash, certPoolHash) {
caCertPEM, err := os.ReadFile(s.certPoolFile)
if err != nil {
return fmt.Errorf("failed to read CA cert pool file: %w", err)
}

s.certPool = pool
pool := x509.NewCertPool()
if ok := pool.AppendCertsFromPEM(caCertPEM); !ok {
return fmt.Errorf("failed to parse CA cert pool file %q", s.certPoolFile)
}

s.certPool = pool
}
}

return nil
Expand All @@ -95,9 +97,21 @@ func (s *Server) certLoopPolling(logger logging.Logger) Loop {

func (s *Server) certLoopNotify(logger logging.Logger) Loop {
return func() error {
watcher, err := pathwatcher.CreatePathWatcher([]string{
s.certFile, s.certKeyFile, s.certPoolFile,
})

var paths []string

// if a cert file is set, then we want to watch the cert and key
if s.certFile != "" {
paths = append(paths, s.certFile, s.certKeyFile)
}

// if a cert pool file is set, then we want to watch the cert pool. This might be set without the cert and key
// being set too.
if s.certPoolFile != "" {
paths = append(paths, s.certPoolFile)
}

watcher, err := pathwatcher.CreatePathWatcher(paths)
if err != nil {
return fmt.Errorf("failed to create tls path watcher: %w", err)
}
Expand Down
8 changes: 7 additions & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,12 @@ func (s *Server) WithTLSConfig(tlsConfig *TLSConfig) *Server {
return s
}

// WithCertRefresh sets the period on which certs, keys and cert pools are reloaded from disk.
func (s *Server) WithCertRefresh(refresh time.Duration) *Server {
s.certRefresh = refresh
return s
}

// WithStore sets the storage used by the server.
func (s *Server) WithStore(store storage.Store) *Server {
s.store = store
Expand Down Expand Up @@ -595,7 +601,7 @@ func (s *Server) getListener(addr string, h http.Handler, t httpListenerType) ([
// otherwise use the fsnotify default behavior
if s.certRefresh > 0 {
loops = []Loop{loop, s.certLoopPolling(logger)}
} else if s.certFile != "" {
} else if s.certFile != "" || s.certPoolFile != "" {
loops = []Loop{loop, s.certLoopNotify(logger)}
}
default:
Expand Down

0 comments on commit a621e2b

Please sign in to comment.