Skip to content

Commit

Permalink
More tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dunglas committed Nov 7, 2019
1 parent 2c200f8 commit 79dc536
Show file tree
Hide file tree
Showing 9 changed files with 50 additions and 66 deletions.
1 change: 0 additions & 1 deletion .golangci.yml
Expand Up @@ -8,7 +8,6 @@ linters:
- errcheck
- lll
- wsl
- gochecknoinits

issues:
exclude-rules:
Expand Down
13 changes: 7 additions & 6 deletions cmd/root.go
Expand Up @@ -9,7 +9,7 @@ import (
)

// rootCmd represents the base command when called without any subcommands
var rootCmd = &cobra.Command{
var rootCmd = &cobra.Command{ //nolint:gochecknoglobals
Use: "mercure",
Short: "Start the Mercure Hub",
Long: `Mercure is a protocol allowing to push data updates to web browsers and
Expand All @@ -22,18 +22,19 @@ Go to https://mercure.rocks for more information!`,
},
}

// Execute adds all child commands to the root command and sets flags appropriately.
// This is called by main.main(). It only needs to happen once to the rootCmd.
func Execute() {
if err := rootCmd.Execute(); err != nil {
log.Fatalln(err)
}
}

func init() {
cobra.OnInitialize(hub.InitConfig)
func init() { //nolint:gochecknoinits
v := viper.GetViper()
cobra.OnInitialize(func() {
hub.InitConfig(v)
})
fs := rootCmd.Flags()
hub.SetFlags(fs, viper.GetViper())
hub.SetFlags(fs, v)

hub.InitLogrus()
}
63 changes: 11 additions & 52 deletions hub/config.go
Expand Up @@ -3,6 +3,7 @@ package hub
import (
"fmt"
"os"
"strings"
"time"

"github.com/spf13/pflag"
Expand Down Expand Up @@ -30,89 +31,47 @@ func ValidateConfig(v *viper.Viper) error {
return fmt.Errorf(`one of "jwt_key" or "publisher_jwt_key" configuration parameter must be defined`)
}
if v.GetString("cert_file") != "" && v.GetString("key_file") == "" {
return fmt.Errorf(`ff the "cert_file" configuration parameter is defined, "key_file" must be defined too`)
return fmt.Errorf(`if the "cert_file" configuration parameter is defined, "key_file" must be defined too`)
}
if v.GetString("key_file") != "" && v.GetString("cert_file") == "" {
return fmt.Errorf(`ff the "key_file" configuration parameter is defined, "cert_file" must be defined too`)
return fmt.Errorf(`if the "key_file" configuration parameter is defined, "cert_file" must be defined too`)
}
return nil
}

// SetFlags creates flags and bind them to Viper
func SetFlags(fs *pflag.FlagSet, v *viper.Viper) {
fs.BoolP("debug", "d", false, "enable the debug mode")
v.BindPFlag("debug", fs.Lookup("debug"))

fs.StringP("transport-url", "t", "", "transport and history system to use")
v.BindPFlag("transport_url", fs.Lookup("transport-url"))

fs.StringP("jwt-key", "k", "", "JWT key")
v.BindPFlag("jwt_key", fs.Lookup("jwt-key"))

fs.StringP("jwt-algorithm", "O", "", "JWT algorithm")
v.BindPFlag("jwt_algorithm", fs.Lookup("jwt-algorithm"))

fs.StringP("publisher-jwt-key", "K", "", "publisher JWT key")
v.BindPFlag("publisher_jwt_key", fs.Lookup("publisher-jwt-key"))

fs.StringP("publisher-jwt-algorithm", "A", "", "publisher JWT algorithm")
v.BindPFlag("publisher_jwt_algorithm", fs.Lookup("publisher-jwt-algorithm"))

fs.StringP("subscriber-jwt-key", "L", "", "subscriber JWT key")
v.BindPFlag("subscriber_jwt_key", fs.Lookup("subscriber-jwt-key"))

fs.StringP("subscriber-jwt-algorithm", "B", "", "subscriber JWT algorithm")
v.BindPFlag("subscriber_jwt_algorithm", fs.Lookup("subscriber-jwt-algorithm"))

fs.BoolP("allow-anonymous", "X", false, "allow subscribers with no valid JWT to connect")
v.BindPFlag("allow_anonymous", fs.Lookup("allow-anonymous"))

fs.StringSliceP("cors-allowed-origins", "c", []string{}, "list of allowed CORS origins")
v.BindPFlag("cors_allowed_origins", fs.Lookup("cors-allowed-origins"))

fs.StringSliceP("publish-allowed-origins", "p", []string{}, "list of origins allowed to publish")
v.BindPFlag("publish_allowed_origins", fs.Lookup("publish-allowed-origins"))

fs.StringP("addr", "a", "", "the address to listen on")
v.BindPFlag("addr", fs.Lookup("addr"))

fs.StringSliceP("acme-hosts", "o", []string{}, "list of hosts for which Let's Encrypt certificates must be issued")
v.BindPFlag("acme_hosts", fs.Lookup("acme-hosts"))

fs.StringP("acme-cert-dir", "E", "", "the directory where to store Let's Encrypt certificates")
v.BindPFlag("acme_cert_dir", fs.Lookup("acme-cert-dir"))

fs.StringP("cert-file", "C", "", "a cert file (to use a custom certificate)")
v.BindPFlag("cert_file", fs.Lookup("cert-file"))

fs.StringP("key-file", "J", "", "a key file (to use a custom certificate)")
v.BindPFlag("key_file", fs.Lookup("key-file"))

fs.StringP("heartbeat-interval", "i", "", "interval between heartbeats (0s to disable)")
v.BindPFlag("heartbeat_interval", fs.Lookup("heartbeat-interval"))

fs.StringP("read-timeout", "R", "", "maximum duration for reading the entire request, including the body")
v.BindPFlag("read_timeout", fs.Lookup("read-timeout"))

fs.StringP("write-timeout", "W", "", "maximum duration before timing out writes of the response")
v.BindPFlag("write_timeout", fs.Lookup("write-timeout"))

fs.DurationP("heartbeat-interval", "i", 15*time.Second, "interval between heartbeats (0s to disable)")
fs.DurationP("read-timeout", "R", time.Duration(0), "maximum duration for reading the entire request, including the body")
fs.DurationP("write-timeout", "W", time.Duration(0), "maximum duration before timing out writes of the response")
fs.BoolP("compress", "Z", false, "enable or disable HTTP compression support")
v.BindPFlag("compress", fs.Lookup("compress"))

fs.BoolP("use-forwarded-headers", "f", false, "enable headers forwarding")
v.BindPFlag("use_forwarded_headers", fs.Lookup("use-forwarded-headers"))

fs.BoolP("demo", "D", false, "enable the demo mode")
v.BindPFlag("demo", fs.Lookup("demo"))

fs.StringP("log-format", "l", "", "the log format (JSON, FLUENTD or TEXT)")
v.BindPFlag("log_format", fs.Lookup("log-format"))

fs.VisitAll(func(f *pflag.Flag) {
v.BindPFlag(strings.ReplaceAll(f.Name, "-", "_"), fs.Lookup(f.Name))
})
}

// InitConfig reads in config file and ENV variables if set.
func InitConfig() {
v := viper.GetViper()
func InitConfig(v *viper.Viper) {
SetConfigDefaults(v)

v.SetConfigName("mercure")
Expand Down
11 changes: 11 additions & 0 deletions hub/config_test.go
@@ -1,6 +1,7 @@
package hub

import (
"os"
"testing"

"github.com/spf13/pflag"
Expand Down Expand Up @@ -38,3 +39,13 @@ func TestSetFlags(t *testing.T) {

assert.Subset(t, []string{"cert_file", "compress", "demo", "jwt_algorithm", "transport_url", "acme_hosts", "acme_cert_dir", "subscriber_jwt_key", "log_format", "jwt_key", "allow_anonymous", "debug", "read_timeout", "publisher_jwt_algorithm", "write_timeout", "key_file", "use_forwarded_headers", "subscriber_jwt_algorithm", "addr", "publisher_jwt_key", "heartbeat_interval", "cors_allowed_origins", "publish_allowed_origins"}, v.AllKeys())
}

func TestInitConfig(t *testing.T) {
os.Setenv("JWT_KEY", "foo")
defer os.Unsetenv("JWT_KEY")

v := viper.New()
InitConfig(v)

assert.Equal(t, "foo", v.GetString("jwt_key"))
}
4 changes: 2 additions & 2 deletions hub/demo.go
Expand Up @@ -9,11 +9,11 @@ import (
"time"
)

// demo exposes UNSECURE demo endpoints to test discovery and authorization mechanisms
// Demo exposes INSECURE Demo endpoints to test discovery and authorization mechanisms
// add a query parameter named "body" to define the content to return in the response's body
// add a query parameter named "jwt" set a "mercureAuthorization" cookie containing this token
// the Content-Type header will automatically be set according to the URL's extension
func demo(w http.ResponseWriter, r *http.Request) {
func Demo(w http.ResponseWriter, r *http.Request) {
// JSON-LD is the preferred format
mime.AddExtensionType(".jsonld", "application/ld+json")

Expand Down
4 changes: 2 additions & 2 deletions hub/demo_test.go
Expand Up @@ -12,7 +12,7 @@ import (
func TestEmptyBodyAndJWT(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/demo/foo.jsonld", nil)
w := httptest.NewRecorder()
demo(w, req)
Demo(w, req)

resp := w.Result()
assert.Equal(t, "application/ld+json", resp.Header.Get("Content-Type"))
Expand All @@ -30,7 +30,7 @@ func TestEmptyBodyAndJWT(t *testing.T) {
func TestBodyAndJWT(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/demo/foo/bar.xml?body=<hello/>&jwt=token", nil)
w := httptest.NewRecorder()
demo(w, req)
Demo(w, req)

resp := w.Result()
assert.Equal(t, "application/xml", resp.Header.Get("Content-Type"))
Expand Down
16 changes: 16 additions & 0 deletions hub/hub_test.go
@@ -1,6 +1,8 @@
package hub

import (
"os"
"os/exec"
"testing"
"time"

Expand Down Expand Up @@ -46,6 +48,20 @@ func TestNewHubTransportValidationError(t *testing.T) {
assert.Error(t, err)
}

func TestStartCrash(t *testing.T) {
if os.Getenv("BE_START_CRASH") == "1" {
Start()
return
}
cmd := exec.Command(os.Args[0], "-test.run=TestStartCrash") //nolint:gosec
cmd.Env = append(os.Environ(), "BE_START_CRASH=1")
err := cmd.Run()

e, ok := err.(*exec.ExitError)
require.True(t, ok)
assert.False(t, e.Success())
}

func createDummy() *Hub {
v := viper.New()
SetConfigDefaults(v)
Expand Down
2 changes: 0 additions & 2 deletions hub/log.go
Expand Up @@ -51,9 +51,7 @@ func InitLogrus() {
switch viper.GetString("log_format") {
case "JSON":
log.SetFormatter(&log.JSONFormatter{})
break
case "FLUENTD":
log.SetFormatter(fluentd.NewFormatter())
break
}
}
2 changes: 1 addition & 1 deletion hub/server.go
Expand Up @@ -105,7 +105,7 @@ func (h *Hub) chainHandlers(acmeHosts []string) http.Handler {
r.HandleFunc(defaultHubURL, h.SubscribeHandler).Methods("GET", "HEAD")
r.HandleFunc(defaultHubURL, h.PublishHandler).Methods("POST")
if h.config.GetBool("demo") {
r.PathPrefix("/demo").HandlerFunc(demo).Methods("GET", "HEAD")
r.PathPrefix("/demo").HandlerFunc(Demo).Methods("GET", "HEAD")
r.PathPrefix("/").Handler(http.FileServer(http.Dir("public")))
} else {
r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
Expand Down

0 comments on commit 79dc536

Please sign in to comment.