From c6830ec8192961e49d07b1294f8ca350e0678e73 Mon Sep 17 00:00:00 2001 From: Frederik Ring Date: Fri, 16 Aug 2019 09:28:14 +0200 Subject: [PATCH] add optional key cache to JWT middleware --- kms/router/router.go | 3 +- server/router/router.go | 5 ++-- shared/http/jwt.go | 65 +++++++++++++++++++++++++++++++++++++---- shared/http/jwt_test.go | 2 +- 4 files changed, 66 insertions(+), 9 deletions(-) diff --git a/kms/router/router.go b/kms/router/router.go index 57a107f15..a8f933eb6 100644 --- a/kms/router/router.go +++ b/kms/router/router.go @@ -71,9 +71,10 @@ func New(opts ...Config) http.Handler { decrypt := m.PathPrefix("/decrypt").Subrouter() if rt.jwtPublicKey != "" { + keyCache := httputil.NewDefaultKeyCache(httputil.DefaultCacheExpiry) auth := httputil.JWTProtect(rt.jwtPublicKey, "auth", "", func(*http.Request, map[string]interface{}) error { return nil - }) + }, keyCache) decrypt.Use(auth) } diff --git a/server/router/router.go b/server/router/router.go index 35a1cd8db..904e0bfa9 100644 --- a/server/router/router.go +++ b/server/router/router.go @@ -168,8 +168,9 @@ func New(opts ...Config) http.Handler { exchange.HandleFunc("", rt.getPublicKey).Methods(http.MethodGet) exchange.HandleFunc("", rt.postUserSecret).Methods(http.MethodPost) - getAuth := httputil.JWTProtect(rt.jwtPublicKey, authKey, authHeader, getAuthorizer) - postAuth := httputil.JWTProtect(rt.jwtPublicKey, authKey, authHeader, postAuthorizer) + keyCache := httputil.NewDefaultKeyCache(httputil.DefaultCacheExpiry) + getAuth := httputil.JWTProtect(rt.jwtPublicKey, authKey, authHeader, getAuthorizer, keyCache) + postAuth := httputil.JWTProtect(rt.jwtPublicKey, authKey, authHeader, postAuthorizer, keyCache) accounts := m.PathPrefix("/accounts").Subrouter() accounts.Handle("", getAuth(http.HandlerFunc(rt.getAccount))).Methods(http.MethodGet) accounts.Handle("", postAuth(http.HandlerFunc(rt.postAccount))).Methods(http.MethodPost) diff --git a/shared/http/jwt.go b/shared/http/jwt.go index 986fb6ccd..bb2dfddb2 100644 --- a/shared/http/jwt.go +++ b/shared/http/jwt.go @@ -9,6 +9,7 @@ import ( "fmt" "net/http" "strings" + "time" "github.com/lestrrat-go/jwx/jwa" "github.com/lestrrat-go/jwx/jwt" @@ -22,7 +23,7 @@ const ClaimsContextKey contextKey = "claims" // JWTProtect uses the public key located at the given URL to check if the // cookie value is signed properly. In case yes, the JWT claims will be added // to the request context -func JWTProtect(keyURL, cookieName, headerName string, authorizer func(*http.Request, map[string]interface{}) error) func(http.Handler) http.Handler { +func JWTProtect(keyURL, cookieName, headerName string, authorizer func(*http.Request, map[string]interface{}) error, cache Cache) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var jwtValue string @@ -38,10 +39,23 @@ func JWTProtect(keyURL, cookieName, headerName string, authorizer func(*http.Req return } - keys, keysErr := fetchKeys(keyURL) - if keysErr != nil { - RespondWithJSONError(w, fmt.Errorf("jwt: error fetching keys: %v", keysErr), http.StatusInternalServerError) - return + var keys [][]byte + if cache != nil { + lookup, lookupErr := cache.Get() + if lookupErr == nil { + keys = lookup + } + } + if keys == nil { + var keysErr error + keys, keysErr = fetchKeys(keyURL) + if keysErr != nil { + RespondWithJSONError(w, fmt.Errorf("jwt: error fetching keys: %v", keysErr), http.StatusInternalServerError) + return + } + if cache != nil { + cache.Set(keys) + } } var token *jwt.Token @@ -117,3 +131,44 @@ func fetchKeys(keyURL string) ([][]byte, error) { } return asBytes, nil } + +// Cache can be implemented by consumers in order to define how requests +// for public keys are being cached. For most use cases, the default cache +// supplied by this package will suffice. +type Cache interface { + Get() ([][]byte, error) + Set([][]byte) +} + +type defaultCache struct { + value *[][]byte + expires time.Duration + deadline time.Time +} + +// DefaultCacheExpiry should be used by cache instantiations without +// any particular requirements. +const DefaultCacheExpiry = time.Minute * 15 + +// ErrNoCache is returned on a cache lookup that did not yield a result +var ErrNoCache = errors.New("nothing found in cache") + +func (c *defaultCache) Get() ([][]byte, error) { + if c.value != nil && time.Now().Before(c.deadline) { + return *c.value, nil + } + return nil, ErrNoCache +} + +func (c *defaultCache) Set(value [][]byte) { + c.deadline = time.Now().Add(c.expires) + c.value = &value +} + +// NewDefaultKeyCache creates a simple cache that will hold a single +// value for the given expiration time +func NewDefaultKeyCache(expires time.Duration) Cache { + return &defaultCache{ + expires: expires, + } +} diff --git a/shared/http/jwt_test.go b/shared/http/jwt_test.go index fe62b4e67..4826efddf 100644 --- a/shared/http/jwt_test.go +++ b/shared/http/jwt_test.go @@ -257,7 +257,7 @@ func TestJWTProtect(t *testing.T) { if test.server != nil { url = test.server.URL } - wrappedHandler := JWTProtect(url, "auth", "X-RPC-Authentication", test.authorizer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wrappedHandler := JWTProtect(url, "auth", "X-RPC-Authentication", test.authorizer, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("OK")) })) w := httptest.NewRecorder()