diff --git a/registry/registry.go b/registry/registry.go index 96f68b1248..b7989f0217 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -81,9 +81,6 @@ var tlsVersions = map[string]uint16{ // defaultLogFormatter is the default formatter to use for logs. const defaultLogFormatter = "text" -// this channel gets notified when process receives signal. It is global to ease unit testing -var quit = make(chan os.Signal, 1) - // HandlerFunc defines an http middleware type HandlerFunc func(config *configuration.Configuration, handler http.Handler) http.Handler @@ -130,6 +127,7 @@ type Registry struct { config *configuration.Configuration app *handlers.App server *http.Server + quit chan os.Signal } // NewRegistry creates a new registry from a context and configuration struct. @@ -173,6 +171,7 @@ func NewRegistry(ctx context.Context, config *configuration.Configuration) (*Reg app: app, config: config, server: server, + quit: make(chan os.Signal, 1), }, nil } @@ -313,7 +312,7 @@ func (registry *Registry) ListenAndServe() error { } // setup channel to get notified on SIGTERM signal - signal.Notify(quit, syscall.SIGTERM) + signal.Notify(registry.quit, syscall.SIGTERM) serveErr := make(chan error) // Start serving in goroutine and listen for stop signal in main thread @@ -324,15 +323,20 @@ func (registry *Registry) ListenAndServe() error { select { case err := <-serveErr: return err - case <-quit: + case <-registry.quit: dcontext.GetLogger(registry.app).Info("stopping server gracefully. Draining connections for ", config.HTTP.DrainTimeout) // shutdown the server with a grace period of configured timeout c, cancel := context.WithTimeout(context.Background(), config.HTTP.DrainTimeout) defer cancel() - return registry.server.Shutdown(c) + return registry.Shutdown(c) } } +// Shutdown gracefully shuts down the registry's HTTP server. +func (registry *Registry) Shutdown(ctx context.Context) error { + return registry.server.Shutdown(ctx) +} + func configureDebugServer(config *configuration.Configuration) { if config.HTTP.Debug.Addr != "" { go func(addr string) { diff --git a/registry/registry_test.go b/registry/registry_test.go index 98dd6c94dd..ea23b9db1f 100644 --- a/registry/registry_test.go +++ b/registry/registry_test.go @@ -103,7 +103,7 @@ func TestGracefulShutdown(t *testing.T) { fmt.Fprintf(conn, "GET /v2/ ") // send stop signal - quit <- os.Interrupt + registry.quit <- os.Interrupt time.Sleep(100 * time.Millisecond) // try connecting again. it shouldn't @@ -325,7 +325,7 @@ func TestRegistrySupportedCipherSuite(t *testing.T) { } // send stop signal - quit <- os.Interrupt + registry.quit <- os.Interrupt time.Sleep(100 * time.Millisecond) } @@ -369,7 +369,7 @@ func TestRegistryUnsupportedCipherSuite(t *testing.T) { } // send stop signal - quit <- os.Interrupt + registry.quit <- os.Interrupt time.Sleep(100 * time.Millisecond) }