Skip to content

Commit

Permalink
add optional key cache to JWT middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
m90 committed Aug 16, 2019
1 parent 30a7d5d commit c6830ec
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 9 deletions.
3 changes: 2 additions & 1 deletion kms/router/router.go
Expand Up @@ -71,9 +71,10 @@ func New(opts ...Config) http.Handler {

decrypt := m.PathPrefix("/decrypt").Subrouter()
if rt.jwtPublicKey != "" {
keyCache := httputil.NewDefaultKeyCache(httputil.DefaultCacheExpiry)
auth := httputil.JWTProtect(rt.jwtPublicKey, "auth", "", func(*http.Request, map[string]interface{}) error {
return nil
})
}, keyCache)
decrypt.Use(auth)
}

Expand Down
5 changes: 3 additions & 2 deletions server/router/router.go
Expand Up @@ -168,8 +168,9 @@ func New(opts ...Config) http.Handler {
exchange.HandleFunc("", rt.getPublicKey).Methods(http.MethodGet)
exchange.HandleFunc("", rt.postUserSecret).Methods(http.MethodPost)

getAuth := httputil.JWTProtect(rt.jwtPublicKey, authKey, authHeader, getAuthorizer)
postAuth := httputil.JWTProtect(rt.jwtPublicKey, authKey, authHeader, postAuthorizer)
keyCache := httputil.NewDefaultKeyCache(httputil.DefaultCacheExpiry)
getAuth := httputil.JWTProtect(rt.jwtPublicKey, authKey, authHeader, getAuthorizer, keyCache)
postAuth := httputil.JWTProtect(rt.jwtPublicKey, authKey, authHeader, postAuthorizer, keyCache)
accounts := m.PathPrefix("/accounts").Subrouter()
accounts.Handle("", getAuth(http.HandlerFunc(rt.getAccount))).Methods(http.MethodGet)
accounts.Handle("", postAuth(http.HandlerFunc(rt.postAccount))).Methods(http.MethodPost)
Expand Down
65 changes: 60 additions & 5 deletions shared/http/jwt.go
Expand Up @@ -9,6 +9,7 @@ import (
"fmt"
"net/http"
"strings"
"time"

"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwt"
Expand All @@ -22,7 +23,7 @@ const ClaimsContextKey contextKey = "claims"
// JWTProtect uses the public key located at the given URL to check if the
// cookie value is signed properly. In case yes, the JWT claims will be added
// to the request context
func JWTProtect(keyURL, cookieName, headerName string, authorizer func(*http.Request, map[string]interface{}) error) func(http.Handler) http.Handler {
func JWTProtect(keyURL, cookieName, headerName string, authorizer func(*http.Request, map[string]interface{}) error, cache Cache) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var jwtValue string
Expand All @@ -38,10 +39,23 @@ func JWTProtect(keyURL, cookieName, headerName string, authorizer func(*http.Req
return
}

keys, keysErr := fetchKeys(keyURL)
if keysErr != nil {
RespondWithJSONError(w, fmt.Errorf("jwt: error fetching keys: %v", keysErr), http.StatusInternalServerError)
return
var keys [][]byte
if cache != nil {
lookup, lookupErr := cache.Get()
if lookupErr == nil {
keys = lookup
}
}
if keys == nil {
var keysErr error
keys, keysErr = fetchKeys(keyURL)
if keysErr != nil {
RespondWithJSONError(w, fmt.Errorf("jwt: error fetching keys: %v", keysErr), http.StatusInternalServerError)
return
}
if cache != nil {
cache.Set(keys)
}
}

var token *jwt.Token
Expand Down Expand Up @@ -117,3 +131,44 @@ func fetchKeys(keyURL string) ([][]byte, error) {
}
return asBytes, nil
}

// Cache can be implemented by consumers in order to define how requests
// for public keys are being cached. For most use cases, the default cache
// supplied by this package will suffice.
type Cache interface {
Get() ([][]byte, error)
Set([][]byte)
}

type defaultCache struct {
value *[][]byte
expires time.Duration
deadline time.Time
}

// DefaultCacheExpiry should be used by cache instantiations without
// any particular requirements.
const DefaultCacheExpiry = time.Minute * 15

// ErrNoCache is returned on a cache lookup that did not yield a result
var ErrNoCache = errors.New("nothing found in cache")

func (c *defaultCache) Get() ([][]byte, error) {
if c.value != nil && time.Now().Before(c.deadline) {
return *c.value, nil
}
return nil, ErrNoCache
}

func (c *defaultCache) Set(value [][]byte) {
c.deadline = time.Now().Add(c.expires)
c.value = &value
}

// NewDefaultKeyCache creates a simple cache that will hold a single
// value for the given expiration time
func NewDefaultKeyCache(expires time.Duration) Cache {
return &defaultCache{
expires: expires,
}
}
2 changes: 1 addition & 1 deletion shared/http/jwt_test.go
Expand Up @@ -257,7 +257,7 @@ func TestJWTProtect(t *testing.T) {
if test.server != nil {
url = test.server.URL
}
wrappedHandler := JWTProtect(url, "auth", "X-RPC-Authentication", test.authorizer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wrappedHandler := JWTProtect(url, "auth", "X-RPC-Authentication", test.authorizer, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("OK"))
}))
w := httptest.NewRecorder()
Expand Down

0 comments on commit c6830ec

Please sign in to comment.