Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor JWT key resolver #832

Merged
merged 1 commit into from
Oct 17, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions middleware/security/jwt/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package jwt

import (
"golang.org/x/net/context"

jwt "github.com/dgrijalva/jwt-go"
)

type contextKey int

const (
jwtKey contextKey = iota + 1
)

// WithJWT creates a child context containing the given JWT.
func WithJWT(ctx context.Context, t *jwt.Token) context.Context {
return context.WithValue(ctx, jwtKey, t)
}

// ContextJWT retrieves the JWT token from a `context` that went through our security middleware.
func ContextJWT(ctx context.Context) *jwt.Token {
token, ok := ctx.Value(jwtKey).(*jwt.Token)
if !ok {
return nil
}
return token
}
26 changes: 26 additions & 0 deletions middleware/security/jwt/error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package jwt

import (
"errors"

"github.com/goadesign/goa"
)

var (
// ErrEmptyHeaderName is returned when the header value given to the standard key resolver
// constructor is empty.
ErrEmptyHeaderName = errors.New("header name must not be empty")

// ErrInvalidKey is returned when a key is not of type string, []string, *rsa.PublicKey or
// []*rsa.PublicKey.
ErrInvalidKey = errors.New("invalid parameter, the only keys accepted " +
"are *rsa.publicKey, []*rsa.PublicKey (for RSA-based algorithms) or a " +
"signing secret string, []string (for HS algorithms)")

// ErrKeyDoesNotExist is returned when a key cannot be found by the provided key name.
ErrKeyDoesNotExist = errors.New("key does not exist")

// ErrJWTError is the error returned by this middleware when any sort of validation or
// assertion fails during processing.
ErrJWTError = goa.NewErrorClass("jwt_security_error", 401)
)
252 changes: 33 additions & 219 deletions middleware/security/jwt/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,153 +2,16 @@ package jwt

import (
"crypto/rsa"
"errors"
"fmt"
"net/http"
"sort"
"strings"
"sync"

jwt "github.com/dgrijalva/jwt-go"
"github.com/goadesign/goa"
"golang.org/x/net/context"
)

// ErrInvalidKey is returned when a key is not of type string, []string,
// *rsa.PublicKey or []*rsa.PublicKey.
var ErrInvalidKey = errors.New("invalid parameter, the only keys accepted " +
"are *rsa.publicKey, []*rsa.PublicKey (for RSA-based algorithms) or a " +
"signing secret string, []string (for HS algorithms)")

// ErrKeyDoesNotExist is returned when a key cannot be found by the provided
// key name.
var ErrKeyDoesNotExist = errors.New("key does not exist")

// KeyResolver is a struct that is passed into the New() function, which allows
// the user to add/remove keys from the jwt goa.middleware. The use of a
// resolver provides for better scalability/performance as the number of
// valid keys grows. If an incoming http.Request contains the header field
// "jwtkeyname", then the handler will only attempt to validate the incoming
// JWT against the keys stored in the resolver under that name. Otherwise,
// the handler will attempt to validate the incoming JWT against all keys
// stored in the resolver.
type KeyResolver struct {
*sync.RWMutex
jwtKeyNameField string
keyMap map[string][]interface{}
}

// AddKeys can be used to add keys to the resolver which will be referenced
// by the provided name. Acceptable types for keys include string, []string,
// *rsa.PublicKey or []*rsa.PublicKey. Multiple keys are allowed for a single
// key name to allow for key rotation.
func (kr *KeyResolver) AddKeys(name string, keys interface{}) error {
kr.Lock()
defer kr.Unlock()
switch keys := keys.(type) {
case *rsa.PublicKey:
kr.keyMap[name] = append(kr.keyMap[name], keys)
case []*rsa.PublicKey:
for _, key := range keys {
kr.keyMap[name] = append(kr.keyMap[name], key)
}
case string:
kr.keyMap[name] = append(kr.keyMap[name], keys)
case []string:
for _, key := range keys {
kr.keyMap[name] = append(kr.keyMap[name], key)
}
default:
return ErrInvalidKey
}
return nil
}

// RemoveAllKeys removes all keys from the resolver.
func (kr *KeyResolver) RemoveAllKeys() {
kr.Lock()
defer kr.Unlock()
kr.keyMap = make(map[string][]interface{})
return
}

// RemoveKeys removes all keys from the resolver stored under the provided name.
func (kr *KeyResolver) RemoveKeys(name string) {
kr.Lock()
defer kr.Unlock()
delete(kr.keyMap, name)
return
}

// RemoveKey removes only the provided key stored under the provided name from
// the resolver.
func (kr *KeyResolver) RemoveKey(name string, key interface{}) {
kr.Lock()
defer kr.Unlock()
if keys, ok := kr.keyMap[name]; ok {
for i, keyItem := range keys {
if keyItem == key {
kr.keyMap[name] = append(keys[:i], keys[i+1:]...)
}
}
}
return
}

// GetAllKeys returns a list of all the keys stored in the resolver.
func (kr *KeyResolver) GetAllKeys() []interface{} {
kr.RLock()
defer kr.RUnlock()
var keys []interface{}
for name := range kr.keyMap {
for _, key := range kr.keyMap[name] {
keys = append(keys, key)
}
}
return keys
}

// GetKeys returns a list of all the keys stored in the resolver under the
// provided name.
func (kr *KeyResolver) GetKeys(name string) ([]interface{}, error) {
kr.RLock()
defer kr.RUnlock()
if keys, ok := kr.keyMap[name]; ok {
return keys, nil
}
return nil, ErrKeyDoesNotExist
}

// NewResolver returns a KeyResolver populated with the provided map of key
// names to key lists. NewResolver will also set the HTTP header param name
// (jwtKeyNameField) to use for reading the JWT key name from HTTP requests.
func NewResolver(validationKeys map[string][]interface{},
jwtKeyNameField string) (*KeyResolver, error) {
keyMap := make(map[string][]interface{})
for name := range validationKeys {
for _, keys := range validationKeys[name] {
switch keys := keys.(type) {
case *rsa.PublicKey:
keyMap[name] = append(keyMap[name], keys)
case []*rsa.PublicKey:
for _, key := range keys {
keyMap[name] = append(keyMap[name], key)
}
case string:
keyMap[name] = append(keyMap[name], keys)
case []string:
for _, key := range keys {
keyMap[name] = append(keyMap[name], key)
}
default:
return nil, ErrInvalidKey
}
}
}
return &KeyResolver{RWMutex: &sync.RWMutex{}, keyMap: keyMap,
jwtKeyNameField: jwtKeyNameField}, nil
}

// New returns a middleware to be used with the JWTSecurity DSL definitions of goa. It supports the
// scopes claim in the JWT and ensures goa-defined Security DSLs are properly validated.
//
Expand All @@ -163,16 +26,15 @@ func NewResolver(validationKeys map[string][]interface{},
//
// validationKeys can be one of these:
//
// * a single string
// * a single []byte
// * a list of string
// * a list of []byte
// * a single rsa.PublicKey
// * a list of rsa.PublicKey
// * a single string
// * a slice of []byte
// * a slice of string
// * a single *rsa.PublicKey
// * a slice of *rsa.PublicKey
//
// The type of the keys determine the algorithms that will be used to do the check. The goal of
// having lists of keys is to allow for key rotation, still check the previous keys until rotation
// has been completed.
// Keys of type string or []byte are interepreted according to the signing method defined in the JWT
// token (HMAC, RSA, etc.).
//
// You can define an optional function to do additional validations on the token once the signature
// and the claims requirements are proven to be valid. Example:
Expand All @@ -187,10 +49,10 @@ func NewResolver(validationKeys map[string][]interface{},
// Mount the middleware with the generated UseXX function where XX is the name of the scheme as
// defined in the design, e.g.:
//
// jwtResolver, _ := jwt.NewResolver("secret")
// jwtResolver, _ := jwt.NewSimpleResolver("secret")
// app.UseJWT(jwt.New(jwtResolver, validationHandler, app.NewJWTSecurity()))
//
func New(resolver *KeyResolver, validationFunc goa.Middleware, scheme *goa.JWTSecurity) goa.Middleware {
func New(resolver KeyResolver, validationFunc goa.Middleware, scheme *goa.JWTSecurity) goa.Middleware {
return func(nextHandler goa.Handler) goa.Handler {
return func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error {
// TODO: implement the QUERY string handler too
Expand All @@ -208,24 +70,28 @@ func New(resolver *KeyResolver, validationFunc goa.Middleware, scheme *goa.JWTSe

incomingToken := strings.Split(val, " ")[1]

var jwtKeyName string
if resolver.jwtKeyNameField != "" {
jwtKeyName = req.Header.Get(resolver.jwtKeyNameField)
}

// Make a copy of the current keys in the KeyResolver's key map
resolver.RLock()
keyMap := make(map[string][]interface{})
for name, keys := range resolver.keyMap {
keyMap[name] = keys
var (
rsaKeys []*rsa.PublicKey
hmacKeys [][]byte
keys = resolver.SelectKeys(req)
)
{
for _, key := range keys {
switch k := key.(type) {
case *rsa.PublicKey:
rsaKeys = append(rsaKeys, k)
case []byte:
hmacKeys = append(hmacKeys, k)
case string:
hmacKeys = append(hmacKeys, []byte(k))
}
}
}
resolver.RUnlock()

rsaKeys, hmacKeys := getKeys(jwtKeyName, keyMap)

var token *jwt.Token
var err error
validated := false
var (
token *jwt.Token
err error
validated = false
)

if len(rsaKeys) > 0 {
token, err = validateRSAKeys(rsaKeys, "RS", incomingToken)
Expand All @@ -242,7 +108,7 @@ func New(resolver *KeyResolver, validationFunc goa.Middleware, scheme *goa.JWTSe
}

if !validated {
return ErrJWTError(fmt.Sprint("JWT validation failed"))
return ErrJWTError("JWT validation failed")
}

scopesInClaim, scopesInClaimList, err := parseClaimScopes(token)
Expand Down Expand Up @@ -303,30 +169,6 @@ func parseClaimScopes(token *jwt.Token) (map[string]bool, []string, error) {
return scopesInClaim, scopesInClaimList, nil
}

// ErrJWTError is the error returned by this middleware when any sort of validation or assertion
// fails during processing.
var ErrJWTError = goa.NewErrorClass("jwt_security_error", 401)

type contextKey int

const (
jwtKey contextKey = iota + 1
)

// WithJWT creates a child context containing the given JWT.
func WithJWT(ctx context.Context, t *jwt.Token) context.Context {
return context.WithValue(ctx, jwtKey, t)
}

// ContextJWT retrieves the JWT token from a `context` that went through our security middleware.
func ContextJWT(ctx context.Context) *jwt.Token {
token, ok := ctx.Value(jwtKey).(*jwt.Token)
if !ok {
return nil
}
return token
}

func validateRSAKeys(rsaKeys []*rsa.PublicKey, algo, incomingToken string) (token *jwt.Token, err error) {
for _, pubkey := range rsaKeys {
token, err = jwt.Parse(incomingToken, func(token *jwt.Token) (interface{}, error) {
Expand All @@ -342,45 +184,17 @@ func validateRSAKeys(rsaKeys []*rsa.PublicKey, algo, incomingToken string) (toke
return
}

func validateHMACKeys(hmacKeys []string, algo, incomingToken string) (token *jwt.Token, err error) {
func validateHMACKeys(hmacKeys [][]byte, algo, incomingToken string) (token *jwt.Token, err error) {
for _, key := range hmacKeys {
token, err = jwt.Parse(incomingToken, func(token *jwt.Token) (interface{}, error) {
if !strings.HasPrefix(token.Method.Alg(), algo) {
return nil, ErrJWTError(fmt.Sprintf("Unexpected signing method: %v", token.Header["alg"]))
}
return []byte(key), nil
return key, nil
})
if err == nil {
return
}
}
return
}

func getKeys(jwtKeyName string, keyMap map[string][]interface{}) (
rsaKeys []*rsa.PublicKey, hmacKeys []string) {
// if jwtKeyName is a non-empty string, we will include only keys
// under that name for validation, otherwise we will try all keys.
if jwtKeyName != "" {
for _, key := range keyMap[jwtKeyName] {
switch key.(type) {
case *rsa.PublicKey:
rsaKeys = append(rsaKeys, key.(*rsa.PublicKey))
case string:
hmacKeys = append(hmacKeys, key.(string))
}
}
} else {
for _, keyList := range keyMap {
for _, key := range keyList {
switch key.(type) {
case *rsa.PublicKey:
rsaKeys = append(rsaKeys, key.(*rsa.PublicKey))
case string:
hmacKeys = append(hmacKeys, key.(string))
}
}
}
}
return
}
Loading