Skip to content

Commit

Permalink
Fix linter
Browse files Browse the repository at this point in the history
  • Loading branch information
dunglas committed May 6, 2020
1 parent d926204 commit 8cdb7e8
Show file tree
Hide file tree
Showing 12 changed files with 63 additions and 38 deletions.
21 changes: 15 additions & 6 deletions hub/authorization.go
Expand Up @@ -31,6 +31,15 @@ const (
publisherRole
)

var (
ErrInvalidAuthorizationHeader = errors.New(`invalid "Authorization" HTTP header`)
ErrNoOrigin = errors.New(`an "Origin" or a "Referer" HTTP header must be present to use the cookie-based authorization mechanism`)
ErrOriginNotAllowed = errors.New("origin not allowed to post updates")
ErrUnexpectedSigningMethod = errors.New("unexpected signing method")
ErrInvalidJWT = errors.New("invalid JWT")
ErrPublicKey = errors.New("public key error")
)

func (h *Hub) getJWTKey(r role) []byte {
var configKey string
switch r {
Expand Down Expand Up @@ -79,7 +88,7 @@ func authorize(r *http.Request, jwtKey []byte, jwtSigningAlgorithm jwt.SigningMe
authorizationHeaders, headerExists := r.Header["Authorization"]
if headerExists {
if len(authorizationHeaders) != 1 || len(authorizationHeaders[0]) < 48 || authorizationHeaders[0][:7] != "Bearer " {
return nil, errors.New("invalid \"Authorization\" HTTP header")
return nil, ErrInvalidAuthorizationHeader
}

return validateJWT(authorizationHeaders[0][7:], jwtKey, jwtSigningAlgorithm)
Expand All @@ -101,7 +110,7 @@ func authorize(r *http.Request, jwtKey []byte, jwtSigningAlgorithm jwt.SigningMe
// Try to extract the origin from the Referer, or return an error
referer := r.Header.Get("Referer")
if referer == "" {
return nil, errors.New("an \"Origin\" or a \"Referer\" HTTP header must be present to use the cookie-based authorization mechanism")
return nil, ErrNoOrigin
}

u, err := url.Parse(referer)
Expand All @@ -118,7 +127,7 @@ func authorize(r *http.Request, jwtKey []byte, jwtSigningAlgorithm jwt.SigningMe
}
}

return nil, fmt.Errorf("the origin \"%s\" is not allowed to post updates", origin)
return nil, fmt.Errorf("%q: %w", origin, ErrOriginNotAllowed)
}

// validateJWT validates that the provided JWT token is a valid Mercure token.
Expand All @@ -131,7 +140,7 @@ func validateJWT(encodedToken string, key []byte, signingAlgorithm jwt.SigningMe
block, _ := pem.Decode(key)

if block == nil {
return nil, errors.New("public key error")
return nil, ErrPublicKey
}

pubInterface, err := x509.ParsePKIXPublicKey(block.Bytes)
Expand All @@ -145,7 +154,7 @@ func validateJWT(encodedToken string, key []byte, signingAlgorithm jwt.SigningMe
return pub, nil
}

return nil, fmt.Errorf("unexpected signing method: %T", signingAlgorithm)
return nil, fmt.Errorf("%T: %w", signingAlgorithm, ErrUnexpectedSigningMethod)
})

if err != nil {
Expand All @@ -156,7 +165,7 @@ func validateJWT(encodedToken string, key []byte, signingAlgorithm jwt.SigningMe
return claims, nil
}

return nil, errors.New("invalid JWT")
return nil, ErrInvalidJWT
}

func authorizedTargets(claims *claims, publisher bool) (all bool, targets map[string]struct{}) {
Expand Down
10 changes: 5 additions & 5 deletions hub/authorization_test.go
Expand Up @@ -163,7 +163,7 @@ func TestAuthorizeAuthorizationHeaderWrongAlgorithm(t *testing.T) {
r.Header.Add("Authorization", "Bearer "+validFullHeaderRsa)

claims, err := authorize(r, []byte(publicKeyRsa), nil, []string{})
assert.EqualError(t, err, "unexpected signing method: <nil>")
assert.EqualError(t, err, "<nil>: unexpected signing method")
assert.Nil(t, claims)
}

Expand Down Expand Up @@ -267,7 +267,7 @@ func TestAuthorizeCookieOriginNotAllowed(t *testing.T) {
r.AddCookie(&http.Cookie{Name: "mercureAuthorization", Value: validFullHeader})

claims, err := authorize(r, []byte("!ChangeMe!"), hmacSigningMethod, []string{"http://example.net"})
assert.EqualError(t, err, "the origin \"http://example.com\" is not allowed to post updates")
assert.EqualError(t, err, `"http://example.com": origin not allowed to post updates`)
assert.Nil(t, claims)
}

Expand All @@ -277,7 +277,7 @@ func TestAuthorizeCookieOriginNotAllowedRsa(t *testing.T) {
r.AddCookie(&http.Cookie{Name: "mercureAuthorization", Value: validFullHeaderRsa})

claims, err := authorize(r, []byte(publicKeyRsa), rsaSigningMethod, []string{"http://example.net"})
assert.EqualError(t, err, "the origin \"http://example.com\" is not allowed to post updates")
assert.EqualError(t, err, `"http://example.com": origin not allowed to post updates`)
assert.Nil(t, claims)
}

Expand All @@ -287,7 +287,7 @@ func TestAuthorizeCookieRefererNotAllowed(t *testing.T) {
r.AddCookie(&http.Cookie{Name: "mercureAuthorization", Value: validFullHeader})

claims, err := authorize(r, []byte("!ChangeMe!"), hmacSigningMethod, []string{"http://example.net"})
assert.EqualError(t, err, "the origin \"http://example.com\" is not allowed to post updates")
assert.EqualError(t, err, `"http://example.com": origin not allowed to post updates`)
assert.Nil(t, claims)
}

Expand All @@ -297,7 +297,7 @@ func TestAuthorizeCookieRefererNotAllowedRsa(t *testing.T) {
r.AddCookie(&http.Cookie{Name: "mercureAuthorization", Value: validFullHeaderRsa})

claims, err := authorize(r, []byte(publicKeyRsa), rsaSigningMethod, []string{"http://example.net"})
assert.EqualError(t, err, "the origin \"http://example.com\" is not allowed to post updates")
assert.EqualError(t, err, `"http://example.com": origin not allowed to post updates`)
assert.Nil(t, claims)
}

Expand Down
18 changes: 10 additions & 8 deletions hub/bolt_transport.go
Expand Up @@ -43,18 +43,20 @@ func NewBoltTransport(u *url.URL, bufferSize int, bufferFullTimeout time.Duratio
}

size := uint64(0)
if q.Get("size") != "" {
size, err = strconv.ParseUint(q.Get("size"), 10, 64)
sizeParameter := q.Get("size")
if sizeParameter != "" {
size, err = strconv.ParseUint(sizeParameter, 10, 64)
if err != nil {
return nil, fmt.Errorf(`invalid bolt "%s" dsn: parameter size: %w`, u, err)
return nil, fmt.Errorf(`%q: invalid "size" parameter %q: %s: %w`, u, sizeParameter, err, ErrInvalidTransportDSN)
}
}

cleanupFrequency := 0.3
if q.Get("cleanup_frequency") != "" {
cleanupFrequency, err = strconv.ParseFloat(q.Get("cleanup_frequency"), 64)
cleanupFrequencyParameter := q.Get("cleanup_frequency")
if cleanupFrequencyParameter != "" {
cleanupFrequency, err = strconv.ParseFloat(cleanupFrequencyParameter, 64)
if err != nil {
return nil, fmt.Errorf(`invalid bolt "%s" dsn: parameter cleanup_frequency: %w`, u, err)
return nil, fmt.Errorf(`%q: invalid "cleanup_frequency" parameter %q: %w`, u, cleanupFrequencyParameter, ErrInvalidTransportDSN)
}
}

Expand All @@ -63,12 +65,12 @@ func NewBoltTransport(u *url.URL, bufferSize int, bufferFullTimeout time.Duratio
path = u.Host // relative path (bolt://path.db)
}
if path == "" {
return nil, fmt.Errorf(`invalid bolt DSN "%s": missing path`, u)
return nil, fmt.Errorf(`%q: missing path: %w`, u, ErrInvalidTransportDSN)
}

db, err := bolt.Open(path, 0600, nil)
if err != nil {
return nil, fmt.Errorf(`invalid bolt DSN "%s": %w`, u, err)
return nil, fmt.Errorf(`%q: %s: %w`, u, err, ErrInvalidTransportDSN)
}

return &BoltTransport{
Expand Down
8 changes: 4 additions & 4 deletions hub/bolt_transport_test.go
Expand Up @@ -107,21 +107,21 @@ func TestNewBoltTransport(t *testing.T) {

u, _ = url.Parse("bolt://")
_, err = NewBoltTransport(u, 5, time.Second)
assert.EqualError(t, err, `invalid bolt DSN "bolt:": missing path`)
assert.EqualError(t, err, `"bolt:": missing path: invalid transport DSN`)

u, _ = url.Parse("bolt:///test.db")
_, err = NewBoltTransport(u, 5, time.Second)

// The exact error message depends of the OS
assert.Contains(t, err.Error(), `invalid bolt DSN "bolt:///test.db": open /test.db: `)
assert.Contains(t, err.Error(), "open /test.db:")

u, _ = url.Parse("bolt://test.db?cleanup_frequency=invalid")
_, err = NewBoltTransport(u, 5, time.Second)
assert.EqualError(t, err, `invalid bolt "bolt://test.db?cleanup_frequency=invalid" dsn: parameter cleanup_frequency: strconv.ParseFloat: parsing "invalid": invalid syntax`)
assert.EqualError(t, err, `"bolt://test.db?cleanup_frequency=invalid": invalid "cleanup_frequency" parameter "invalid": invalid transport DSN`)

u, _ = url.Parse("bolt://test.db?size=invalid")
_, err = NewBoltTransport(u, 5, time.Second)
assert.EqualError(t, err, `invalid bolt "bolt://test.db?size=invalid" dsn: parameter size: strconv.ParseUint: parsing "invalid": invalid syntax`)
assert.EqualError(t, err, `"bolt://test.db?size=invalid": invalid "size" parameter "invalid": strconv.ParseUint: parsing "invalid": invalid syntax: invalid transport DSN`)
}

func TestBoltTransportWriteIsNotDispatchedUntilListen(t *testing.T) {
Expand Down
9 changes: 6 additions & 3 deletions hub/config.go
@@ -1,6 +1,7 @@
package hub

import (
"errors"
"fmt"
"os"
"strings"
Expand All @@ -10,6 +11,8 @@ import (
"github.com/spf13/viper"
)

var ErrInvalidConfig = errors.New("invalid config")

// SetConfigDefaults sets defaults on a Viper instance.
func SetConfigDefaults(v *viper.Viper) {
v.SetDefault("debug", false)
Expand All @@ -33,13 +36,13 @@ func SetConfigDefaults(v *viper.Viper) {
// ValidateConfig validates a Viper instance.
func ValidateConfig(v *viper.Viper) error {
if v.GetString("publisher_jwt_key") == "" && v.GetString("jwt_key") == "" {
return fmt.Errorf(`one of "jwt_key" or "publisher_jwt_key" configuration parameter must be defined`)
return fmt.Errorf(`%w: one of "jwt_key" or "publisher_jwt_key" configuration parameter must be defined`, ErrInvalidConfig)
}
if v.GetString("cert_file") != "" && v.GetString("key_file") == "" {
return fmt.Errorf(`if the "cert_file" configuration parameter is defined, "key_file" must be defined too`)
return fmt.Errorf(`%w: if the "cert_file" configuration parameter is defined, "key_file" must be defined too`, ErrInvalidConfig)
}
if v.GetString("key_file") != "" && v.GetString("cert_file") == "" {
return fmt.Errorf(`if the "key_file" configuration parameter is defined, "cert_file" must be defined too`)
return fmt.Errorf(`%w: if the "key_file" configuration parameter is defined, "cert_file" must be defined too`, ErrInvalidConfig)
}
return nil
}
Expand Down
6 changes: 3 additions & 3 deletions hub/config_test.go
Expand Up @@ -11,7 +11,7 @@ import (

func TestMissingConfig(t *testing.T) {
err := ValidateConfig(viper.New())
assert.EqualError(t, err, `one of "jwt_key" or "publisher_jwt_key" configuration parameter must be defined`)
assert.EqualError(t, err, `invalid config: one of "jwt_key" or "publisher_jwt_key" configuration parameter must be defined`)
}

func TestMissingKeyFile(t *testing.T) {
Expand All @@ -20,7 +20,7 @@ func TestMissingKeyFile(t *testing.T) {
v.Set("cert_file", "foo")

err := ValidateConfig(v)
assert.EqualError(t, err, `if the "cert_file" configuration parameter is defined, "key_file" must be defined too`)
assert.EqualError(t, err, `invalid config: if the "cert_file" configuration parameter is defined, "key_file" must be defined too`)
}

func TestMissingCertFile(t *testing.T) {
Expand All @@ -29,7 +29,7 @@ func TestMissingCertFile(t *testing.T) {
v.Set("key_file", "foo")

err := ValidateConfig(v)
assert.EqualError(t, err, `if the "key_file" configuration parameter is defined, "cert_file" must be defined too`)
assert.EqualError(t, err, `invalid config: if the "key_file" configuration parameter is defined, "cert_file" must be defined too`)
}

func TestSetFlags(t *testing.T) {
Expand Down
5 changes: 4 additions & 1 deletion hub/publish.go
Expand Up @@ -2,6 +2,7 @@ package hub

import (
"errors"
"fmt"
"io"
"net/http"
"strconv"
Expand All @@ -10,6 +11,8 @@ import (
log "github.com/sirupsen/logrus"
)

var ErrTargetNotAuthorized = errors.New("target not authorized")

func (h *Hub) dispatch(u *Update) error {
if u.ID == "" {
u.ID = uuid.Must(uuid.NewV4()).String()
Expand Down Expand Up @@ -84,7 +87,7 @@ func getAuthorizedTargets(claims *claims, t []string) (map[string]struct{}, erro
if !authorizedAlltargets {
_, ok := authorizedTargets[t]
if !ok {
return nil, errors.New("Target " + t + " is not authorized")
return nil, fmt.Errorf("%q: %w", t, ErrTargetNotAuthorized)
}
}
targets[t] = struct{}{}
Expand Down
3 changes: 2 additions & 1 deletion hub/server.go
Expand Up @@ -2,6 +2,7 @@ package hub

import (
"context"
"errors"
"fmt"
"net/http"
"os"
Expand Down Expand Up @@ -60,7 +61,7 @@ func (h *Hub) Serve() {
err = h.server.ListenAndServeTLS(certFile, keyFile)
}

if err != http.ErrServerClosed {
if !errors.Is(err, http.ErrServerClosed) {
log.Fatal(err)
}

Expand Down
3 changes: 2 additions & 1 deletion hub/subscribe.go
Expand Up @@ -3,6 +3,7 @@ package hub
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -55,7 +56,7 @@ func (h *Hub) SubscribeHandler(w http.ResponseWriter, r *http.Request) {
// Listen to the closing of the http connection via the Request's Context
return
case <-ctx.Done():
if ctx.Err() == context.DeadlineExceeded {
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
// Send a SSE comment as a heartbeat, to prevent issues with some proxies and old browsers
fmt.Fprint(w, ":\n")
f.Flush()
Expand Down
6 changes: 4 additions & 2 deletions hub/subscribe_test.go
Expand Up @@ -2,7 +2,7 @@ package hub

import (
"context"
"fmt"
"errors"
"io/ioutil"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -154,6 +154,8 @@ func TestSubscribeNoTopic(t *testing.T) {
assert.Equal(t, "Missing \"topic\" parameter.\n", w.Body.String())
}

var errFailedToCreatePipe = errors.New("failed to create a pipe")

type createPipeErrorTransport struct {
}

Expand All @@ -162,7 +164,7 @@ func (*createPipeErrorTransport) Write(update *Update) error {
}

func (*createPipeErrorTransport) CreatePipe(fromID string) (*Pipe, error) {
return nil, fmt.Errorf("Failed to create a pipe")
return nil, errFailedToCreatePipe
}

func (*createPipeErrorTransport) Close() error {
Expand Down
10 changes: 7 additions & 3 deletions hub/transport.go
Expand Up @@ -22,8 +22,12 @@ type Transport interface {
Close() error
}

// ErrClosedTransport is returned by the Transport's Write and CreatePipe methods after a call to Close.
var ErrClosedTransport = errors.New("hub: read/write on closed Transport")
var (
// ErrInvalidTransportDSN is returned when the Transport's DSN is invalid
ErrInvalidTransportDSN = errors.New("invalid transport DSN")
// ErrClosedTransport is returned by the Transport's Dispatch and AddSubscriber methods after a call to Close.
ErrClosedTransport = errors.New("hub: read/write on closed Transport")
)

// NewTransport create a transport using the backend matching the given TransportURL.
func NewTransport(config *viper.Viper) (Transport, error) {
Expand All @@ -47,7 +51,7 @@ func NewTransport(config *viper.Viper) (Transport, error) {
return NewBoltTransport(u, bs, bt)
}

return nil, fmt.Errorf(`no Transport available for "%s"`, tu)
return nil, fmt.Errorf("%q: no such transport available: %w", tu, ErrInvalidTransportDSN)
}

// LocalTransport implements the TransportInterface without database and simply broadcast the live Updates.
Expand Down
2 changes: 1 addition & 1 deletion hub/transport_test.go
Expand Up @@ -175,7 +175,7 @@ func TestNewTransport(t *testing.T) {
transport, err = NewTransport(v)
assert.Nil(t, transport)
assert.NotNil(t, err)
assert.EqualError(t, err, `no Transport available for "nothing:"`)
assert.EqualError(t, err, `"nothing:": no such transport available: invalid transport DSN`)

v = viper.New()
v.Set("transport_url", "http://[::1]%23")
Expand Down

0 comments on commit 8cdb7e8

Please sign in to comment.