Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for arrays of keys to be returned by KeyFunc. #328

Closed
wants to merge 1 commit into from

Conversation

schmidtw
Copy link

Based on feedback from PR 170 I've attempted to create the functionality of supporting a set of multiple keys in an easy to use and maintain way.

@oxisto
Copy link
Collaborator

oxisto commented Aug 3, 2023

Based on feedback from PR 170 I've attempted to create the functionality of supporting a set of multiple keys in an easy to use and maintain way.

Thanks! I will allocate some time on the weekend to look at this in detail.

@schmidtw
Copy link
Author

schmidtw commented Aug 8, 2023

Any thought?

@oxisto
Copy link
Collaborator

oxisto commented Aug 8, 2023

Any thought?

Sorry for the delay, still a little bit swamped at work. I will get to it soon, I promise.

Copy link
Collaborator

@oxisto oxisto left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems fine from an implementation point of view, although it is a little bit of extra reflect stuff we wanted to avoid, but it seems neceesary.

I am still torn if this is really a good idea cramming even more magic into the Keyfunc. What's your opinion on this @mfridman?

I wonder if we could have something that is more "explicit", for example a dedicated Keyset structure like this? (patch is based on this PR):

diff --git a/parser.go b/parser.go
index 00d360a..c4bca4a 100644
--- a/parser.go
+++ b/parser.go
@@ -5,7 +5,6 @@ import (
 	"encoding/base64"
 	"encoding/json"
 	"fmt"
-	"reflect"
 	"strings"
 )
 
@@ -81,34 +80,6 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
 		return token, newError("no keyfunc was provided", ErrTokenUnverifiable)
 	}
 
-	keys := make([]interface{}, 1)
-	// Convert the key or list of keys into a list of keys.
-	{
-		got, err := keyFunc(token)
-		if err != nil {
-			return token, newError("error while executing keyfunc", ErrTokenUnverifiable, err)
-		}
-
-		switch have := got.(type) {
-		case []interface{}:
-			keys = have
-		case []byte, []int8: // HMAC is an outlier, so treat it specially.
-			keys[0] = have
-		case interface{}:
-			typ := reflect.TypeOf(have)
-			switch typ.Kind() {
-			case reflect.Array, reflect.Slice:
-				val := reflect.ValueOf(have)
-				keys = make([]interface{}, val.Len())
-				for i := 0; i < val.Len(); i++ {
-					keys[i] = val.Index(i).Interface()
-				}
-			default:
-				keys[0] = have
-			}
-		}
-	}
-
 	// Decode signature
 	token.Signature, err = p.DecodeSegment(parts[2])
 	if err != nil {
@@ -117,16 +88,30 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
 
 	text := strings.Join(parts[0:2], ".")
 
-	// Assume there is an error until proven otherwise because an empty array of
-	// keys means no checks are performed.
-	err = ErrTokenSignatureInvalid
-	for _, key := range keys {
-		// Perform signature validation, skipping the rest when a match is found.
-		err = token.Method.Verify(text, token.Signature, key)
+	got, err := keyFunc(token)
+	if err != nil {
+		return token, newError("error while executing keyfunc", ErrTokenUnverifiable, err)
+	}
+
+	switch have := got.(type) {
+	case Keyset:
+		// Assume there is an error until proven otherwise because an empty array of
+		// keys means no checks are performed.
+		err = ErrTokenSignatureInvalid
+		for _, key := range have.Keys {
+			// Perform signature validation, skipping the rest when a match is found.
+			err = token.Method.Verify(text, token.Signature, key)
+			if err == nil {
+				break
+			}
+		}
+	default:
+		err = token.Method.Verify(text, token.Signature, have)
 		if err == nil {
 			break
 		}
 	}
+
 	// If the only key or last key checked failed, then it's an error.
 	if err != nil {
 		return token, newError("", ErrTokenSignatureInvalid, err)
diff --git a/parser_test.go b/parser_test.go
index 83e4b3f..19df154 100644
--- a/parser_test.go
+++ b/parser_test.go
@@ -31,13 +31,13 @@ var (
 	multipleZeroKeyFunc    jwt.Keyfunc = func(t *jwt.Token) (interface{}, error) { return []interface{}{}, nil }
 	multipleEmptyKeyFunc   jwt.Keyfunc = func(t *jwt.Token) (interface{}, error) { return []interface{}{nil, nil}, nil }
 	multipleLastKeyFunc    jwt.Keyfunc = func(t *jwt.Token) (interface{}, error) {
-		return []interface{}{jwtTestEC256PublicKey, jwtTestDefaultKey}, nil
+		return jwt.Keyset{Keys: []jwt.Key{jwtTestEC256PublicKey, jwtTestDefaultKey}}, nil
 	}
 	multipleFirstKeyFunc jwt.Keyfunc = func(t *jwt.Token) (interface{}, error) {
-		return []interface{}{jwtTestDefaultKey, jwtTestEC256PublicKey}, nil
+		return jwt.Keyset{Keys: []jwt.Key{jwtTestDefaultKey, jwtTestEC256PublicKey}}, nil
 	}
 	multipleAltTypedKeyFunc jwt.Keyfunc = func(t *jwt.Token) (interface{}, error) {
-		return []*rsa.PublicKey{jwtTestDefaultKey, jwtTestDefaultKey}, nil
+		return jwt.Keyset{Keys: []jwt.Key{jwtTestDefaultKey, jwtTestDefaultKey}}, nil
 	}
 )
 
diff --git a/token.go b/token.go
index bf3e4e8..1361dfa 100644
--- a/token.go
+++ b/token.go
@@ -1,6 +1,7 @@
 package jwt
 
 import (
+	"crypto"
 	"encoding/base64"
 	"encoding/json"
 )
@@ -13,6 +14,14 @@ import (
 // an array of keys is returned an []interface{} with mixed types is allowed.
 type Keyfunc func(*Token) (interface{}, error)
 
+type Key interface {
+	crypto.PublicKey | []uint8
+}
+
+type Keyset struct {
+	Keys []Key
+}
+
 // Token represents a JWT Token.  Different fields will be used depending on
 // whether you're creating or parsing/verifying a token.
 type Token struct {

Or maybe forgo the wrapper and just "react" if a keyfunc returns a []jwt.Key


keys := make([]interface{}, 1)
// Convert the key or list of keys into a list of keys.
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any particular reason for this block here? Seems unnecessary and doesn't really fit our existing code style.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really love it either. I tried to encapsulate the key type determination so it was more clear & that it didn't leak any variables that should not be used. I can instead add a function & call to replace it if you'd rather.

@oxisto
Copy link
Collaborator

oxisto commented Aug 14, 2023

Alternatively, "force" the user to use our type with a little bit of an extra function, which I think is actually quite nice in calling, such as:

func(t *jwt.Token) (interface{}, error) {
  return jwt.Keyset(jwtTestEC256PublicKey, jwtTestDefaultKey), nil
}

Patch here:

diff --git a/parser.go b/parser.go
index 00d360a..0c94f1f 100644
--- a/parser.go
+++ b/parser.go
@@ -5,7 +5,6 @@ import (
 	"encoding/base64"
 	"encoding/json"
 	"fmt"
-	"reflect"
 	"strings"
 )
 
@@ -81,34 +80,6 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
 		return token, newError("no keyfunc was provided", ErrTokenUnverifiable)
 	}
 
-	keys := make([]interface{}, 1)
-	// Convert the key or list of keys into a list of keys.
-	{
-		got, err := keyFunc(token)
-		if err != nil {
-			return token, newError("error while executing keyfunc", ErrTokenUnverifiable, err)
-		}
-
-		switch have := got.(type) {
-		case []interface{}:
-			keys = have
-		case []byte, []int8: // HMAC is an outlier, so treat it specially.
-			keys[0] = have
-		case interface{}:
-			typ := reflect.TypeOf(have)
-			switch typ.Kind() {
-			case reflect.Array, reflect.Slice:
-				val := reflect.ValueOf(have)
-				keys = make([]interface{}, val.Len())
-				for i := 0; i < val.Len(); i++ {
-					keys[i] = val.Index(i).Interface()
-				}
-			default:
-				keys[0] = have
-			}
-		}
-	}
-
 	// Decode signature
 	token.Signature, err = p.DecodeSegment(parts[2])
 	if err != nil {
@@ -117,16 +88,30 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
 
 	text := strings.Join(parts[0:2], ".")
 
-	// Assume there is an error until proven otherwise because an empty array of
-	// keys means no checks are performed.
-	err = ErrTokenSignatureInvalid
-	for _, key := range keys {
-		// Perform signature validation, skipping the rest when a match is found.
-		err = token.Method.Verify(text, token.Signature, key)
+	got, err := keyFunc(token)
+	if err != nil {
+		return token, newError("error while executing keyfunc", ErrTokenUnverifiable, err)
+	}
+
+	switch have := got.(type) {
+	case []Key:
+		// Assume there is an error until proven otherwise because an empty array of
+		// keys means no checks are performed.
+		err = ErrTokenSignatureInvalid
+		for _, key := range have {
+			// Perform signature validation, skipping the rest when a match is found.
+			err = token.Method.Verify(text, token.Signature, key)
+			if err == nil {
+				break
+			}
+		}
+	default:
+		err = token.Method.Verify(text, token.Signature, have)
 		if err == nil {
 			break
 		}
 	}
+
 	// If the only key or last key checked failed, then it's an error.
 	if err != nil {
 		return token, newError("", ErrTokenSignatureInvalid, err)
diff --git a/parser_test.go b/parser_test.go
index 83e4b3f..b8d6ad6 100644
--- a/parser_test.go
+++ b/parser_test.go
@@ -31,13 +31,13 @@ var (
 	multipleZeroKeyFunc    jwt.Keyfunc = func(t *jwt.Token) (interface{}, error) { return []interface{}{}, nil }
 	multipleEmptyKeyFunc   jwt.Keyfunc = func(t *jwt.Token) (interface{}, error) { return []interface{}{nil, nil}, nil }
 	multipleLastKeyFunc    jwt.Keyfunc = func(t *jwt.Token) (interface{}, error) {
-		return []interface{}{jwtTestEC256PublicKey, jwtTestDefaultKey}, nil
+		return jwt.Keyset(jwtTestEC256PublicKey, jwtTestDefaultKey), nil
 	}
 	multipleFirstKeyFunc jwt.Keyfunc = func(t *jwt.Token) (interface{}, error) {
-		return []interface{}{jwtTestDefaultKey, jwtTestEC256PublicKey}, nil
+		return jwt.Keyset(jwtTestDefaultKey, jwtTestEC256PublicKey), nil
 	}
 	multipleAltTypedKeyFunc jwt.Keyfunc = func(t *jwt.Token) (interface{}, error) {
-		return []*rsa.PublicKey{jwtTestDefaultKey, jwtTestDefaultKey}, nil
+		return jwt.Keyset(jwtTestDefaultKey, jwtTestDefaultKey), nil
 	}
 )
 
diff --git a/token.go b/token.go
index bf3e4e8..f5b24e1 100644
--- a/token.go
+++ b/token.go
@@ -1,6 +1,7 @@
 package jwt
 
 import (
+	"crypto"
 	"encoding/base64"
 	"encoding/json"
 )
@@ -13,6 +14,14 @@ import (
 // an array of keys is returned an []interface{} with mixed types is allowed.
 type Keyfunc func(*Token) (interface{}, error)
 
+type Key interface {
+	crypto.PublicKey | []uint8
+}
+
+func Keyset(keys ...Key) []Key {
+	return keys
+}
+
 // Token represents a JWT Token.  Different fields will be used depending on
 // whether you're creating or parsing/verifying a token.
 type Token struct {

@oxisto
Copy link
Collaborator

oxisto commented Aug 14, 2023

At some point we probably also want to type constraint the Keyfunc into returning Key | []Key, but this would be API breaking, but help a lot with what this function "expects".

@mfridman mfridman self-requested a review August 14, 2023 19:42
@oxisto oxisto linked an issue Aug 19, 2023 that may be closed by this pull request
@schmidtw
Copy link
Author

Closing in favor of #344 .

@schmidtw schmidtw closed this Sep 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

KeyFunc should be able to return a slice
2 participants