diff --git a/go.mod b/go.mod index 98bc141..8a2fb58 100644 --- a/go.mod +++ b/go.mod @@ -8,5 +8,5 @@ require ( github.com/lestrrat-go/jwx v0.9.0 github.com/pkg/errors v0.8.1 // indirect github.com/stretchr/testify v1.4.0 - github.com/urfave/cli/v2 v2.0.0 + github.com/urfave/cli/v2 v2.0.1-0.20191214051647-cae7b0c5e15e ) diff --git a/go.sum b/go.sum index 2ba0a06..d6f0e0e 100644 --- a/go.sum +++ b/go.sum @@ -22,6 +22,8 @@ github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJy github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/urfave/cli/v2 v2.0.0 h1:+HU9SCbu8GnEUFtIBfuUNXN39ofWViIEJIp6SURMpCg= github.com/urfave/cli/v2 v2.0.0/go.mod h1:SE9GqnLQmjVa0iPEY0f1w3ygNIYcIJ0OKPMoW2caLfQ= +github.com/urfave/cli/v2 v2.0.1-0.20191214051647-cae7b0c5e15e h1:1wd+juynOAB6e0t+yw3gc9Pkz6N4QRwtrxzT1l2dEcU= +github.com/urfave/cli/v2 v2.0.1-0.20191214051647-cae7b0c5e15e/go.mod h1:SE9GqnLQmjVa0iPEY0f1w3ygNIYcIJ0OKPMoW2caLfQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= diff --git a/main.go b/main.go index 0c0645b..ff50851 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "errors" "fmt" "io/ioutil" @@ -8,7 +9,9 @@ import ( "net/http" "net/url" "os" + "os/signal" "strconv" + "syscall" "github.com/caido/grafana-auth-proxy/pkg/extraction" "github.com/caido/grafana-auth-proxy/pkg/identity" @@ -123,11 +126,13 @@ func createRequestsHandler(c *cli.Context) (*RequestsHandler, error) { } func launchProxy(c *cli.Context) error { + // Build requests handler requestsHandler, err := createRequestsHandler(c) if err != nil { return err } + // Find port port := c.Int("port") if port == 0 { return errors.New("a port is required") @@ -135,7 +140,19 @@ func launchProxy(c *cli.Context) error { log.Printf("Proxy running on port : %d", port) - return http.ListenAndServe(":"+strconv.Itoa(port), requestsHandler) + // Start server + server := http.Server{Addr: ":" + strconv.Itoa(port), Handler: requestsHandler} + go func() { + if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Fatal(err) + } + }() + + // Handle shutdown + select { + case <-c.Context.Done(): + return server.Shutdown(c.Context) + } } func main() { @@ -147,7 +164,7 @@ func main() { log.Printf("Unable to load a .env file") } - // Launch app + // Build the app app := &cli.App{ Action: launchProxy, Flags: []cli.Flag{ @@ -225,8 +242,23 @@ func main() { }, } - err = app.Run(os.Args) + // Handle signals + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) + + go func() { + <-sigs + os.Exit(1) + }() + + // Run the app + ctx, cancel := context.WithCancel(context.Background()) + err = app.RunContext(ctx, os.Args) if err != nil { log.Fatal(err) } + + // Handle graceful shutdown + <-sigs + cancel() }