From e74cbe9f2eb54e71e129db4a306dcbdf82d8f7a4 Mon Sep 17 00:00:00 2001 From: Dave Grijalva Date: Tue, 21 May 2019 14:30:14 -0700 Subject: [PATCH 1/5] WIP on incorporating support for Go2 errors (and backward compatible xerrors package) --- errors.go | 51 ++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 46 insertions(+), 5 deletions(-) diff --git a/errors.go b/errors.go index 94b4e041..aa5380b3 100644 --- a/errors.go +++ b/errors.go @@ -2,16 +2,51 @@ package jwt import ( "errors" + "fmt" "time" ) +// Copied from xerrors, for compatibility without requiring the xerrors package +type errorPrinter interface { + Print(args ...interface{}) + Printf(format string, args ...interface{}) + Detail() bool +} + // Error constants var ( ErrInvalidKey = errors.New("key is invalid") - ErrInvalidKeyType = errors.New("key is of invalid type") + ErrInvalidKeyType = NewInvalidKeyTypeError("", "") ErrHashUnavailable = errors.New("the requested hash function is unavailable") ) +type InvalidKeyTypeError struct { + expected, received string +} + +func (e *InvalidKeyTypeError) Error() string { + if e.expected == "" && e.received == "" { + return "key is of invalid type" + } + return fmt.Sprintf("key is of invalid type: expected %v, received %v", e.Unwrap(), e.expected, e.received) +} + +func (e *InvalidKeyTypeError) Format(f fmt.State, c rune) { + if c == '+' { + f.Write([]byte(e.Error())) + } else { + f.Write([]byte(ErrInvalidKeyType.Error())) + } +} + +func (e *InvalidKeyTypeError) Unwrap() error { + return ErrInvalidKeyType +} + +func NewInvalidKeyTypeError(expected, received string) error { + return &InvalidKeyTypeError{expected, received} +} + // The errors that might occur when parsing and validating a token const ( ValidationErrorMalformed uint32 = 1 << iota // Token is malformed @@ -21,7 +56,6 @@ const ( // Standard Claim validation errors ValidationErrorAudience // AUD validation failed ValidationErrorExpired // EXP validation failed - ValidationErrorIssuedAt // IAT validation failed ValidationErrorIssuer // ISS validation failed ValidationErrorNotValidYet // NBF validation failed ValidationErrorID // JTI validation failed @@ -38,14 +72,14 @@ func NewValidationError(errorText string, errorFlags uint32) *ValidationError { // ValidationError is the error from Parse if token is not valid type ValidationError struct { - Inner error // stores the error returned by external dependencies, i.e.: KeyFunc + inner error // stores the error returned by external dependencies, i.e.: KeyFunc Errors uint32 // bitfield. see ValidationError... constants text string // errors that do not have a valid error just have text } // Validation error is an error type func (e ValidationError) Error() string { - if e.Inner != nil { + if e.inner != nil { return e.Inner.Error() } else if e.text != "" { return e.text @@ -54,6 +88,11 @@ func (e ValidationError) Error() string { } } +// Unwrap implements xerrors.Wrapper +func (e ValidationError) Unwrap() error { + return e.inner +} + // No errors func (e *ValidationError) valid() bool { return e.Errors == 0 @@ -62,7 +101,7 @@ func (e *ValidationError) valid() bool { // ExpiredError allows the caller to know the delta between now and the expired time and the unvalidated claims. // A client system may have a bug that doesn't refresh a token in time, or there may be clock skew so this information can help you understand. type ExpiredError struct { - Now int64 + Now time.Time ExpiredBy time.Duration Claims } @@ -70,3 +109,5 @@ type ExpiredError struct { func (e *ExpiredError) Error() string { return "Token is expired" } + +type MultiError []error From 0acf824c354aed13047a96ab17c6344e563b5cd3 Mon Sep 17 00:00:00 2001 From: Dave Grijalva Date: Thu, 23 May 2019 16:06:21 -0700 Subject: [PATCH 2/5] rewrite of errors to be compatible with upcoming changes and xerrors --- claims.go | 12 +-- claims_test.go | 26 ++---- ecdsa.go | 23 ++--- errors.go | 207 ++++++++++++++++++++++++++++--------------- errors_new.go | 7 ++ example_test.go | 21 +++-- go.mod | 2 + go.sum | 2 + hmac.go | 21 ++--- map_claims.go | 12 +-- none.go | 7 +- parser.go | 44 ++++----- parser_option.go | 2 - parser_test.go | 53 ++++++----- rsa.go | 11 +-- rsa_pss.go | 11 +-- validation_helper.go | 10 +-- 17 files changed, 256 insertions(+), 215 deletions(-) create mode 100644 errors_new.go diff --git a/claims.go b/claims.go index 14656e01..4e4b0b65 100644 --- a/claims.go +++ b/claims.go @@ -33,24 +33,18 @@ type StandardClaims struct { // As well, if any of the above claims are not in the token, it will still // be considered a valid claim. func (c StandardClaims) Valid(h *ValidationHelper) error { - vErr := new(ValidationError) + var vErr error if h == nil { h = DefaultValidationHelper } if err := h.ValidateExpiresAt(c.ExpiresAt); err != nil { - vErr.Inner = err - vErr.Errors |= ValidationErrorExpired + vErr = wrap(err, vErr) } if err := h.ValidateNotBefore(c.NotBefore); err != nil { - vErr.Inner = err - vErr.Errors |= ValidationErrorNotValidYet - } - - if vErr.valid() { - return nil + vErr = wrap(err, vErr) } return vErr diff --git a/claims_test.go b/claims_test.go index 4d538539..9a689e0e 100644 --- a/claims_test.go +++ b/claims_test.go @@ -8,6 +8,7 @@ import ( "github.com/dgrijalva/jwt-go/v4" "github.com/dgrijalva/jwt-go/v4/test" + "golang.org/x/xerrors" ) const ( @@ -49,25 +50,16 @@ func TestClaimValidExpired(t *testing.T) { if err == nil { t.Errorf("[%v] Expecting error. Didn't get one.", name) } else { - ve := err.(*jwt.ValidationError) - // compare the bitfield part of the error - if e := ve.Errors; e != jwt.ValidationErrorExpired { - t.Errorf("[%v] Errors don't match expectation. %v != %v", name, e, jwt.ValidationErrorExpired) + var expErr *jwt.TokenExpiredError + + if !xerrors.As(err, &expErr) { + t.Errorf("[%v] Expected error to unwrap as *jwt.TokenExpiredError but it didn't", name) + return } - switch vi := ve.Inner.(type) { - default: - expectedErrorStr := "token is expired by 1m40s" - if fmt.Sprint(ve.Inner.Error()) != expectedErrorStr { - t.Errorf("[%v] Errors inner text is not as expected. \"%v\" is not \"%v\"", name, ve.Inner, expectedErrorStr) - } - case *jwt.ExpiredError: - if vi.ExpiredBy != 100*time.Second { - t.Errorf("[%v] ExpiredError.ExpiredBy %v is not %v\n", name, vi.ExpiredBy, 100*time.Second) - } - if vi.Error() != "Token is expired" { - t.Errorf("[%v] Error message is not as expected \"%v\"\n", name, vi.Error()) - } + expectedErrorStr := "token is expired by 1m40s" + if expErr.Error() != expectedErrorStr { + t.Errorf("[%v] Error message is not as expected \"%v\" != \"%v\"", name, expErr.Error(), expectedErrorStr) } } }) diff --git a/ecdsa.go b/ecdsa.go index 79a03112..9a0d6083 100644 --- a/ecdsa.go +++ b/ecdsa.go @@ -5,15 +5,10 @@ import ( "crypto/ecdsa" "crypto/rand" "encoding/asn1" - "errors" + "fmt" "math/big" ) -// Errors returned by ecdsa signing method -var ( - ErrECDSAVerification = errors.New("crypto/ecdsa: verification error") -) - // SigningMethodECDSA implements the ECDSA family of signing methods signing methods // Expects *ecdsa.PrivateKey for signing and *ecdsa.PublicKey for verification type SigningMethodECDSA struct { @@ -81,14 +76,14 @@ func (m *SigningMethodECDSA) Verify(signingString, signature string, key interfa case crypto.Signer: pub := k.Public() if ecdsaKey, ok = pub.(*ecdsa.PublicKey); !ok { - return ErrInvalidKeyType + return &InvalidKeyError{Message: fmt.Sprintf("crypto.Signer returned an unexpected public key type: %T", pub)} } default: - return ErrInvalidKeyType + return NewInvalidKeyTypeError("*ecdsa.PublicKey or crypto.Signer", key) } if len(sig) != 2*m.KeySize { - return ErrECDSAVerification + return &UnverfiableTokenError{Message: "signature length is invalid"} } r := big.NewInt(0).SetBytes(sig[:m.KeySize]) @@ -105,7 +100,7 @@ func (m *SigningMethodECDSA) Verify(signingString, signature string, key interfa if verifystatus := ecdsa.Verify(ecdsaKey, hasher.Sum(nil), r, s); verifystatus == true { return nil } - return ErrECDSAVerification + return new(InvalidSignatureError) } // Sign implements the Sign method from SigningMethod @@ -116,12 +111,12 @@ func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) (string var ok bool if signer, ok = key.(crypto.Signer); !ok { - return "", ErrInvalidKey + return "", NewInvalidKeyTypeError("*ecdsa.PrivateKey or crypto.Signer", key) } //sanity check that the signer is an ecdsa signer if pub, ok = signer.Public().(*ecdsa.PublicKey); !ok { - return "", ErrInvalidKeyType + return "", &InvalidKeyError{Message: fmt.Sprintf("signer returned unexpected public key type: %T", pub)} } // Create the hasher @@ -147,13 +142,13 @@ func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) (string } if len(rest) != 0 { - return "", ErrECDSASignatureUnmarshal + return "", &UnverfiableTokenError{Message: "unexpected extra bytes in ecda signature"} } curveBits := pub.Curve.Params().BitSize if m.CurveBits != curveBits { - return "", ErrInvalidKey + return "", &InvalidKeyError{Message: "CurveBits in public key don't match those in signing method"} } keyBytes := curveBits / 8 diff --git a/errors.go b/errors.go index b04a6f43..aae5518c 100644 --- a/errors.go +++ b/errors.go @@ -1,113 +1,182 @@ package jwt import ( - "errors" "fmt" "time" ) -// Copied from xerrors, for compatibility without requiring the xerrors package -type errorPrinter interface { - Print(args ...interface{}) - Printf(format string, args ...interface{}) - Detail() bool -} - // Error constants var ( - ErrInvalidKey = errors.New("key is invalid") - ErrInvalidKeyType = NewInvalidKeyTypeError("", "") - ErrHashUnavailable = errors.New("the requested hash function is unavailable") - ErrECDSASignatureUnmarshal = errors.New("unexpected extra bytes in ecda signature") + ErrHashUnavailable = new(HashUnavailableError) ) +// Embeds b within a, if a is a valid wrapper. returns a +// If a is not a valid wrapper, b is dropped +func wrap(a, b error) error { + type iErrorWrapper interface { + Wrap(error) + Unwrap() error + } + if w, ok := a.(iErrorWrapper); ok { + w.Wrap(b) + } + return a +} + +// ErrorWrapper provides a simple, concrete helper for implementing nestable errors +type ErrorWrapper struct{ err error } + +// Unwrap implements xerrors.Wrapper +func (w ErrorWrapper) Unwrap() error { + return w.err +} + +// Wrap stores the provided error value and returns it when Unwrap is called +func (w ErrorWrapper) Wrap(err error) { + w.err = err +} + +// InvalidKeyError is returned if the key is unusable for some reason other than type +type InvalidKeyError struct { + Message string + ErrorWrapper +} + +func (e *InvalidKeyError) Error() string { + return fmt.Sprintf("key is invalid: %v", e.Message) +} + +// InvalidKeyTypeError is returned if the key is unusable because it is of an incompatible type type InvalidKeyTypeError struct { - expected, received string + Expected, Received string // String descriptions of expected and received types + ErrorWrapper } func (e *InvalidKeyTypeError) Error() string { - if e.expected == "" && e.received == "" { + if e.Expected == "" && e.Received == "" { return "key is of invalid type" } - return fmt.Sprintf("key is of invalid type: expected %v, received %v", e.Unwrap(), e.expected, e.received) + return fmt.Sprintf("key is of invalid type: expected %v, received %v", e.Expected, e.Received) } -func (e *InvalidKeyTypeError) Format(f fmt.State, c rune) { - if c == '+' { - f.Write([]byte(e.Error())) - } else { - f.Write([]byte(ErrInvalidKeyType.Error())) +// NewInvalidKeyTypeError creates an InvalidKeyTypeError, automatically capturing the type +// of received +func NewInvalidKeyTypeError(expected string, received interface{}) error { + return &InvalidKeyTypeError{Expected: expected, Received: fmt.Sprintf("%T", received)} +} + +type MalformedTokenError struct { + Message string + ErrorWrapper +} + +func (e *MalformedTokenError) Error() string { + if e.Message == "" { + return "token is malformed" } + return fmt.Sprintf("token is malformed: %v", e.Message) } -func (e *InvalidKeyTypeError) Unwrap() error { - return ErrInvalidKeyType +type UnverfiableTokenError struct { + Message string + ErrorWrapper } -func NewInvalidKeyTypeError(expected, received string) error { - return &InvalidKeyTypeError{expected, received} +func (e *UnverfiableTokenError) Error() string { + if e.Message == "" { + return "token is unverifiable" + } + return fmt.Sprintf("token is unverifiable: %v", e.Message) } -// The errors that might occur when parsing and validating a token -const ( - ValidationErrorMalformed uint32 = 1 << iota // Token is malformed - ValidationErrorUnverifiable // Token could not be verified because of signing problems - ValidationErrorSignatureInvalid // Signature validation failed +type InvalidSignatureError struct { + Message string + ErrorWrapper +} - // Standard Claim validation errors - ValidationErrorAudience // AUD validation failed - ValidationErrorExpired // EXP validation failed - ValidationErrorIssuer // ISS validation failed - ValidationErrorNotValidYet // NBF validation failed - ValidationErrorID // JTI validation failed - ValidationErrorClaimsInvalid // Generic claims validation error -) +func (e *InvalidSignatureError) Error() string { + if e.Message == "" { + return "token signature is invalid" + } + return fmt.Sprintf("token signature is invalid: %v", e.Message) +} + +// TokenExpiredError allows the caller to know the delta between now and the expired time and the unvalidated claims. +// A client system may have a bug that doesn't refresh a token in time, or there may be clock skew so this information can help you understand. +type TokenExpiredError struct { + At time.Time // The time at which the exp was evaluated. Includes leeway. + ExpiredBy time.Duration // How long the token had been expired at time of evaluation + ErrorWrapper // Value for unwrapping +} -// NewValidationError is a helper for constructing a ValidationError with a string error message -func NewValidationError(errorText string, errorFlags uint32) *ValidationError { - return &ValidationError{ - text: errorText, - Errors: errorFlags, +func (e *TokenExpiredError) Error() string { + return fmt.Sprintf("token is expired by %v", e.ExpiredBy) +} + +type TokenNotValidYetError struct { + At time.Time // The time at which the exp was evaluated. Includes leeway. + EarlyBy time.Duration // How long the token had been expired at time of evaluation + ErrorWrapper // Value for unwrapping +} + +func (e *TokenNotValidYetError) Error() string { + return fmt.Sprintf("token is not valid yet; wait %v", e.EarlyBy) +} + +type InvalidAudienceError struct { + Message string + ErrorWrapper +} + +func (e *InvalidAudienceError) Error() string { + if e.Message == "" { + return "token audience is invalid" } + return fmt.Sprintf("token audience is invalid: %v", e.Message) } -// ValidationError is the error from Parse if token is not valid -type ValidationError struct { - inner error // stores the error returned by external dependencies, i.e.: KeyFunc - Errors uint32 // bitfield. see ValidationError... constants - text string // errors that do not have a valid error just have text +type InvalidIssuerError struct { + Message string + ErrorWrapper } -// Validation error is an error type -func (e ValidationError) Error() string { - if e.inner != nil { - return e.Inner.Error() - } else if e.text != "" { - return e.text - } else { - return "token is invalid" +func (e *InvalidIssuerError) Error() string { + if e.Message == "" { + return "token issuer is invalid" } + return fmt.Sprintf("token issuer is invalid: %v", e.Message) } -// Unwrap implements xerrors.Wrapper -func (e ValidationError) Unwrap() error { - return e.inner +// InvalidClaimsError is a catchall type for claims errors that don't have their own type +type InvalidClaimsError struct { + Message string + ErrorWrapper +} + +func (e *InvalidClaimsError) Error() string { + if e.Message == "" { + return "token claim is invalid" + } + return fmt.Sprintf("token claim is invalid: %v", e.Message) } -// No errors -func (e *ValidationError) valid() bool { - return e.Errors == 0 +// SigningError is a catchall type for signing errors +type SigningError struct { + Message string + ErrorWrapper } -// ExpiredError allows the caller to know the delta between now and the expired time and the unvalidated claims. -// A client system may have a bug that doesn't refresh a token in time, or there may be clock skew so this information can help you understand. -type ExpiredError struct { - Now time.Time - ExpiredBy time.Duration +func (e *SigningError) Error() string { + if e.Message == "" { + return "error encountered during signing" + } + return fmt.Sprintf("error encountered during signing: %v", e.Message) } -func (e *ExpiredError) Error() string { - return "Token is expired" +type HashUnavailableError struct { + ErrorWrapper } -type MultiError []error +func (e *HashUnavailableError) Error() string { + return "the requested hash function is unavailable" +} diff --git a/errors_new.go b/errors_new.go new file mode 100644 index 00000000..edddc9bd --- /dev/null +++ b/errors_new.go @@ -0,0 +1,7 @@ +// Conditionally adds support for new errors behavior, only where it's the default +// +build go1.13 + +package jwt + +// TODO: add Format and FormatError methods to all error types +// per: https://go.googlesource.com/proposal/+/master/design/29934-error-values.md diff --git a/example_test.go b/example_test.go index 8d270e9d..13d528e3 100644 --- a/example_test.go +++ b/example_test.go @@ -6,6 +6,7 @@ import ( "github.com/dgrijalva/jwt-go/v4" "github.com/dgrijalva/jwt-go/v4/test" + "golang.org/x/xerrors" ) // Example (atypical) using the StandardClaims type by itself to parse a token. @@ -88,17 +89,19 @@ func ExampleParse_errorChecking() { return []byte("AllYourBase"), nil }) + var uErr *jwt.UnverfiableTokenError + var expErr *jwt.TokenExpiredError + var nbfErr *jwt.TokenNotValidYetError + + // Use xerrors.Is to see what kind of error(s) occurred if token.Valid { fmt.Println("You look nice today") - } else if ve, ok := err.(*jwt.ValidationError); ok { - if ve.Errors&jwt.ValidationErrorMalformed != 0 { - fmt.Println("That's not even a token") - } else if ve.Errors&(jwt.ValidationErrorExpired|jwt.ValidationErrorNotValidYet) != 0 { - // Token is either expired or not active yet - fmt.Println("Timing is everything") - } else { - fmt.Println("Couldn't handle this token:", err) - } + } else if xerrors.As(err, &uErr) { + fmt.Println("That's not even a token") + } else if xerrors.As(err, &expErr) { + fmt.Println("Timing is everything") + } else if xerrors.As(err, &nbfErr) { + fmt.Println("Timing is everything") } else { fmt.Println("Couldn't handle this token:", err) } diff --git a/go.mod b/go.mod index 99fa0edb..0248f634 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module github.com/dgrijalva/jwt-go/v4 go 1.12 + +require golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 diff --git a/go.sum b/go.sum index e69de29b..40960221 100644 --- a/go.sum +++ b/go.sum @@ -0,0 +1,2 @@ +golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 h1:bhOzK9QyoD0ogCnFro1m2mz41+Ib0oOhfJnBp5MR4K4= +golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/hmac.go b/hmac.go index 0dc9f85e..2c910125 100644 --- a/hmac.go +++ b/hmac.go @@ -51,7 +51,7 @@ func (m *SigningMethodHMAC) Verify(signingString, signature string, key interfac // Verify the key is the right type keyBytes, ok := key.([]byte) if !ok { - return ErrInvalidKeyType + return NewInvalidKeyTypeError("[]byte", key) } // Decode signature, for comparison @@ -81,16 +81,17 @@ func (m *SigningMethodHMAC) Verify(signingString, signature string, key interfac // Sign implements the Sign method from SigningMethod // Key must be []byte func (m *SigningMethodHMAC) Sign(signingString string, key interface{}) (string, error) { - if keyBytes, ok := key.([]byte); ok { - if !m.Hash.Available() { - return "", ErrHashUnavailable - } - - hasher := hmac.New(m.Hash.New, keyBytes) - hasher.Write([]byte(signingString)) + keyBytes, ok := key.([]byte) + if !ok { + return "", NewInvalidKeyTypeError("[]byte", key) + } - return EncodeSegment(hasher.Sum(nil)), nil + if !m.Hash.Available() { + return "", ErrHashUnavailable } - return "", ErrInvalidKeyType + hasher := hmac.New(m.Hash.New, keyBytes) + hasher.Write([]byte(signingString)) + + return EncodeSegment(hasher.Sum(nil)), nil } diff --git a/map_claims.go b/map_claims.go index f464def7..eda2b7fc 100644 --- a/map_claims.go +++ b/map_claims.go @@ -32,7 +32,7 @@ func (m MapClaims) VerifyIssuer(cmp string, req bool) bool { // As well, if any of the above claims are not in the token, it will still // be considered a valid claim. func (m MapClaims) Valid(h *ValidationHelper) error { - vErr := new(ValidationError) + var vErr error if h == nil { h = DefaultValidationHelper @@ -44,8 +44,7 @@ func (m MapClaims) Valid(h *ValidationHelper) error { } if err = h.ValidateExpiresAt(exp); err != nil { - vErr.Inner = err - vErr.Errors |= ValidationErrorExpired + vErr = wrap(err, vErr) } nbf, err := m.LoadTimeValue("nbf") @@ -54,12 +53,7 @@ func (m MapClaims) Valid(h *ValidationHelper) error { } if err = h.ValidateNotBefore(nbf); err != nil { - vErr.Inner = err - vErr.Errors |= ValidationErrorNotValidYet - } - - if vErr.valid() { - return nil + vErr = wrap(err, vErr) } return vErr diff --git a/none.go b/none.go index 6c1ccba6..a5caed37 100644 --- a/none.go +++ b/none.go @@ -18,7 +18,7 @@ type unsafeNoneMagicConstant string func init() { SigningMethodNone = &signingMethodNone{} - NoneSignatureTypeDisallowedError = NewValidationError("'none' signature type is not allowed", ValidationErrorSignatureInvalid) + NoneSignatureTypeDisallowedError = &InvalidSignatureError{Message: "'none' signature type is not allowed"} RegisterSigningMethod(SigningMethodNone.Alg(), func() SigningMethod { return SigningMethodNone @@ -38,10 +38,7 @@ func (m *signingMethodNone) Verify(signingString, signature string, key interfac } // If signing method is none, signature must be an empty string if signature != "" { - return NewValidationError( - "'none' signing method with non-empty signature", - ValidationErrorSignatureInvalid, - ) + return &InvalidSignatureError{Message: "'none' signing method with non-empty signature"} } // Accept 'none' signing method. diff --git a/parser.go b/parser.go index 539e43cb..9830f652 100644 --- a/parser.go +++ b/parser.go @@ -52,7 +52,7 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf } if !signingMethodValid { // signing method is not in the listed set - return token, NewValidationError(fmt.Sprintf("signing method %v is invalid", alg), ValidationErrorSignatureInvalid) + return token, &UnverfiableTokenError{Message: fmt.Sprintf("signing method %v is invalid", alg)} } } @@ -60,42 +60,30 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf var key interface{} if keyFunc == nil { // keyFunc was not provided. short circuiting validation - return token, NewValidationError("no Keyfunc was provided.", ValidationErrorUnverifiable) + return token, &UnverfiableTokenError{Message: "no Keyfunc was provided."} } if key, err = keyFunc(token); err != nil { // keyFunc returned an error - if ve, ok := err.(*ValidationError); ok { - return token, ve - } - return token, &ValidationError{Inner: err, Errors: ValidationErrorUnverifiable} + return token, wrap(&UnverfiableTokenError{Message: "Keyfunc returned an error"}, err) } - vErr := &ValidationError{} + var vErr error // Perform validation token.Signature = parts[2] if err = token.Method.Verify(strings.Join(parts[0:2], "."), token.Signature, key); err != nil { - vErr.Inner = err - vErr.Errors |= ValidationErrorSignatureInvalid + vErr = wrap(&InvalidSignatureError{}, err) } // Validate Claims - if !p.skipClaimsValidation && vErr.valid() { + if !p.skipClaimsValidation && vErr == nil { if err := token.Claims.Valid(p.ValidationHelper); err != nil { - - // If the Claims Valid returned an error, check if it is a validation error, - // If it was another error type, create a ValidationError with a generic ClaimsInvalid flag set - if e, ok := err.(*ValidationError); !ok { - vErr = &ValidationError{Inner: err, Errors: ValidationErrorClaimsInvalid} - } else { - vErr = e - } + vErr = wrap(err, vErr) } } - if vErr.valid() { + if vErr == nil { token.Valid = true - return token, nil } return token, vErr @@ -111,7 +99,7 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Token, parts []string, err error) { parts = strings.Split(tokenString, ".") if len(parts) != 3 { - return nil, parts, NewValidationError("token contains an invalid number of segments", ValidationErrorMalformed) + return nil, parts, &MalformedTokenError{Message: "token contains an invalid number of segments"} } token = &Token{Raw: tokenString} @@ -120,12 +108,12 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke var headerBytes []byte if headerBytes, err = DecodeSegment(parts[0]); err != nil { if strings.HasPrefix(strings.ToLower(tokenString), "bearer ") { - return token, parts, NewValidationError("tokenstring should not contain 'bearer '", ValidationErrorMalformed) + return token, parts, &MalformedTokenError{Message: "tokenstring should not contain 'bearer '"} } - return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed} + return token, parts, wrap(&MalformedTokenError{Message: "failed to decode token header"}, err) } if err = json.Unmarshal(headerBytes, &token.Header); err != nil { - return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed} + return token, parts, wrap(&MalformedTokenError{Message: "failed to unmarshal token header"}, err) } // parse Claims @@ -133,7 +121,7 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke token.Claims = claims if claimBytes, err = DecodeSegment(parts[1]); err != nil { - return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed} + return token, parts, wrap(&MalformedTokenError{Message: "failed to decode token claims"}, err) } dec := json.NewDecoder(bytes.NewBuffer(claimBytes)) if p.useJSONNumber { @@ -147,16 +135,16 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke } // Handle decode error if err != nil { - return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed} + return token, parts, wrap(&MalformedTokenError{Message: "failed to unmarshal token claims"}, err) } // Lookup signature method if method, ok := token.Header["alg"].(string); ok { if token.Method = GetSigningMethod(method); token.Method == nil { - return token, parts, NewValidationError("signing method (alg) is unavailable.", ValidationErrorUnverifiable) + return token, parts, &UnverfiableTokenError{Message: "signing method (alg) is unavailable."} } } else { - return token, parts, NewValidationError("signing method (alg) is unspecified.", ValidationErrorUnverifiable) + return token, parts, &UnverfiableTokenError{Message: "signing method (alg) is unspecified."} } return token, parts, nil diff --git a/parser_option.go b/parser_option.go index 5fbe4633..6e6a66db 100644 --- a/parser_option.go +++ b/parser_option.go @@ -36,6 +36,4 @@ func WithLeeway(d time.Duration) ParserOption { return func(p *Parser) { p.ValidationHelper.leeway = d } -} - } } diff --git a/parser_test.go b/parser_test.go index a5337cc1..7a066bd1 100644 --- a/parser_test.go +++ b/parser_test.go @@ -10,6 +10,7 @@ import ( "github.com/dgrijalva/jwt-go/v4" "github.com/dgrijalva/jwt-go/v4/test" + "golang.org/x/xerrors" ) var keyFuncError error = fmt.Errorf("error loading key") @@ -32,7 +33,7 @@ var jwtTestData = []struct { keyfunc jwt.Keyfunc claims jwt.Claims valid bool - errors uint32 + errors []error parser *jwt.Parser }{ { @@ -41,7 +42,7 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar"}, true, - 0, + nil, nil, }, { @@ -50,7 +51,7 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar", "exp": float64(time.Now().Unix() - 100)}, false, - jwt.ValidationErrorExpired, + []error{&jwt.TokenExpiredError{}}, nil, }, { @@ -59,7 +60,7 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100)}, false, - jwt.ValidationErrorNotValidYet, + []error{&jwt.TokenNotValidYetError{}}, nil, }, { @@ -68,7 +69,7 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100), "exp": float64(time.Now().Unix() - 100)}, false, - jwt.ValidationErrorNotValidYet | jwt.ValidationErrorExpired, + []error{&jwt.TokenExpiredError{}, &jwt.TokenNotValidYetError{}}, nil, }, { @@ -77,7 +78,7 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 50), "exp": float64(time.Now().Unix() - 50)}, true, - 0, + nil, jwt.NewParser(jwt.WithLeeway(100 * time.Second)), }, { @@ -86,7 +87,7 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar"}, false, - jwt.ValidationErrorSignatureInvalid, + []error{&jwt.InvalidSignatureError{}}, nil, }, { @@ -95,7 +96,7 @@ var jwtTestData = []struct { nilKeyFunc, jwt.MapClaims{"foo": "bar"}, false, - jwt.ValidationErrorUnverifiable, + []error{&jwt.UnverfiableTokenError{}}, nil, }, { @@ -104,7 +105,7 @@ var jwtTestData = []struct { emptyKeyFunc, jwt.MapClaims{"foo": "bar"}, false, - jwt.ValidationErrorSignatureInvalid, + []error{&jwt.InvalidSignatureError{}}, nil, }, { @@ -113,7 +114,7 @@ var jwtTestData = []struct { errorKeyFunc, jwt.MapClaims{"foo": "bar"}, false, - jwt.ValidationErrorUnverifiable, + []error{&jwt.UnverfiableTokenError{}}, nil, }, { @@ -122,7 +123,7 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar"}, false, - jwt.ValidationErrorSignatureInvalid, + []error{&jwt.InvalidSignatureError{}}, jwt.NewParser(jwt.WithValidMethods([]string{"HS256"})), }, { @@ -131,7 +132,7 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar"}, true, - 0, + nil, jwt.NewParser(jwt.WithValidMethods([]string{"RS256", "HS256"})), }, { @@ -140,7 +141,7 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": json.Number("123.4")}, true, - 0, + nil, jwt.NewParser(jwt.WithJSONNumber()), }, { @@ -151,7 +152,7 @@ var jwtTestData = []struct { ExpiresAt: jwt.At(time.Now().Add(time.Second * 10).Truncate(time.Second)), }, true, - 0, + nil, jwt.NewParser(jwt.WithJSONNumber()), }, { @@ -160,7 +161,7 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar", "exp": json.Number(fmt.Sprintf("%v", time.Now().Unix()-100))}, false, - jwt.ValidationErrorExpired, + []error{&jwt.TokenExpiredError{}}, jwt.NewParser(jwt.WithJSONNumber()), }, { @@ -169,7 +170,7 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100))}, false, - jwt.ValidationErrorNotValidYet, + []error{&jwt.TokenNotValidYetError{}}, jwt.NewParser(jwt.WithJSONNumber()), }, { @@ -178,7 +179,7 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100)), "exp": json.Number(fmt.Sprintf("%v", time.Now().Unix()-100))}, false, - jwt.ValidationErrorNotValidYet | jwt.ValidationErrorExpired, + []error{&jwt.TokenExpiredError{}, &jwt.TokenNotValidYetError{}}, jwt.NewParser(jwt.WithJSONNumber()), }, { @@ -187,7 +188,7 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100))}, true, - 0, + nil, jwt.NewParser(jwt.WithJSONNumber(), jwt.WithoutClaimsValidation()), }, } @@ -234,19 +235,15 @@ func TestParser_Parse(t *testing.T) { t.Errorf("[%v] Inconsistent behavior between returned error and token.Valid", data.name) } - if data.errors != 0 { + if data.errors != nil { if err == nil { t.Errorf("[%v] Expecting error. Didn't get one.", data.name) } else { - - ve := err.(*jwt.ValidationError) - // compare the bitfield part of the error - if e := ve.Errors; e != data.errors { - t.Errorf("[%v] Errors don't match expectation. %v != %v", data.name, e, data.errors) - } - - if err.Error() == keyFuncError.Error() && ve.Inner != keyFuncError { - t.Errorf("[%v] Inner error does not match expectation. %v != %v", data.name, ve.Inner, keyFuncError) + for _, expected := range data.errors { + var xxx error = expected + if !xerrors.As(err, &xxx) { + t.Errorf("[%v] Error is expected to match type %T but doesn't", data.name, expected) + } } } } diff --git a/rsa.go b/rsa.go index d1d8870e..72ba2d66 100644 --- a/rsa.go +++ b/rsa.go @@ -4,6 +4,7 @@ import ( "crypto" "crypto/rand" "crypto/rsa" + "fmt" ) // SigningMethodRSA implements the RSA family of signing methods signing methods @@ -65,10 +66,10 @@ func (m *SigningMethodRSA) Verify(signingString, signature string, key interface case crypto.Signer: pub := k.Public() if rsaKey, ok = pub.(*rsa.PublicKey); !ok { - return ErrInvalidKeyType + return &InvalidKeyError{Message: fmt.Sprintf("signer returned unexpected public key type: %T", pub)} } default: - return ErrInvalidKeyType + return NewInvalidKeyTypeError("*rsa.PublicKey or crypto.Signer", key) } // Create hasher @@ -89,12 +90,12 @@ func (m *SigningMethodRSA) Sign(signingString string, key interface{}) (string, var ok bool if signer, ok = key.(crypto.Signer); !ok { - return "", ErrInvalidKey + return "", NewInvalidKeyTypeError("*rsa.PublicKey or crypto.Signer", key) } //sanity check that the signer is an rsa signer - if _, ok := signer.Public().(*rsa.PublicKey); !ok { - return "", ErrInvalidKeyType + if pub, ok := signer.Public().(*rsa.PublicKey); !ok { + return "", &InvalidKeyError{Message: fmt.Sprintf("signer returned unexpected public key type: %T", pub)} } // Create the hasher diff --git a/rsa_pss.go b/rsa_pss.go index 812b3ff1..aa3ba399 100644 --- a/rsa_pss.go +++ b/rsa_pss.go @@ -6,6 +6,7 @@ import ( "crypto" "crypto/rand" "crypto/rsa" + "fmt" ) // SigningMethodRSAPSS implements the RSAPSS family of signing methods @@ -88,10 +89,10 @@ func (m *SigningMethodRSAPSS) Verify(signingString, signature string, key interf case crypto.Signer: pub := k.Public() if rsaKey, ok = pub.(*rsa.PublicKey); !ok { - return ErrInvalidKeyType + return &InvalidKeyError{Message: fmt.Sprintf("signer returned unexpected public key type: %T", pub)} } default: - return ErrInvalidKeyType + return NewInvalidKeyTypeError("*rsa.PublicKey or crypto.Signer", key) } // Create hasher @@ -111,12 +112,12 @@ func (m *SigningMethodRSAPSS) Sign(signingString string, key interface{}) (strin var ok bool if signer, ok = key.(crypto.Signer); !ok { - return "", ErrInvalidKey + return "", NewInvalidKeyTypeError("*rsa.PrivateKey or crypto.Signer", key) } //sanity check that the signer is an rsa signer - if _, ok := signer.Public().(*rsa.PublicKey); !ok { - return "", ErrInvalidKeyType + if pub, ok := signer.Public().(*rsa.PublicKey); !ok { + return "", &InvalidKeyError{Message: fmt.Sprintf("signer returned unexpected public key type: %T", pub)} } // Create the hasher diff --git a/validation_helper.go b/validation_helper.go index a5363f7e..d5bb9d54 100644 --- a/validation_helper.go +++ b/validation_helper.go @@ -1,7 +1,6 @@ package jwt import ( - "fmt" "time" ) @@ -11,8 +10,8 @@ var DefaultValidationHelper = &ValidationHelper{} // ValidationHelper is built by the parser and passed // to Claims.Value to carry parse/validation options type ValidationHelper struct { - nowFunc func() time.Time // Override for time.Now. Mostly used for testing - leeway time.Duration // Leeway to provide when validating time values + nowFunc func() time.Time // Override for time.Now. Mostly used for testing + leeway time.Duration // Leeway to provide when validating time values } // NewValidationHelper creates a validation helper from a list of parser options @@ -54,7 +53,7 @@ func (h *ValidationHelper) ValidateExpiresAt(exp *Time) error { // Expiration has passed if h.After(exp.Time) { delta := h.now().Sub(exp.Time) - return &ExpiredError{h.now().Unix(), delta} + return &TokenExpiredError{At: h.now(), ExpiredBy: delta} } // Expiration has not passed @@ -71,7 +70,8 @@ func (h *ValidationHelper) ValidateNotBefore(nbf *Time) error { // Nbf hasn't been reached if h.Before(nbf.Time) { - return fmt.Errorf("token is not valid yet") + delta := nbf.Time.Sub(h.now()) + return &TokenNotValidYetError{At: h.now(), EarlyBy: delta} } // Nbf has been reached. valid. return nil From ed8d20b2df0aa489498387ea500e5f040bd81d92 Mon Sep 17 00:00:00 2001 From: Dave Grijalva Date: Mon, 6 Jan 2020 16:09:44 -0800 Subject: [PATCH 3/5] remove file no longer needed based on xerrors/go1.13 development --- errors_new.go | 7 ------- 1 file changed, 7 deletions(-) delete mode 100644 errors_new.go diff --git a/errors_new.go b/errors_new.go deleted file mode 100644 index edddc9bd..00000000 --- a/errors_new.go +++ /dev/null @@ -1,7 +0,0 @@ -// Conditionally adds support for new errors behavior, only where it's the default -// +build go1.13 - -package jwt - -// TODO: add Format and FormatError methods to all error types -// per: https://go.googlesource.com/proposal/+/master/design/29934-error-values.md From 122d9923da71a38f0ffc375ebf1bf0f31382b428 Mon Sep 17 00:00:00 2001 From: Dave Grijalva Date: Mon, 6 Jan 2020 16:15:34 -0800 Subject: [PATCH 4/5] cleaup error wrapper behavior and naming --- claims.go | 8 ++++---- errors.go | 10 +++++++++- map_claims.go | 8 ++++---- parser.go | 14 +++++++------- 4 files changed, 24 insertions(+), 16 deletions(-) diff --git a/claims.go b/claims.go index 1b175fb8..a065328b 100644 --- a/claims.go +++ b/claims.go @@ -35,19 +35,19 @@ func (c StandardClaims) Valid(h *ValidationHelper) error { } if err := h.ValidateExpiresAt(c.ExpiresAt); err != nil { - vErr = wrap(err, vErr) + vErr = wrapError(err, vErr) } if err := h.ValidateNotBefore(c.NotBefore); err != nil { - vErr = wrap(err, vErr) + vErr = wrapError(err, vErr) } if err := h.ValidateAudience(c.Audience); err != nil { - vErr = wrap(err, vErr) + vErr = wrapError(err, vErr) } if err := h.ValidateIssuer(c.Issuer); err != nil { - vErr = wrap(err, vErr) + vErr = wrapError(err, vErr) } return vErr diff --git a/errors.go b/errors.go index aae5518c..4851f0a7 100644 --- a/errors.go +++ b/errors.go @@ -12,7 +12,15 @@ var ( // Embeds b within a, if a is a valid wrapper. returns a // If a is not a valid wrapper, b is dropped -func wrap(a, b error) error { +// If one of the errors is nil, the other is returned +func wrapError(a, b error) error { + if b == nil { + return a + } + if a == nil { + return b + } + type iErrorWrapper interface { Wrap(error) Unwrap() error diff --git a/map_claims.go b/map_claims.go index 340a5feb..d721c338 100644 --- a/map_claims.go +++ b/map_claims.go @@ -40,7 +40,7 @@ func (m MapClaims) Valid(h *ValidationHelper) error { } if err = h.ValidateExpiresAt(exp); err != nil { - vErr = wrap(err, vErr) + vErr = wrapError(err, vErr) } nbf, err := m.LoadTimeValue("nbf") @@ -49,14 +49,14 @@ func (m MapClaims) Valid(h *ValidationHelper) error { } if err = h.ValidateNotBefore(nbf); err != nil { - vErr = wrap(err, vErr) + vErr = wrapError(err, vErr) } // Try to parse the 'aud' claim if aud, err := ParseClaimStrings(m["aud"]); err == nil && aud != nil { // If it's present and well formed, validate if err = h.ValidateAudience(aud); err != nil { - vErr = wrap(err, vErr) + vErr = wrapError(err, vErr) } } else if err != nil { // If it's present and not well formed, return an error @@ -65,7 +65,7 @@ func (m MapClaims) Valid(h *ValidationHelper) error { iss, _ := m["iss"].(string) if err = h.ValidateIssuer(iss); err != nil { - vErr = wrap(err, vErr) + vErr = wrapError(err, vErr) } return vErr diff --git a/parser.go b/parser.go index 9018336d..1f1a9c09 100644 --- a/parser.go +++ b/parser.go @@ -65,7 +65,7 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf } if key, err = keyFunc(token); err != nil { // keyFunc returned an error - return token, wrap(&UnverfiableTokenError{Message: "Keyfunc returned an error"}, err) + return token, wrapError(&UnverfiableTokenError{Message: "Keyfunc returned an error"}, err) } var vErr error @@ -73,13 +73,13 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf // Perform validation token.Signature = parts[2] if err = token.Method.Verify(strings.Join(parts[0:2], "."), token.Signature, key); err != nil { - vErr = wrap(&InvalidSignatureError{}, err) + vErr = wrapError(&InvalidSignatureError{}, err) } // Validate Claims if !p.skipClaimsValidation && vErr == nil { if err := token.Claims.Valid(p.ValidationHelper); err != nil { - vErr = wrap(err, vErr) + vErr = wrapError(err, vErr) } } @@ -117,10 +117,10 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke if strings.HasPrefix(strings.ToLower(tokenString), "bearer ") { return token, parts, &MalformedTokenError{Message: "tokenstring should not contain 'bearer '"} } - return token, parts, wrap(&MalformedTokenError{Message: "failed to decode token header"}, err) + return token, parts, wrapError(&MalformedTokenError{Message: "failed to decode token header"}, err) } if err = unmarshaller(CodingContext{HeaderFieldDescriptor, nil}, headerBytes, &token.Header); err != nil { - return token, parts, wrap(&MalformedTokenError{Message: "failed to unmarshal token header"}, err) + return token, parts, wrapError(&MalformedTokenError{Message: "failed to unmarshal token header"}, err) } // parse Claims @@ -128,7 +128,7 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke token.Claims = claims if claimBytes, err = DecodeSegment(parts[1]); err != nil { - return token, parts, wrap(&MalformedTokenError{Message: "failed to decode token claims"}, err) + return token, parts, wrapError(&MalformedTokenError{Message: "failed to decode token claims"}, err) } // JSON Decode. Special case for map type to avoid weird pointer behavior ctx := CodingContext{ClaimsFieldDescriptor, token.Header} @@ -139,7 +139,7 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke } // Handle decode error if err != nil { - return token, parts, wrap(&MalformedTokenError{Message: "failed to unmarshal token claims"}, err) + return token, parts, wrapError(&MalformedTokenError{Message: "failed to unmarshal token claims"}, err) } // Lookup signature method From d9519f14cd751661a90027a314301d64e6de1081 Mon Sep 17 00:00:00 2001 From: Dave Grijalva Date: Mon, 6 Jan 2020 16:17:35 -0800 Subject: [PATCH 5/5] we don't need both methods to wrap --- errors.go | 1 - 1 file changed, 1 deletion(-) diff --git a/errors.go b/errors.go index 4851f0a7..3a925657 100644 --- a/errors.go +++ b/errors.go @@ -23,7 +23,6 @@ func wrapError(a, b error) error { type iErrorWrapper interface { Wrap(error) - Unwrap() error } if w, ok := a.(iErrorWrapper); ok { w.Wrap(b)