diff --git a/hub/authorization.go b/hub/authorization.go index 9459a76f..7ea5fae3 100644 --- a/hub/authorization.go +++ b/hub/authorization.go @@ -31,6 +31,15 @@ const ( publisherRole ) +var ( + ErrInvalidAuthorizationHeader = errors.New(`invalid "Authorization" HTTP header`) + ErrNoOrigin = errors.New(`an "Origin" or a "Referer" HTTP header must be present to use the cookie-based authorization mechanism`) + ErrOriginNotAllowed = errors.New("origin not allowed to post updates") + ErrUnexpectedSigningMethod = errors.New("unexpected signing method") + ErrInvalidJWT = errors.New("invalid JWT") + ErrPublicKey = errors.New("public key error") +) + func (h *Hub) getJWTKey(r role) []byte { var configKey string switch r { @@ -79,7 +88,7 @@ func authorize(r *http.Request, jwtKey []byte, jwtSigningAlgorithm jwt.SigningMe authorizationHeaders, headerExists := r.Header["Authorization"] if headerExists { if len(authorizationHeaders) != 1 || len(authorizationHeaders[0]) < 48 || authorizationHeaders[0][:7] != "Bearer " { - return nil, errors.New("invalid \"Authorization\" HTTP header") + return nil, ErrInvalidAuthorizationHeader } return validateJWT(authorizationHeaders[0][7:], jwtKey, jwtSigningAlgorithm) @@ -101,7 +110,7 @@ func authorize(r *http.Request, jwtKey []byte, jwtSigningAlgorithm jwt.SigningMe // Try to extract the origin from the Referer, or return an error referer := r.Header.Get("Referer") if referer == "" { - return nil, errors.New("an \"Origin\" or a \"Referer\" HTTP header must be present to use the cookie-based authorization mechanism") + return nil, ErrNoOrigin } u, err := url.Parse(referer) @@ -118,7 +127,7 @@ func authorize(r *http.Request, jwtKey []byte, jwtSigningAlgorithm jwt.SigningMe } } - return nil, fmt.Errorf("the origin \"%s\" is not allowed to post updates", origin) + return nil, fmt.Errorf("%q: %w", origin, ErrOriginNotAllowed) } // validateJWT validates that the provided JWT token is a valid Mercure token. @@ -131,7 +140,7 @@ func validateJWT(encodedToken string, key []byte, signingAlgorithm jwt.SigningMe block, _ := pem.Decode(key) if block == nil { - return nil, errors.New("public key error") + return nil, ErrPublicKey } pubInterface, err := x509.ParsePKIXPublicKey(block.Bytes) @@ -145,7 +154,7 @@ func validateJWT(encodedToken string, key []byte, signingAlgorithm jwt.SigningMe return pub, nil } - return nil, fmt.Errorf("unexpected signing method: %T", signingAlgorithm) + return nil, fmt.Errorf("%T: %w", signingAlgorithm, ErrUnexpectedSigningMethod) }) if err != nil { @@ -156,7 +165,7 @@ func validateJWT(encodedToken string, key []byte, signingAlgorithm jwt.SigningMe return claims, nil } - return nil, errors.New("invalid JWT") + return nil, ErrInvalidJWT } func authorizedTargets(claims *claims, publisher bool) (all bool, targets map[string]struct{}) { diff --git a/hub/authorization_test.go b/hub/authorization_test.go index 5ec2f7a9..371c8fad 100644 --- a/hub/authorization_test.go +++ b/hub/authorization_test.go @@ -163,7 +163,7 @@ func TestAuthorizeAuthorizationHeaderWrongAlgorithm(t *testing.T) { r.Header.Add("Authorization", "Bearer "+validFullHeaderRsa) claims, err := authorize(r, []byte(publicKeyRsa), nil, []string{}) - assert.EqualError(t, err, "unexpected signing method: ") + assert.EqualError(t, err, ": unexpected signing method") assert.Nil(t, claims) } @@ -267,7 +267,7 @@ func TestAuthorizeCookieOriginNotAllowed(t *testing.T) { r.AddCookie(&http.Cookie{Name: "mercureAuthorization", Value: validFullHeader}) claims, err := authorize(r, []byte("!ChangeMe!"), hmacSigningMethod, []string{"http://example.net"}) - assert.EqualError(t, err, "the origin \"http://example.com\" is not allowed to post updates") + assert.EqualError(t, err, `"http://example.com": origin not allowed to post updates`) assert.Nil(t, claims) } @@ -277,7 +277,7 @@ func TestAuthorizeCookieOriginNotAllowedRsa(t *testing.T) { r.AddCookie(&http.Cookie{Name: "mercureAuthorization", Value: validFullHeaderRsa}) claims, err := authorize(r, []byte(publicKeyRsa), rsaSigningMethod, []string{"http://example.net"}) - assert.EqualError(t, err, "the origin \"http://example.com\" is not allowed to post updates") + assert.EqualError(t, err, `"http://example.com": origin not allowed to post updates`) assert.Nil(t, claims) } @@ -287,7 +287,7 @@ func TestAuthorizeCookieRefererNotAllowed(t *testing.T) { r.AddCookie(&http.Cookie{Name: "mercureAuthorization", Value: validFullHeader}) claims, err := authorize(r, []byte("!ChangeMe!"), hmacSigningMethod, []string{"http://example.net"}) - assert.EqualError(t, err, "the origin \"http://example.com\" is not allowed to post updates") + assert.EqualError(t, err, `"http://example.com": origin not allowed to post updates`) assert.Nil(t, claims) } @@ -297,7 +297,7 @@ func TestAuthorizeCookieRefererNotAllowedRsa(t *testing.T) { r.AddCookie(&http.Cookie{Name: "mercureAuthorization", Value: validFullHeaderRsa}) claims, err := authorize(r, []byte(publicKeyRsa), rsaSigningMethod, []string{"http://example.net"}) - assert.EqualError(t, err, "the origin \"http://example.com\" is not allowed to post updates") + assert.EqualError(t, err, `"http://example.com": origin not allowed to post updates`) assert.Nil(t, claims) } diff --git a/hub/bolt_transport.go b/hub/bolt_transport.go index 20b5997a..b8930300 100644 --- a/hub/bolt_transport.go +++ b/hub/bolt_transport.go @@ -43,18 +43,20 @@ func NewBoltTransport(u *url.URL, bufferSize int, bufferFullTimeout time.Duratio } size := uint64(0) - if q.Get("size") != "" { - size, err = strconv.ParseUint(q.Get("size"), 10, 64) + sizeParameter := q.Get("size") + if sizeParameter != "" { + size, err = strconv.ParseUint(sizeParameter, 10, 64) if err != nil { - return nil, fmt.Errorf(`invalid bolt "%s" dsn: parameter size: %w`, u, err) + return nil, fmt.Errorf(`%q: invalid "size" parameter %q: %s: %w`, u, sizeParameter, err, ErrInvalidTransportDSN) } } cleanupFrequency := 0.3 - if q.Get("cleanup_frequency") != "" { - cleanupFrequency, err = strconv.ParseFloat(q.Get("cleanup_frequency"), 64) + cleanupFrequencyParameter := q.Get("cleanup_frequency") + if cleanupFrequencyParameter != "" { + cleanupFrequency, err = strconv.ParseFloat(cleanupFrequencyParameter, 64) if err != nil { - return nil, fmt.Errorf(`invalid bolt "%s" dsn: parameter cleanup_frequency: %w`, u, err) + return nil, fmt.Errorf(`%q: invalid "cleanup_frequency" parameter %q: %w`, u, cleanupFrequencyParameter, ErrInvalidTransportDSN) } } @@ -63,12 +65,12 @@ func NewBoltTransport(u *url.URL, bufferSize int, bufferFullTimeout time.Duratio path = u.Host // relative path (bolt://path.db) } if path == "" { - return nil, fmt.Errorf(`invalid bolt DSN "%s": missing path`, u) + return nil, fmt.Errorf(`%q: missing path: %w`, u, ErrInvalidTransportDSN) } db, err := bolt.Open(path, 0600, nil) if err != nil { - return nil, fmt.Errorf(`invalid bolt DSN "%s": %w`, u, err) + return nil, fmt.Errorf(`%q: %s: %w`, u, err, ErrInvalidTransportDSN) } return &BoltTransport{ diff --git a/hub/bolt_transport_test.go b/hub/bolt_transport_test.go index 3ff6b94f..847fcec5 100644 --- a/hub/bolt_transport_test.go +++ b/hub/bolt_transport_test.go @@ -107,21 +107,21 @@ func TestNewBoltTransport(t *testing.T) { u, _ = url.Parse("bolt://") _, err = NewBoltTransport(u, 5, time.Second) - assert.EqualError(t, err, `invalid bolt DSN "bolt:": missing path`) + assert.EqualError(t, err, `"bolt:": missing path: invalid transport DSN`) u, _ = url.Parse("bolt:///test.db") _, err = NewBoltTransport(u, 5, time.Second) // The exact error message depends of the OS - assert.Contains(t, err.Error(), `invalid bolt DSN "bolt:///test.db": open /test.db: `) + assert.Contains(t, err.Error(), "open /test.db:") u, _ = url.Parse("bolt://test.db?cleanup_frequency=invalid") _, err = NewBoltTransport(u, 5, time.Second) - assert.EqualError(t, err, `invalid bolt "bolt://test.db?cleanup_frequency=invalid" dsn: parameter cleanup_frequency: strconv.ParseFloat: parsing "invalid": invalid syntax`) + assert.EqualError(t, err, `"bolt://test.db?cleanup_frequency=invalid": invalid "cleanup_frequency" parameter "invalid": invalid transport DSN`) u, _ = url.Parse("bolt://test.db?size=invalid") _, err = NewBoltTransport(u, 5, time.Second) - assert.EqualError(t, err, `invalid bolt "bolt://test.db?size=invalid" dsn: parameter size: strconv.ParseUint: parsing "invalid": invalid syntax`) + assert.EqualError(t, err, `"bolt://test.db?size=invalid": invalid "size" parameter "invalid": strconv.ParseUint: parsing "invalid": invalid syntax: invalid transport DSN`) } func TestBoltTransportWriteIsNotDispatchedUntilListen(t *testing.T) { diff --git a/hub/config.go b/hub/config.go index cb2141f6..e9896070 100644 --- a/hub/config.go +++ b/hub/config.go @@ -1,6 +1,7 @@ package hub import ( + "errors" "fmt" "os" "strings" @@ -10,6 +11,8 @@ import ( "github.com/spf13/viper" ) +var ErrInvalidConfig = errors.New("invalid config") + // SetConfigDefaults sets defaults on a Viper instance. func SetConfigDefaults(v *viper.Viper) { v.SetDefault("debug", false) @@ -33,13 +36,13 @@ func SetConfigDefaults(v *viper.Viper) { // ValidateConfig validates a Viper instance. func ValidateConfig(v *viper.Viper) error { if v.GetString("publisher_jwt_key") == "" && v.GetString("jwt_key") == "" { - return fmt.Errorf(`one of "jwt_key" or "publisher_jwt_key" configuration parameter must be defined`) + return fmt.Errorf(`%w: one of "jwt_key" or "publisher_jwt_key" configuration parameter must be defined`, ErrInvalidConfig) } if v.GetString("cert_file") != "" && v.GetString("key_file") == "" { - return fmt.Errorf(`if the "cert_file" configuration parameter is defined, "key_file" must be defined too`) + return fmt.Errorf(`%w: if the "cert_file" configuration parameter is defined, "key_file" must be defined too`, ErrInvalidConfig) } if v.GetString("key_file") != "" && v.GetString("cert_file") == "" { - return fmt.Errorf(`if the "key_file" configuration parameter is defined, "cert_file" must be defined too`) + return fmt.Errorf(`%w: if the "key_file" configuration parameter is defined, "cert_file" must be defined too`, ErrInvalidConfig) } return nil } diff --git a/hub/config_test.go b/hub/config_test.go index c604e54c..cb377b17 100644 --- a/hub/config_test.go +++ b/hub/config_test.go @@ -11,7 +11,7 @@ import ( func TestMissingConfig(t *testing.T) { err := ValidateConfig(viper.New()) - assert.EqualError(t, err, `one of "jwt_key" or "publisher_jwt_key" configuration parameter must be defined`) + assert.EqualError(t, err, `invalid config: one of "jwt_key" or "publisher_jwt_key" configuration parameter must be defined`) } func TestMissingKeyFile(t *testing.T) { @@ -20,7 +20,7 @@ func TestMissingKeyFile(t *testing.T) { v.Set("cert_file", "foo") err := ValidateConfig(v) - assert.EqualError(t, err, `if the "cert_file" configuration parameter is defined, "key_file" must be defined too`) + assert.EqualError(t, err, `invalid config: if the "cert_file" configuration parameter is defined, "key_file" must be defined too`) } func TestMissingCertFile(t *testing.T) { @@ -29,7 +29,7 @@ func TestMissingCertFile(t *testing.T) { v.Set("key_file", "foo") err := ValidateConfig(v) - assert.EqualError(t, err, `if the "key_file" configuration parameter is defined, "cert_file" must be defined too`) + assert.EqualError(t, err, `invalid config: if the "key_file" configuration parameter is defined, "cert_file" must be defined too`) } func TestSetFlags(t *testing.T) { diff --git a/hub/publish.go b/hub/publish.go index 885a0dec..f583e067 100644 --- a/hub/publish.go +++ b/hub/publish.go @@ -2,6 +2,7 @@ package hub import ( "errors" + "fmt" "io" "net/http" "strconv" @@ -10,6 +11,8 @@ import ( log "github.com/sirupsen/logrus" ) +var ErrTargetNotAuthorized = errors.New("target not authorized") + func (h *Hub) dispatch(u *Update) error { if u.ID == "" { u.ID = uuid.Must(uuid.NewV4()).String() @@ -84,7 +87,7 @@ func getAuthorizedTargets(claims *claims, t []string) (map[string]struct{}, erro if !authorizedAlltargets { _, ok := authorizedTargets[t] if !ok { - return nil, errors.New("Target " + t + " is not authorized") + return nil, fmt.Errorf("%q: %w", t, ErrTargetNotAuthorized) } } targets[t] = struct{}{} diff --git a/hub/server.go b/hub/server.go index e9c74030..486df301 100644 --- a/hub/server.go +++ b/hub/server.go @@ -2,6 +2,7 @@ package hub import ( "context" + "errors" "fmt" "net/http" "os" @@ -60,7 +61,7 @@ func (h *Hub) Serve() { err = h.server.ListenAndServeTLS(certFile, keyFile) } - if err != http.ErrServerClosed { + if !errors.Is(err, http.ErrServerClosed) { log.Fatal(err) } diff --git a/hub/subscribe.go b/hub/subscribe.go index d8efed9d..1eb23078 100644 --- a/hub/subscribe.go +++ b/hub/subscribe.go @@ -3,6 +3,7 @@ package hub import ( "context" "encoding/json" + "errors" "fmt" "io" "net" @@ -55,7 +56,7 @@ func (h *Hub) SubscribeHandler(w http.ResponseWriter, r *http.Request) { // Listen to the closing of the http connection via the Request's Context return case <-ctx.Done(): - if ctx.Err() == context.DeadlineExceeded { + if errors.Is(ctx.Err(), context.DeadlineExceeded) { // Send a SSE comment as a heartbeat, to prevent issues with some proxies and old browsers fmt.Fprint(w, ":\n") f.Flush() diff --git a/hub/subscribe_test.go b/hub/subscribe_test.go index a0c64fc8..88fadaa0 100644 --- a/hub/subscribe_test.go +++ b/hub/subscribe_test.go @@ -2,7 +2,7 @@ package hub import ( "context" - "fmt" + "errors" "io/ioutil" "net/http" "net/http/httptest" @@ -154,6 +154,8 @@ func TestSubscribeNoTopic(t *testing.T) { assert.Equal(t, "Missing \"topic\" parameter.\n", w.Body.String()) } +var errFailedToCreatePipe = errors.New("failed to create a pipe") + type createPipeErrorTransport struct { } @@ -162,7 +164,7 @@ func (*createPipeErrorTransport) Write(update *Update) error { } func (*createPipeErrorTransport) CreatePipe(fromID string) (*Pipe, error) { - return nil, fmt.Errorf("Failed to create a pipe") + return nil, errFailedToCreatePipe } func (*createPipeErrorTransport) Close() error { diff --git a/hub/transport.go b/hub/transport.go index b4b48982..cbf6baa1 100644 --- a/hub/transport.go +++ b/hub/transport.go @@ -22,8 +22,12 @@ type Transport interface { Close() error } -// ErrClosedTransport is returned by the Transport's Write and CreatePipe methods after a call to Close. -var ErrClosedTransport = errors.New("hub: read/write on closed Transport") +var ( + // ErrInvalidTransportDSN is returned when the Transport's DSN is invalid + ErrInvalidTransportDSN = errors.New("invalid transport DSN") + // ErrClosedTransport is returned by the Transport's Dispatch and AddSubscriber methods after a call to Close. + ErrClosedTransport = errors.New("hub: read/write on closed Transport") +) // NewTransport create a transport using the backend matching the given TransportURL. func NewTransport(config *viper.Viper) (Transport, error) { @@ -47,7 +51,7 @@ func NewTransport(config *viper.Viper) (Transport, error) { return NewBoltTransport(u, bs, bt) } - return nil, fmt.Errorf(`no Transport available for "%s"`, tu) + return nil, fmt.Errorf("%q: no such transport available: %w", tu, ErrInvalidTransportDSN) } // LocalTransport implements the TransportInterface without database and simply broadcast the live Updates. diff --git a/hub/transport_test.go b/hub/transport_test.go index a5c2c38c..89b3fa84 100644 --- a/hub/transport_test.go +++ b/hub/transport_test.go @@ -175,7 +175,7 @@ func TestNewTransport(t *testing.T) { transport, err = NewTransport(v) assert.Nil(t, transport) assert.NotNil(t, err) - assert.EqualError(t, err, `no Transport available for "nothing:"`) + assert.EqualError(t, err, `"nothing:": no such transport available: invalid transport DSN`) v = viper.New() v.Set("transport_url", "http://[::1]%23")