From 0c7de7548daa3dacb801dfdc8fea5a63e823220c Mon Sep 17 00:00:00 2001 From: Christian Banse Date: Sun, 19 Feb 2023 17:45:30 +0100 Subject: [PATCH 1/9] More consistent way of handling composite errors --- errors.go | 9 ++++++--- go.mod | 2 +- none.go | 9 ++++----- parser.go | 55 ++++++++++++++++++-------------------------------- parser_test.go | 6 +++--- validator.go | 32 +++++++++++++---------------- 6 files changed, 48 insertions(+), 65 deletions(-) diff --git a/errors.go b/errors.go index a63827fa..3293a0cf 100644 --- a/errors.go +++ b/errors.go @@ -10,9 +10,10 @@ var ( ErrInvalidKeyType = errors.New("key is of invalid type") ErrHashUnavailable = errors.New("the requested hash function is unavailable") - ErrTokenMalformed = errors.New("token is malformed") - ErrTokenUnverifiable = errors.New("token is unverifiable") - ErrTokenSignatureInvalid = errors.New("token signature is invalid") + ErrTokenMalformed = errors.New("token is malformed") + ErrTokenUnverifiable = errors.New("token is unverifiable") + ErrTokenRequiredClaimMissing = errors.New("a required claim is missing") + ErrTokenSignatureInvalid = errors.New("token signature is invalid") ErrTokenInvalidAudience = errors.New("token has invalid audience") ErrTokenExpired = errors.New("token is expired") @@ -43,6 +44,7 @@ const ( ValidationErrorClaimsInvalid // Generic claims validation error ) +/* // NewValidationError is a helper for constructing a ValidationError with a string error message func NewValidationError(errorText string, errorFlags uint32) *ValidationError { return &ValidationError{ @@ -119,3 +121,4 @@ func (e *ValidationError) Is(err error) bool { return false } +*/ diff --git a/go.mod b/go.mod index 3b8690b0..aa4a2c34 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/golang-jwt/jwt/v5 -go 1.16 +go 1.20 diff --git a/none.go b/none.go index f19835d2..75d2d7e3 100644 --- a/none.go +++ b/none.go @@ -1,5 +1,7 @@ package jwt +import "fmt" + // SigningMethodNone implements the none signing method. This is required by the spec // but you probably should never use it. var SigningMethodNone *signingMethodNone @@ -13,7 +15,7 @@ type unsafeNoneMagicConstant string func init() { SigningMethodNone = &signingMethodNone{} - NoneSignatureTypeDisallowedError = NewValidationError("'none' signature type is not allowed", ValidationErrorSignatureInvalid) + NoneSignatureTypeDisallowedError = fmt.Errorf("%w: 'none' signature type is not allowed", ErrTokenUnverifiable) RegisterSigningMethod(SigningMethodNone.Alg(), func() SigningMethod { return SigningMethodNone @@ -33,10 +35,7 @@ func (m *signingMethodNone) Verify(signingString, signature string, key interfac } // If signing method is none, signature must be an empty string if signature != "" { - return NewValidationError( - "'none' signing method with non-empty signature", - ValidationErrorSignatureInvalid, - ) + return fmt.Errorf("%w: 'none' signing method with non-empty signature", ErrTokenUnverifiable) } // Accept 'none' signing method. diff --git a/parser.go b/parser.go index b9c3ffb5..a6aeefa4 100644 --- a/parser.go +++ b/parser.go @@ -65,7 +65,7 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf } if !signingMethodValid { // signing method is not in the listed set - return token, NewValidationError(fmt.Sprintf("signing method %v is invalid", alg), ValidationErrorSignatureInvalid) + return token, fmt.Errorf("%w: signing method %v is invalid", ErrTokenSignatureInvalid, err) } } @@ -73,17 +73,17 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf var key interface{} if keyFunc == nil { // keyFunc was not provided. short circuiting validation - return token, NewValidationError("no Keyfunc was provided.", ValidationErrorUnverifiable) + return token, fmt.Errorf("%w: no keyfunc was provided", ErrTokenUnverifiable) } if key, err = keyFunc(token); err != nil { - // keyFunc returned an error - if ve, ok := err.(*ValidationError); ok { - return token, ve - } - return token, &ValidationError{Inner: err, Errors: ValidationErrorUnverifiable} + return token, fmt.Errorf("%w: error while executing keyfunc: %w", ErrTokenUnverifiable, err) } - vErr := &ValidationError{} + // Perform signature validation + token.Signature = parts[2] + if err = token.Method.Verify(strings.Join(parts[0:2], "."), token.Signature, key); err != nil { + return token, fmt.Errorf("%w: %w", ErrTokenSignatureInvalid, err) + } // Validate Claims if !p.skipClaimsValidation { @@ -93,29 +93,14 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf } if err := p.validator.Validate(claims); err != nil { - // If the Claims Valid returned an error, check if it is a validation error, - // If it was another error type, create a ValidationError with a generic ClaimsInvalid flag set - if e, ok := err.(*ValidationError); !ok { - vErr = &ValidationError{Inner: err, Errors: ValidationErrorClaimsInvalid} - } else { - vErr = e - } + return token, err } } - // Perform validation - token.Signature = parts[2] - if err = token.Method.Verify(strings.Join(parts[0:2], "."), token.Signature, key); err != nil { - vErr.Inner = err - vErr.Errors |= ValidationErrorSignatureInvalid - } - - if vErr.valid() { - token.Valid = true - return token, nil - } + // No errors so far, token is valid. + token.Valid = true - return token, vErr + return token, nil } // ParseUnverified parses the token but doesn't validate the signature. @@ -127,7 +112,7 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Token, parts []string, err error) { parts = strings.Split(tokenString, ".") if len(parts) != 3 { - return nil, parts, NewValidationError("token contains an invalid number of segments", ValidationErrorMalformed) + return nil, parts, fmt.Errorf("%w: token contains an invalid number of segments", ErrTokenMalformed) } token = &Token{Raw: tokenString} @@ -136,12 +121,12 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke var headerBytes []byte if headerBytes, err = DecodeSegment(parts[0]); err != nil { if strings.HasPrefix(strings.ToLower(tokenString), "bearer ") { - return token, parts, NewValidationError("tokenstring should not contain 'bearer '", ValidationErrorMalformed) + return token, parts, fmt.Errorf("%w: tokenstring should not contain 'bearer '", ErrTokenMalformed) } - return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed} + return token, parts, fmt.Errorf("%w: %w", ErrTokenMalformed, err) } if err = json.Unmarshal(headerBytes, &token.Header); err != nil { - return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed} + return token, parts, fmt.Errorf("%w: %w", ErrTokenMalformed, err) } // parse Claims @@ -149,7 +134,7 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke token.Claims = claims if claimBytes, err = DecodeSegment(parts[1]); err != nil { - return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed} + return token, parts, fmt.Errorf("%w: %w", ErrTokenMalformed, err) } dec := json.NewDecoder(bytes.NewBuffer(claimBytes)) if p.useJSONNumber { @@ -163,16 +148,16 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke } // Handle decode error if err != nil { - return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed} + return token, parts, fmt.Errorf("%w: %w", ErrTokenMalformed, err) } // Lookup signature method if method, ok := token.Header["alg"].(string); ok { if token.Method = GetSigningMethod(method); token.Method == nil { - return token, parts, NewValidationError("signing method (alg) is unavailable.", ValidationErrorUnverifiable) + return token, parts, fmt.Errorf("%w: signing method (alg) is unavailable", ErrTokenUnverifiable) } } else { - return token, parts, NewValidationError("signing method (alg) is unspecified.", ValidationErrorUnverifiable) + return token, parts, fmt.Errorf("%w: signing method (alg) is unspecified", ErrTokenUnverifiable) } return token, parts, nil diff --git a/parser_test.go b/parser_test.go index 306f8b50..bf1c03a8 100644 --- a/parser_test.go +++ b/parser_test.go @@ -381,7 +381,7 @@ func TestParser_Parse(t *testing.T) { // Parse the token var token *jwt.Token - var ve *jwt.ValidationError + //var ve *jwt.ValidationError var err error var parser = data.parser if parser == nil { @@ -417,7 +417,7 @@ func TestParser_Parse(t *testing.T) { t.Errorf("[%v] Inconsistent behavior between returned error and token.Valid", data.name) } - if data.errors != 0 { + /*if data.errors != 0 { if err == nil { t.Errorf("[%v] Expecting error. Didn't get one.", data.name) } else { @@ -432,7 +432,7 @@ func TestParser_Parse(t *testing.T) { } } } - } + }*/ if data.err != nil { if err == nil { diff --git a/validator.go b/validator.go index 51469659..aed5526b 100644 --- a/validator.go +++ b/validator.go @@ -2,6 +2,7 @@ package jwt import ( "crypto/subtle" + "errors" "time" ) @@ -48,8 +49,10 @@ func newValidator(opts ...ParserOption) *validator { // Validate validates the given claims. It will also perform any custom // validation if claims implements the CustomValidator interface. func (v *validator) Validate(claims Claims) error { - var now time.Time - vErr := new(ValidationError) + var ( + now time.Time + errs []error = make([]error, 0) + ) // Check, if we have a time func if v.timeFunc != nil { @@ -61,39 +64,33 @@ func (v *validator) Validate(claims Claims) error { // We always need to check the expiration time, but usage of the claim // itself is OPTIONAL if !v.VerifyExpiresAt(claims, now, false) { - vErr.Inner = ErrTokenExpired - vErr.Errors |= ValidationErrorExpired + errs = append(errs, ErrTokenExpired) } // We always need to check not-before, but usage of the claim itself is // OPTIONAL if !v.VerifyNotBefore(claims, now, false) { - vErr.Inner = ErrTokenNotValidYet - vErr.Errors |= ValidationErrorNotValidYet + errs = append(errs, ErrTokenNotValidYet) } // Check issued-at if the option is enabled if v.verifyIat && !v.VerifyIssuedAt(claims, now, false) { - vErr.Inner = ErrTokenUsedBeforeIssued - vErr.Errors |= ValidationErrorIssuedAt + errs = append(errs, ErrTokenUsedBeforeIssued) } // If we have an expected audience, we also require the audience claim if v.expectedAud != "" && !v.VerifyAudience(claims, v.expectedAud, true) { - vErr.Inner = ErrTokenInvalidAudience - vErr.Errors |= ValidationErrorAudience + errs = append(errs, ErrTokenInvalidAudience) } // If we have an expected issuer, we also require the issuer claim if v.expectedIss != "" && !v.VerifyIssuer(claims, v.expectedIss, true) { - vErr.Inner = ErrTokenInvalidIssuer - vErr.Errors |= ValidationErrorIssuer + errs = append(errs, ErrTokenInvalidIssuer) } // If we have an expected subject, we also require the subject claim if v.expectedSub != "" && !v.VerifySubject(claims, v.expectedSub, true) { - vErr.Inner = ErrTokenInvalidSubject - vErr.Errors |= ValidationErrorSubject + errs = append(errs, ErrTokenInvalidSubject) } // Finally, we want to give the claim itself some possibility to do some @@ -103,16 +100,15 @@ func (v *validator) Validate(claims Claims) error { }) if ok { if err := cvt.Validate(); err != nil { - vErr.Inner = err - vErr.Errors |= ValidationErrorClaimsInvalid + errs = append(errs, err) } } - if vErr.valid() { + if len(errs) == 0 { return nil } - return vErr + return errors.Join(errs...) } // VerifyExpiresAt compares the exp claim in claims against cmp. This function From 88bd06420ada1442fbbdc2139ec2980ae460dd72 Mon Sep 17 00:00:00 2001 From: Christian Banse Date: Sun, 19 Feb 2023 17:48:28 +0100 Subject: [PATCH 2/9] Running in Go 1.20 only --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index cea5f110..f7d85f4b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -25,7 +25,7 @@ jobs: strategy: fail-fast: false matrix: - go: [1.17, 1.18, 1.19] + go: ["1.20"] steps: - name: Checkout uses: actions/checkout@v3 From a9047d57ba5d43760cbda3794ad87ca9723ced1a Mon Sep 17 00:00:00 2001 From: Christian Banse Date: Mon, 20 Feb 2023 16:50:42 +0100 Subject: [PATCH 3/9] All the error goodness even without Go 1.20 --- errors.go | 25 ++++++++++++++ errors_go1_20.go | 40 ++++++++++++++++++++++ errors_go_other.go | 71 +++++++++++++++++++++++++++++++++++++++ errors_test.go | 83 ++++++++++++++++++++++++++++++++++++++++++++++ parser.go | 24 +++++++------- validator.go | 3 +- 6 files changed, 232 insertions(+), 14 deletions(-) create mode 100644 errors_go1_20.go create mode 100644 errors_go_other.go create mode 100644 errors_test.go diff --git a/errors.go b/errors.go index 3293a0cf..64577387 100644 --- a/errors.go +++ b/errors.go @@ -2,6 +2,7 @@ package jwt import ( "errors" + "strings" ) // Error constants @@ -44,6 +45,30 @@ const ( ValidationErrorClaimsInvalid // Generic claims validation error ) +// joinedError is an error type that works similar to what [errors.Join] +// produces, with the exception that it has a nice error string; mainly its +// error messages are concatenated using a comma, rather than a newline. +type joinedError struct { + errs []error +} + +func (je joinedError) Error() string { + msg := []string{} + for _, err := range je.errs { + msg = append(msg, err.Error()) + } + + return strings.Join(msg, ", ") +} + +// joinErrors joins together multiple errors. Useful for scenarios where +// multiple errors next to each other occur, e.g., in claims validation. +func joinErrors(errs []error) error { + return &joinedError{ + errs: errs, + } +} + /* // NewValidationError is a helper for constructing a ValidationError with a string error message func NewValidationError(errorText string, errorFlags uint32) *ValidationError { diff --git a/errors_go1_20.go b/errors_go1_20.go new file mode 100644 index 00000000..ba542181 --- /dev/null +++ b/errors_go1_20.go @@ -0,0 +1,40 @@ +//go:build go1.20 +// +build go1.20 + +package jwt + +import ( + "fmt" +) + +// Unwrap implements the multiple error unwrapping for this error type, which is +// possible in Go 1.20. +func (je joinedError) Unwrap() []error { + return je.errs +} + +// newError creates a new error message with a detailed error message. The +// message will be prefixed with the contents of the supplied error type. +// Additionally, more errors, that provide more context can be supplied which +// will be appended to the message. This makes use of Go 1.20's possibility to +// include more than one %w formatting directive in [fmt.Errorf]. +// +// For example, +// +// newError("no keyfunc was provided", ErrTokenUnverifiable) +// +// will produce the error string +// +// "token is unverifiable: no keyfunc was provided" +func newError(message string, err error, more ...error) error { + format := "%w: %s" + args := []any{err, message} + + for _, e := range more { + format += ": %w" + args = append(args, e) + } + + err = fmt.Errorf(format, args...) + return err +} diff --git a/errors_go_other.go b/errors_go_other.go new file mode 100644 index 00000000..d37c3de8 --- /dev/null +++ b/errors_go_other.go @@ -0,0 +1,71 @@ +//go:build !go1.20 +// +build !go1.20 + +package jwt + +import ( + "errors" + "fmt" +) + +// Is implements checking for multiple errors using [errors.Is], since multiple +// error unwrapping is not possible in versions less than Go 1.20. +func (je joinedError) Is(err error) bool { + for _, e := range je.errs { + if errors.Is(e, err) { + return true + } + } + + return false +} + +// wrappedErrors is a workaround for wrapping multiple errors in environments +// where Go 1.20 is not available. It basically uses the already implemented +// functionatlity of joinedError to handle multiple errors with supplies a +// custom error message that is identical to the one we produce in Go 1.20 using +// multiple %w directives. +type wrappedErrors struct { + msg string + joinedError +} + +// Error returns the stored error string +func (we wrappedErrors) Error() string { + return we.msg +} + +// newError creates a new error message with a detailed error message. The +// message will be prefixed with the contents of the supplied error type. +// Additionally, more errors, that provide more context can be supplied which +// will be appended to the message. Since we cannot use of Go 1.20's possibility +// to include more than one %w formatting directive in [fmt.Errorf], we have to +// emulate that. +// +// For example, +// +// newError("no keyfunc was provided", ErrTokenUnverifiable) +// +// will produce the error string +// +// "token is unverifiable: no keyfunc was provided" +func newError(message string, err error, more ...error) error { + // We cannot wrap multiple errors here with %w, so we have to be a little + // bit creative. Basically, we are using %s instead of %w to produce the + // same error message and then throw the result into a custom error struct. + format := "%s: %s" + args := []any{err, message} + errs := []error{err} + + for _, e := range more { + format += ": %s" + args = append(args, e) + errs = append(errs, e) + } + + err = &wrappedErrors{ + msg: fmt.Sprintf(format, args...), + joinedError: joinedError{errs: errs}, + } + return err +} diff --git a/errors_test.go b/errors_test.go new file mode 100644 index 00000000..b534ff09 --- /dev/null +++ b/errors_test.go @@ -0,0 +1,83 @@ +package jwt + +import ( + "errors" + "io" + "testing" +) + +func Test_joinErrors(t *testing.T) { + type args struct { + errs []error + } + tests := []struct { + name string + args args + wantErrors []error + wantMessage string + }{ + { + name: "multiple errors", + args: args{ + errs: []error{ErrTokenNotValidYet, ErrTokenExpired}, + }, + wantErrors: []error{ErrTokenNotValidYet, ErrTokenExpired}, + wantMessage: "token is not valid yet, token is expired", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := joinErrors(tt.args.errs) + for _, wantErr := range tt.wantErrors { + if !errors.Is(err, wantErr) { + t.Errorf("joinErrors() error = %v, does not contain %v", err, wantErr) + } + } + + if err.Error() != tt.wantMessage { + t.Errorf("joinErrors() error.Error() = %v, wantMessage %v", err, tt.wantMessage) + } + }) + } +} + +func Test_newError(t *testing.T) { + type args struct { + message string + err error + more []error + } + tests := []struct { + name string + args args + wantErrors []error + wantMessage string + }{ + { + name: "single error", + args: args{message: "something is wrong", err: ErrTokenMalformed}, + wantMessage: "token is malformed: something is wrong", + wantErrors: []error{ErrTokenMalformed}, + }, + { + name: "two errors", + args: args{message: "something is wrong", err: ErrTokenMalformed, more: []error{io.ErrUnexpectedEOF}}, + wantMessage: "token is malformed: something is wrong: unexpected EOF", + wantErrors: []error{ErrTokenMalformed}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := newError(tt.args.message, tt.args.err, tt.args.more...) + for _, wantErr := range tt.wantErrors { + if !errors.Is(err, wantErr) { + t.Errorf("newError() error = %v, does not contain %v", err, wantErr) + } + } + + if err.Error() != tt.wantMessage { + t.Errorf("newError() error.Error() = %v, wantMessage %v", err, tt.wantMessage) + } + }) + } +} diff --git a/parser.go b/parser.go index a6aeefa4..f02fbc90 100644 --- a/parser.go +++ b/parser.go @@ -65,7 +65,7 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf } if !signingMethodValid { // signing method is not in the listed set - return token, fmt.Errorf("%w: signing method %v is invalid", ErrTokenSignatureInvalid, err) + return token, newError(fmt.Sprintf("signing method %v is invalid", alg), ErrTokenSignatureInvalid) } } @@ -73,16 +73,16 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf var key interface{} if keyFunc == nil { // keyFunc was not provided. short circuiting validation - return token, fmt.Errorf("%w: no keyfunc was provided", ErrTokenUnverifiable) + return token, newError("no keyfunc was provided", ErrTokenUnverifiable) } if key, err = keyFunc(token); err != nil { - return token, fmt.Errorf("%w: error while executing keyfunc: %w", ErrTokenUnverifiable, err) + return token, newError("error while executing keyfunc", ErrTokenUnverifiable, err) } // Perform signature validation token.Signature = parts[2] if err = token.Method.Verify(strings.Join(parts[0:2], "."), token.Signature, key); err != nil { - return token, fmt.Errorf("%w: %w", ErrTokenSignatureInvalid, err) + return token, newError("could not verify", ErrTokenSignatureInvalid, err) } // Validate Claims @@ -112,7 +112,7 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Token, parts []string, err error) { parts = strings.Split(tokenString, ".") if len(parts) != 3 { - return nil, parts, fmt.Errorf("%w: token contains an invalid number of segments", ErrTokenMalformed) + return nil, parts, newError("token contains an invalid number of segments", ErrTokenMalformed) } token = &Token{Raw: tokenString} @@ -121,12 +121,12 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke var headerBytes []byte if headerBytes, err = DecodeSegment(parts[0]); err != nil { if strings.HasPrefix(strings.ToLower(tokenString), "bearer ") { - return token, parts, fmt.Errorf("%w: tokenstring should not contain 'bearer '", ErrTokenMalformed) + return token, parts, newError("tokenstring should not contain 'bearer '", ErrTokenMalformed) } - return token, parts, fmt.Errorf("%w: %w", ErrTokenMalformed, err) + return token, parts, newError("could not base64 decode header", ErrTokenMalformed, err) } if err = json.Unmarshal(headerBytes, &token.Header); err != nil { - return token, parts, fmt.Errorf("%w: %w", ErrTokenMalformed, err) + return token, parts, newError("could not unmarshal header", ErrTokenMalformed, err) } // parse Claims @@ -134,7 +134,7 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke token.Claims = claims if claimBytes, err = DecodeSegment(parts[1]); err != nil { - return token, parts, fmt.Errorf("%w: %w", ErrTokenMalformed, err) + return token, parts, newError("could not base64 decode claim", ErrTokenMalformed, err) } dec := json.NewDecoder(bytes.NewBuffer(claimBytes)) if p.useJSONNumber { @@ -148,16 +148,16 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke } // Handle decode error if err != nil { - return token, parts, fmt.Errorf("%w: %w", ErrTokenMalformed, err) + return token, parts, newError("could JSON decode claim", ErrTokenMalformed, err) } // Lookup signature method if method, ok := token.Header["alg"].(string); ok { if token.Method = GetSigningMethod(method); token.Method == nil { - return token, parts, fmt.Errorf("%w: signing method (alg) is unavailable", ErrTokenUnverifiable) + return token, parts, newError("signing method (alg) is unavailable", ErrTokenUnverifiable) } } else { - return token, parts, fmt.Errorf("%w: signing method (alg) is unspecified", ErrTokenUnverifiable) + return token, parts, newError("signing method (alg) is unspecified", ErrTokenUnverifiable) } return token, parts, nil diff --git a/validator.go b/validator.go index aed5526b..8aca744d 100644 --- a/validator.go +++ b/validator.go @@ -2,7 +2,6 @@ package jwt import ( "crypto/subtle" - "errors" "time" ) @@ -108,7 +107,7 @@ func (v *validator) Validate(claims Claims) error { return nil } - return errors.Join(errs...) + return joinErrors(errs) } // VerifyExpiresAt compares the exp claim in claims against cmp. This function From 3efe89df6aa1ed4055a82291dc984436f3d08c0a Mon Sep 17 00:00:00 2001 From: Christian Banse Date: Mon, 20 Feb 2023 16:52:11 +0100 Subject: [PATCH 4/9] Building again on Go 1.28 --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index f7d85f4b..9840fffa 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -25,7 +25,7 @@ jobs: strategy: fail-fast: false matrix: - go: ["1.20"] + go: ["1.18", "1.19", "1.20"] steps: - name: Checkout uses: actions/checkout@v3 From e2373049fca589e782a2fcee3e72f700ad931092 Mon Sep 17 00:00:00 2001 From: Christian Banse Date: Mon, 20 Feb 2023 16:56:21 +0100 Subject: [PATCH 5/9] Removed old error bitfields --- errors.go | 96 -------------------------------------------------- parser_test.go | 48 ++----------------------- 2 files changed, 3 insertions(+), 141 deletions(-) diff --git a/errors.go b/errors.go index 64577387..35f66ae9 100644 --- a/errors.go +++ b/errors.go @@ -28,23 +28,6 @@ var ( ErrInvalidType = errors.New("invalid type for claim") ) -// The errors that might occur when parsing and validating a token -const ( - ValidationErrorMalformed uint32 = 1 << iota // Token is malformed - ValidationErrorUnverifiable // Token could not be verified because of signing problems - ValidationErrorSignatureInvalid // Signature validation failed - - // Registered Claim validation errors - ValidationErrorAudience // AUD validation failed - ValidationErrorExpired // EXP validation failed - ValidationErrorIssuedAt // IAT validation failed - ValidationErrorIssuer // ISS validation failed - ValidationErrorSubject // SUB validation failed - ValidationErrorNotValidYet // NBF validation failed - ValidationErrorId // JTI validation failed - ValidationErrorClaimsInvalid // Generic claims validation error -) - // joinedError is an error type that works similar to what [errors.Join] // produces, with the exception that it has a nice error string; mainly its // error messages are concatenated using a comma, rather than a newline. @@ -68,82 +51,3 @@ func joinErrors(errs []error) error { errs: errs, } } - -/* -// NewValidationError is a helper for constructing a ValidationError with a string error message -func NewValidationError(errorText string, errorFlags uint32) *ValidationError { - return &ValidationError{ - text: errorText, - Errors: errorFlags, - } -} - -// ValidationError represents an error from Parse if token is not valid -type ValidationError struct { - // Inner stores the error returned by external dependencies, e.g.: KeyFunc - Inner error - // Errors is a bit-field. See ValidationError... constants - Errors uint32 - // Text can be used for errors that do not have a valid error just have text - text string -} - -// Error is the implementation of the err interface. -func (e ValidationError) Error() string { - if e.Inner != nil { - return e.Inner.Error() - } else if e.text != "" { - return e.text - } else { - return "token is invalid" - } -} - -// Unwrap gives errors.Is and errors.As access to the inner error. -func (e *ValidationError) Unwrap() error { - return e.Inner -} - -// No errors -func (e *ValidationError) valid() bool { - return e.Errors == 0 -} - -// Is checks if this ValidationError is of the supplied error. We are first -// checking for the exact error message by comparing the inner error message. If -// that fails, we compare using the error flags. This way we can use custom -// error messages (mainly for backwards compatibility) and still leverage -// errors.Is using the global error variables. -func (e *ValidationError) Is(err error) bool { - // Check, if our inner error is a direct match - if errors.Is(errors.Unwrap(e), err) { - return true - } - - // Otherwise, we need to match using our error flags - switch err { - case ErrTokenMalformed: - return e.Errors&ValidationErrorMalformed != 0 - case ErrTokenUnverifiable: - return e.Errors&ValidationErrorUnverifiable != 0 - case ErrTokenSignatureInvalid: - return e.Errors&ValidationErrorSignatureInvalid != 0 - case ErrTokenInvalidAudience: - return e.Errors&ValidationErrorAudience != 0 - case ErrTokenExpired: - return e.Errors&ValidationErrorExpired != 0 - case ErrTokenUsedBeforeIssued: - return e.Errors&ValidationErrorIssuedAt != 0 - case ErrTokenInvalidIssuer: - return e.Errors&ValidationErrorIssuer != 0 - case ErrTokenNotValidYet: - return e.Errors&ValidationErrorNotValidYet != 0 - case ErrTokenInvalidId: - return e.Errors&ValidationErrorId != 0 - case ErrTokenInvalidClaims: - return e.Errors&ValidationErrorClaimsInvalid != 0 - } - - return false -} -*/ diff --git a/parser_test.go b/parser_test.go index bf1c03a8..78b924f0 100644 --- a/parser_test.go +++ b/parser_test.go @@ -51,7 +51,6 @@ var jwtTestData = []struct { keyfunc jwt.Keyfunc claims jwt.Claims valid bool - errors uint32 err []error parser *jwt.Parser signingMethod jwt.SigningMethod // The method to sign the JWT token for test purpose @@ -84,7 +83,6 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar"}, true, - 0, nil, nil, jwt.SigningMethodRS256, @@ -95,7 +93,6 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar", "exp": float64(time.Now().Unix() - 100)}, false, - jwt.ValidationErrorExpired, []error{jwt.ErrTokenExpired}, nil, jwt.SigningMethodRS256, @@ -106,7 +103,6 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100)}, false, - jwt.ValidationErrorNotValidYet, []error{jwt.ErrTokenNotValidYet}, nil, jwt.SigningMethodRS256, @@ -117,8 +113,7 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100), "exp": float64(time.Now().Unix() - 100)}, false, - jwt.ValidationErrorNotValidYet | jwt.ValidationErrorExpired, - []error{jwt.ErrTokenNotValidYet}, + []error{jwt.ErrTokenNotValidYet, jwt.ErrTokenExpired}, nil, jwt.SigningMethodRS256, }, @@ -128,7 +123,6 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar"}, false, - jwt.ValidationErrorSignatureInvalid, []error{jwt.ErrTokenSignatureInvalid, rsa.ErrVerification}, nil, jwt.SigningMethodRS256, @@ -139,7 +133,6 @@ var jwtTestData = []struct { nilKeyFunc, jwt.MapClaims{"foo": "bar"}, false, - jwt.ValidationErrorUnverifiable, []error{jwt.ErrTokenUnverifiable}, nil, jwt.SigningMethodRS256, @@ -150,7 +143,6 @@ var jwtTestData = []struct { emptyKeyFunc, jwt.MapClaims{"foo": "bar"}, false, - jwt.ValidationErrorSignatureInvalid, []error{jwt.ErrTokenSignatureInvalid}, nil, jwt.SigningMethodRS256, @@ -161,7 +153,6 @@ var jwtTestData = []struct { errorKeyFunc, jwt.MapClaims{"foo": "bar"}, false, - jwt.ValidationErrorUnverifiable, []error{jwt.ErrTokenUnverifiable, errKeyFuncError}, nil, jwt.SigningMethodRS256, @@ -172,7 +163,6 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar"}, false, - jwt.ValidationErrorSignatureInvalid, []error{jwt.ErrTokenSignatureInvalid}, jwt.NewParser(jwt.WithValidMethods([]string{"HS256"})), jwt.SigningMethodRS256, @@ -183,7 +173,6 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar"}, true, - 0, nil, jwt.NewParser(jwt.WithValidMethods([]string{"RS256", "HS256"})), jwt.SigningMethodRS256, @@ -194,7 +183,6 @@ var jwtTestData = []struct { ecdsaKeyFunc, jwt.MapClaims{"foo": "bar"}, false, - jwt.ValidationErrorSignatureInvalid, []error{jwt.ErrTokenSignatureInvalid}, jwt.NewParser(jwt.WithValidMethods([]string{"RS256", "HS256"})), jwt.SigningMethodES256, @@ -205,7 +193,6 @@ var jwtTestData = []struct { ecdsaKeyFunc, jwt.MapClaims{"foo": "bar"}, true, - 0, nil, jwt.NewParser(jwt.WithValidMethods([]string{"HS256", "ES256"})), jwt.SigningMethodES256, @@ -216,7 +203,6 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": json.Number("123.4")}, true, - 0, nil, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, @@ -227,7 +213,6 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar", "exp": json.Number(fmt.Sprintf("%v", time.Now().Unix()-100))}, false, - jwt.ValidationErrorExpired, []error{jwt.ErrTokenExpired}, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, @@ -238,7 +223,6 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100))}, false, - jwt.ValidationErrorNotValidYet, []error{jwt.ErrTokenNotValidYet}, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, @@ -249,8 +233,7 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100)), "exp": json.Number(fmt.Sprintf("%v", time.Now().Unix()-100))}, false, - jwt.ValidationErrorNotValidYet | jwt.ValidationErrorExpired, - []error{jwt.ErrTokenNotValidYet}, + []error{jwt.ErrTokenNotValidYet, jwt.ErrTokenExpired}, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, }, @@ -260,7 +243,6 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100))}, true, - 0, nil, jwt.NewParser(jwt.WithJSONNumber(), jwt.WithoutClaimsValidation()), jwt.SigningMethodRS256, @@ -273,7 +255,6 @@ var jwtTestData = []struct { ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Second * 10)), }, true, - 0, nil, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, @@ -286,7 +267,6 @@ var jwtTestData = []struct { Audience: jwt.ClaimStrings{"test"}, }, true, - 0, nil, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, @@ -299,7 +279,6 @@ var jwtTestData = []struct { Audience: jwt.ClaimStrings{"test", "test"}, }, true, - 0, nil, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, @@ -312,7 +291,6 @@ var jwtTestData = []struct { Audience: nil, // because of the unmarshal error, this will be empty }, false, - jwt.ValidationErrorMalformed, []error{jwt.ErrTokenMalformed}, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, @@ -325,7 +303,6 @@ var jwtTestData = []struct { Audience: nil, // because of the unmarshal error, this will be empty }, false, - jwt.ValidationErrorMalformed, []error{jwt.ErrTokenMalformed}, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, @@ -336,7 +313,6 @@ var jwtTestData = []struct { defaultKeyFunc, &jwt.RegisteredClaims{NotBefore: jwt.NewNumericDate(time.Now().Add(time.Second * 100))}, false, - jwt.ValidationErrorNotValidYet, []error{jwt.ErrTokenNotValidYet}, jwt.NewParser(jwt.WithLeeway(time.Minute)), jwt.SigningMethodRS256, @@ -347,7 +323,6 @@ var jwtTestData = []struct { defaultKeyFunc, &jwt.RegisteredClaims{NotBefore: jwt.NewNumericDate(time.Now().Add(time.Second * 100))}, true, - 0, nil, jwt.NewParser(jwt.WithLeeway(2 * time.Minute)), jwt.SigningMethodRS256, @@ -417,23 +392,6 @@ func TestParser_Parse(t *testing.T) { t.Errorf("[%v] Inconsistent behavior between returned error and token.Valid", data.name) } - /*if data.errors != 0 { - if err == nil { - t.Errorf("[%v] Expecting error. Didn't get one.", data.name) - } else { - if errors.As(err, &ve) { - // compare the bitfield part of the error - if e := ve.Errors; e != data.errors { - t.Errorf("[%v] Errors don't match expectation. %v != %v", data.name, e, data.errors) - } - - if err.Error() == errKeyFuncError.Error() && ve.Inner != errKeyFuncError { - t.Errorf("[%v] Inner error does not match expectation. %v != %v", data.name, ve.Inner, errKeyFuncError) - } - } - } - }*/ - if data.err != nil { if err == nil { t.Errorf("[%v] Expecting error(s). Didn't get one.", data.name) @@ -467,7 +425,7 @@ func TestParser_ParseUnverified(t *testing.T) { // Iterate over test data set and run tests for _, data := range jwtTestData { // Skip test data, that intentionally contains malformed tokens, as they would lead to an error - if data.errors&jwt.ValidationErrorMalformed != 0 { + if len(data.err) == 1 && errors.Is(data.err[0], jwt.ErrTokenMalformed) { continue } From 84a369b282223a5486d8938dfd58fb21d3f6a80e Mon Sep 17 00:00:00 2001 From: Christian Banse Date: Mon, 20 Feb 2023 22:22:19 +0100 Subject: [PATCH 6/9] Verify functions now return errors instead of bool --- map_claims.go | 7 +- none.go | 6 +- parser_test.go | 1 - validator.go | 202 ++++++++++++++++++++++++++++++------------------- 4 files changed, 132 insertions(+), 84 deletions(-) diff --git a/map_claims.go b/map_claims.go index 014acb94..26ac3fbe 100644 --- a/map_claims.go +++ b/map_claims.go @@ -2,6 +2,7 @@ package jwt import ( "encoding/json" + "fmt" ) // MapClaims is a claims type that uses the map[string]interface{} for JSON decoding. @@ -60,7 +61,7 @@ func (m MapClaims) parseNumericDate(key string) (*NumericDate, error) { return newNumericDateFromSeconds(v), nil } - return nil, ErrInvalidType + return nil, newError(fmt.Sprintf("%s is invalid", key), ErrInvalidType) } // parseClaimsString tries to parse a key in the map claims type as a @@ -76,7 +77,7 @@ func (m MapClaims) parseClaimsString(key string) (ClaimStrings, error) { for _, a := range v { vs, ok := a.(string) if !ok { - return nil, ErrInvalidType + return nil, newError(fmt.Sprintf("%s is invalid", key), ErrInvalidType) } cs = append(cs, vs) } @@ -101,7 +102,7 @@ func (m MapClaims) parseString(key string) (string, error) { iss, ok = raw.(string) if !ok { - return "", ErrInvalidType + return "", newError(fmt.Sprintf("%s is invalid", key), ErrInvalidType) } return iss, nil diff --git a/none.go b/none.go index 75d2d7e3..a16495ac 100644 --- a/none.go +++ b/none.go @@ -1,7 +1,5 @@ package jwt -import "fmt" - // SigningMethodNone implements the none signing method. This is required by the spec // but you probably should never use it. var SigningMethodNone *signingMethodNone @@ -15,7 +13,7 @@ type unsafeNoneMagicConstant string func init() { SigningMethodNone = &signingMethodNone{} - NoneSignatureTypeDisallowedError = fmt.Errorf("%w: 'none' signature type is not allowed", ErrTokenUnverifiable) + NoneSignatureTypeDisallowedError = newError("'none' signature type is not allowed", ErrTokenUnverifiable) RegisterSigningMethod(SigningMethodNone.Alg(), func() SigningMethod { return SigningMethodNone @@ -35,7 +33,7 @@ func (m *signingMethodNone) Verify(signingString, signature string, key interfac } // If signing method is none, signature must be an empty string if signature != "" { - return fmt.Errorf("%w: 'none' signing method with non-empty signature", ErrTokenUnverifiable) + return newError("'none' signing method with non-empty signature", ErrTokenUnverifiable) } // Accept 'none' signing method. diff --git a/parser_test.go b/parser_test.go index 78b924f0..90c271fc 100644 --- a/parser_test.go +++ b/parser_test.go @@ -356,7 +356,6 @@ func TestParser_Parse(t *testing.T) { // Parse the token var token *jwt.Token - //var ve *jwt.ValidationError var err error var parser = data.parser if parser == nil { diff --git a/validator.go b/validator.go index 8aca744d..912050ad 100644 --- a/validator.go +++ b/validator.go @@ -2,9 +2,32 @@ package jwt import ( "crypto/subtle" + "fmt" "time" ) +// ClaimsValidator is an interface that can be implemented by custom claims who +// wish to execute any additional claims validation based on +// application-specific logic. The Validate function is then executed in +// addition to the regular claims validation and any error returned is appended +// to the final validation result. +// +// type MyCustomClaims struct { +// Foo string `json:"foo"` +// jwt.RegisteredClaims +// } +// +// func (m MyCustomClaims) Validate() error { +// if m.Foo != "bar" { +// return errors.New("must be foobar") +// } +// return nil +// } +type ClaimsValidator interface { + Claims + Validate() error +} + // validator is the core of the new Validation API. It is automatically used by // a [Parser] during parsing and can be modified with various parser options. // @@ -46,11 +69,12 @@ func newValidator(opts ...ParserOption) *validator { } // Validate validates the given claims. It will also perform any custom -// validation if claims implements the CustomValidator interface. +// validation if claims implements the [ClaimsValidator] interface. func (v *validator) Validate(claims Claims) error { var ( now time.Time - errs []error = make([]error, 0) + errs []error = make([]error, 0, 6) + err error ) // Check, if we have a time func @@ -61,42 +85,48 @@ func (v *validator) Validate(claims Claims) error { } // We always need to check the expiration time, but usage of the claim - // itself is OPTIONAL - if !v.VerifyExpiresAt(claims, now, false) { - errs = append(errs, ErrTokenExpired) + // itself is OPTIONAL. + if err = v.verifyExpiresAt(claims, now, false); err != nil { + errs = append(errs, err) } // We always need to check not-before, but usage of the claim itself is - // OPTIONAL - if !v.VerifyNotBefore(claims, now, false) { - errs = append(errs, ErrTokenNotValidYet) + // OPTIONAL. + if err = v.verifyNotBefore(claims, now, false); err != nil { + errs = append(errs, err) } // Check issued-at if the option is enabled - if v.verifyIat && !v.VerifyIssuedAt(claims, now, false) { - errs = append(errs, ErrTokenUsedBeforeIssued) + if v.verifyIat { + if err = v.verifyIssuedAt(claims, now, false); err != nil { + errs = append(errs, err) + } } // If we have an expected audience, we also require the audience claim - if v.expectedAud != "" && !v.VerifyAudience(claims, v.expectedAud, true) { - errs = append(errs, ErrTokenInvalidAudience) + if v.expectedAud != "" { + if err = v.verifyAudience(claims, v.expectedAud, true); err != nil { + errs = append(errs, err) + } } // If we have an expected issuer, we also require the issuer claim - if v.expectedIss != "" && !v.VerifyIssuer(claims, v.expectedIss, true) { - errs = append(errs, ErrTokenInvalidIssuer) + if v.expectedIss != "" { + if err = v.verifyIssuer(claims, v.expectedIss, true); err != nil { + errs = append(errs, err) + } } // If we have an expected subject, we also require the subject claim - if v.expectedSub != "" && !v.VerifySubject(claims, v.expectedSub, true) { - errs = append(errs, ErrTokenInvalidSubject) + if v.expectedSub != "" { + if err = v.verifySubject(claims, v.expectedSub, true); err != nil { + errs = append(errs, ErrTokenInvalidSubject) + } } // Finally, we want to give the claim itself some possibility to do some // additional custom validation based on a custom Validate function. - cvt, ok := claims.(interface { - Validate() error - }) + cvt, ok := claims.(ClaimsValidator) if ok { if err := cvt.Validate(); err != nil { errs = append(errs, err) @@ -110,84 +140,84 @@ func (v *validator) Validate(claims Claims) error { return joinErrors(errs) } -// VerifyExpiresAt compares the exp claim in claims against cmp. This function -// will return true if cmp < exp. Additional leeway is taken into account. +// verifyExpiresAt compares the exp claim in claims against cmp. This function +// will succeed if cmp < exp. Additional leeway is taken into account. // -// If exp is not set, it will return true if the claim is not required, -// otherwise false will be returned. +// If exp is not set, it will succeed if the claim is not required, +// otherwise ErrTokenRequiredClaimMissing will be returned. // // Additionally, if any error occurs while retrieving the claim, e.g., when its -// the wrong type, false will be returned. -func (v *validator) VerifyExpiresAt(claims Claims, cmp time.Time, required bool) bool { +// the wrong type, an ErrTokenUnverifiable error will be returned. +func (v *validator) verifyExpiresAt(claims Claims, cmp time.Time, required bool) error { exp, err := claims.GetExpirationTime() if err != nil { - return false + return err } - if exp != nil { - return cmp.Before((exp.Time).Add(+v.leeway)) - } else { - return !required + if exp == nil { + return errorIfRequired(required, "exp") } + + return errorIfFalse(cmp.Before((exp.Time).Add(+v.leeway)), ErrTokenExpired) } -// VerifyIssuedAt compares the iat claim in claims against cmp. This function -// will return true if cmp >= iat. Additional leeway is taken into account. +// verifyIssuedAt compares the iat claim in claims against cmp. This function +// will succeed if cmp >= iat. Additional leeway is taken into account. // -// If iat is not set, it will return true if the claim is not required, -// otherwise false will be returned. +// If iat is not set, it will succeed if the claim is not required, +// otherwise ErrTokenRequiredClaimMissing will be returned. // // Additionally, if any error occurs while retrieving the claim, e.g., when its -// the wrong type, false will be returned. -func (v *validator) VerifyIssuedAt(claims Claims, cmp time.Time, required bool) bool { +// the wrong type, an ErrTokenUnverifiable error will be returned. +func (v *validator) verifyIssuedAt(claims Claims, cmp time.Time, required bool) error { iat, err := claims.GetIssuedAt() if err != nil { - return false + return err } - if iat != nil { - return !cmp.Before(iat.Add(-v.leeway)) - } else { - return !required + if iat == nil { + return errorIfRequired(required, "iat") } + + return errorIfFalse(!cmp.Before(iat.Add(-v.leeway)), ErrTokenUsedBeforeIssued) } -// VerifyNotBefore compares the nbf claim in claims against cmp. This function +// verifyNotBefore compares the nbf claim in claims against cmp. This function // will return true if cmp >= nbf. Additional leeway is taken into account. // -// If nbf is not set, it will return true if the claim is not required, -// otherwise false will be returned. +// If nbf is not set, it will succeed if the claim is not required, +// otherwise ErrTokenRequiredClaimMissing will be returned. // // Additionally, if any error occurs while retrieving the claim, e.g., when its -// the wrong type, false will be returned. -func (v *validator) VerifyNotBefore(claims Claims, cmp time.Time, required bool) bool { +// the wrong type, an ErrTokenUnverifiable error will be returned. +func (v *validator) verifyNotBefore(claims Claims, cmp time.Time, required bool) error { nbf, err := claims.GetNotBefore() if err != nil { - return false + return err } - if nbf != nil { - return !cmp.Before(nbf.Add(-v.leeway)) - } else { - return !required + if nbf == nil { + return errorIfRequired(required, "nbf") } + + return errorIfFalse(!cmp.Before(nbf.Add(-v.leeway)), ErrTokenNotValidYet) } -// VerifyAudience compares the aud claim against cmp. +// verifyAudience compares the aud claim against cmp. // -// If aud is not set or an empty list, it will return true if the claim is not -// required, otherwise false will be returned. +// If aud is not set or an empty list, it will succeed if the claim is not required, +// otherwise ErrTokenRequiredClaimMissing will be returned. // // Additionally, if any error occurs while retrieving the claim, e.g., when its -// the wrong type, false will be returned. -func (v *validator) VerifyAudience(claims Claims, cmp string, required bool) bool { +// the wrong type, an ErrTokenUnverifiable error will be returned. +func (v *validator) verifyAudience(claims Claims, cmp string, required bool) error { aud, err := claims.GetAudience() if err != nil { - return false + return err } if len(aud) == 0 { - return !required + return errorIfRequired(required, "aud") } // use a var here to keep constant time compare when looping over a number of claims @@ -203,48 +233,68 @@ func (v *validator) VerifyAudience(claims Claims, cmp string, required bool) boo // case where "" is sent in one or many aud claims if stringClaims == "" { - return !required + return errorIfRequired(required, "aud") } - return result + return errorIfFalse(result, ErrTokenInvalidAudience) } -// VerifyIssuer compares the iss claim in claims against cmp. +// verifyIssuer compares the iss claim in claims against cmp. // -// If iss is not set, it will return true if the claim is not required, -// otherwise false will be returned. +// If iss is not set, it will succeed if the claim is not required, +// otherwise ErrTokenRequiredClaimMissing will be returned. // // Additionally, if any error occurs while retrieving the claim, e.g., when its -// the wrong type, false will be returned. -func (v *validator) VerifyIssuer(claims Claims, cmp string, required bool) bool { +// the wrong type, an ErrTokenUnverifiable error will be returned. +func (v *validator) verifyIssuer(claims Claims, cmp string, required bool) error { iss, err := claims.GetIssuer() if err != nil { - return false + return err } if iss == "" { - return !required + return errorIfRequired(required, "iss") } - return iss == cmp + return errorIfFalse(iss == cmp, ErrTokenInvalidIssuer) } -// VerifySubject compares the sub claim against cmp. +// verifySubject compares the sub claim against cmp. // -// If sub is not set, it will return true if the claim is not required, -// otherwise false will be returned. +// If sub is not set, it will succeed if the claim is not required, +// otherwise ErrTokenRequiredClaimMissing will be returned. // // Additionally, if any error occurs while retrieving the claim, e.g., when its -// the wrong type, false will be returned. -func (v *validator) VerifySubject(claims Claims, cmp string, required bool) bool { +// the wrong type, an ErrTokenUnverifiable error will be returned. +func (v *validator) verifySubject(claims Claims, cmp string, required bool) error { sub, err := claims.GetSubject() if err != nil { - return false + return err } if sub == "" { - return !required + return errorIfRequired(required, "sub") } - return sub == cmp + return errorIfFalse(sub == cmp, ErrTokenInvalidIssuer) +} + +// errorIfFalse returns the error specified in err, if the value is true. +// Otherwise, nil is returned. +func errorIfFalse(value bool, err error) error { + if value { + return nil + } else { + return err + } +} + +// errorIfRequired returns an ErrTokenRequiredClaimMissing error if required is +// true. Otherwise, nil is returned. +func errorIfRequired(required bool, claim string) error { + if required { + return newError(fmt.Sprintf("%s claim is required", claim), ErrTokenRequiredClaimMissing) + } else { + return nil + } } From 1cf3c7151200409cd2a04d693fef769e3b2ea59a Mon Sep 17 00:00:00 2001 From: Christian Banse Date: Mon, 20 Feb 2023 22:59:06 +0100 Subject: [PATCH 7/9] More validation testing --- errors.go | 32 +++--- errors_go1_20.go | 11 +- errors_go_other.go | 11 +- errors_test.go | 14 ++- map_claims_test.go | 8 +- parser.go | 4 +- parser_test.go | 2 - validator.go | 6 +- validator_test.go | 261 +++++++++++++++++++++++++++++++++++++++++++++ 9 files changed, 315 insertions(+), 34 deletions(-) create mode 100644 validator_test.go diff --git a/errors.go b/errors.go index 35f66ae9..23bb616d 100644 --- a/errors.go +++ b/errors.go @@ -5,27 +5,23 @@ import ( "strings" ) -// Error constants var ( - ErrInvalidKey = errors.New("key is invalid") - ErrInvalidKeyType = errors.New("key is of invalid type") - ErrHashUnavailable = errors.New("the requested hash function is unavailable") - + ErrInvalidKey = errors.New("key is invalid") + ErrInvalidKeyType = errors.New("key is of invalid type") + ErrHashUnavailable = errors.New("the requested hash function is unavailable") ErrTokenMalformed = errors.New("token is malformed") ErrTokenUnverifiable = errors.New("token is unverifiable") - ErrTokenRequiredClaimMissing = errors.New("a required claim is missing") ErrTokenSignatureInvalid = errors.New("token signature is invalid") - - ErrTokenInvalidAudience = errors.New("token has invalid audience") - ErrTokenExpired = errors.New("token is expired") - ErrTokenUsedBeforeIssued = errors.New("token used before issued") - ErrTokenInvalidIssuer = errors.New("token has invalid issuer") - ErrTokenInvalidSubject = errors.New("token has invalid subject") - ErrTokenNotValidYet = errors.New("token is not valid yet") - ErrTokenInvalidId = errors.New("token has invalid id") - ErrTokenInvalidClaims = errors.New("token has invalid claims") - - ErrInvalidType = errors.New("invalid type for claim") + ErrTokenRequiredClaimMissing = errors.New("token is missing required claim") + ErrTokenInvalidAudience = errors.New("token has invalid audience") + ErrTokenExpired = errors.New("token is expired") + ErrTokenUsedBeforeIssued = errors.New("token used before issued") + ErrTokenInvalidIssuer = errors.New("token has invalid issuer") + ErrTokenInvalidSubject = errors.New("token has invalid subject") + ErrTokenNotValidYet = errors.New("token is not valid yet") + ErrTokenInvalidId = errors.New("token has invalid id") + ErrTokenInvalidClaims = errors.New("token has invalid claims") + ErrInvalidType = errors.New("invalid type for claim") ) // joinedError is an error type that works similar to what [errors.Join] @@ -46,7 +42,7 @@ func (je joinedError) Error() string { // joinErrors joins together multiple errors. Useful for scenarios where // multiple errors next to each other occur, e.g., in claims validation. -func joinErrors(errs []error) error { +func joinErrors(errs ...error) error { return &joinedError{ errs: errs, } diff --git a/errors_go1_20.go b/errors_go1_20.go index ba542181..a893d355 100644 --- a/errors_go1_20.go +++ b/errors_go1_20.go @@ -27,8 +27,15 @@ func (je joinedError) Unwrap() []error { // // "token is unverifiable: no keyfunc was provided" func newError(message string, err error, more ...error) error { - format := "%w: %s" - args := []any{err, message} + var format string + var args []any + if message != "" { + format = "%w: %s" + args = []any{err, message} + } else { + format = "%w" + args = []any{err} + } for _, e := range more { format += ": %w" diff --git a/errors_go_other.go b/errors_go_other.go index d37c3de8..3afb04e6 100644 --- a/errors_go_other.go +++ b/errors_go_other.go @@ -53,8 +53,15 @@ func newError(message string, err error, more ...error) error { // We cannot wrap multiple errors here with %w, so we have to be a little // bit creative. Basically, we are using %s instead of %w to produce the // same error message and then throw the result into a custom error struct. - format := "%s: %s" - args := []any{err, message} + var format string + var args []any + if message != "" { + format = "%s: %s" + args = []any{err, message} + } else { + format = "%s" + args = []any{err} + } errs := []error{err} for _, e := range more { diff --git a/errors_test.go b/errors_test.go index b534ff09..fd4004b3 100644 --- a/errors_test.go +++ b/errors_test.go @@ -27,7 +27,7 @@ func Test_joinErrors(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := joinErrors(tt.args.errs) + err := joinErrors(tt.args.errs...) for _, wantErr := range tt.wantErrors { if !errors.Is(err, wantErr) { t.Errorf("joinErrors() error = %v, does not contain %v", err, wantErr) @@ -65,6 +65,18 @@ func Test_newError(t *testing.T) { wantMessage: "token is malformed: something is wrong: unexpected EOF", wantErrors: []error{ErrTokenMalformed}, }, + { + name: "two errors, no detail", + args: args{message: "", err: ErrTokenInvalidClaims, more: []error{ErrTokenExpired}}, + wantMessage: "token has invalid claims: token is expired", + wantErrors: []error{ErrTokenInvalidClaims, ErrTokenExpired}, + }, + { + name: "two errors, no detail and join error", + args: args{message: "", err: ErrTokenInvalidClaims, more: []error{joinErrors(ErrTokenExpired, ErrTokenNotValidYet)}}, + wantMessage: "token has invalid claims: token is expired, token is not valid yet", + wantErrors: []error{ErrTokenInvalidClaims, ErrTokenExpired, ErrTokenNotValidYet}, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/map_claims_test.go b/map_claims_test.go index b4af574b..83065d5b 100644 --- a/map_claims_test.go +++ b/map_claims_test.go @@ -135,7 +135,7 @@ func TestMapClaimsVerifyExpiresAtExpire(t *testing.T) { } } -func TestMapClaims_ParseString(t *testing.T) { +func TestMapClaims_parseString(t *testing.T) { type args struct { key string } @@ -176,13 +176,13 @@ func TestMapClaims_ParseString(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := tt.m.ParseString(tt.args.key) + got, err := tt.m.parseString(tt.args.key) if (err != nil) != tt.wantErr { - t.Errorf("MapClaims.ParseString() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("MapClaims.parseString() error = %v, wantErr %v", err, tt.wantErr) return } if got != tt.want { - t.Errorf("MapClaims.ParseString() = %v, want %v", got, tt.want) + t.Errorf("MapClaims.parseString() = %v, want %v", got, tt.want) } }) } diff --git a/parser.go b/parser.go index f02fbc90..1a4ddb92 100644 --- a/parser.go +++ b/parser.go @@ -82,7 +82,7 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf // Perform signature validation token.Signature = parts[2] if err = token.Method.Verify(strings.Join(parts[0:2], "."), token.Signature, key); err != nil { - return token, newError("could not verify", ErrTokenSignatureInvalid, err) + return token, newError("", ErrTokenSignatureInvalid, err) } // Validate Claims @@ -93,7 +93,7 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf } if err := p.validator.Validate(claims); err != nil { - return token, err + return token, newError("", ErrTokenInvalidClaims, err) } } diff --git a/parser_test.go b/parser_test.go index 90c271fc..a6af8df3 100644 --- a/parser_test.go +++ b/parser_test.go @@ -61,7 +61,6 @@ var jwtTestData = []struct { defaultKeyFunc, nil, false, - jwt.ValidationErrorMalformed, []error{jwt.ErrTokenMalformed}, nil, jwt.SigningMethodRS256, @@ -72,7 +71,6 @@ var jwtTestData = []struct { defaultKeyFunc, nil, false, - jwt.ValidationErrorMalformed, []error{jwt.ErrTokenMalformed}, nil, jwt.SigningMethodRS256, diff --git a/validator.go b/validator.go index 912050ad..e9cd4d2e 100644 --- a/validator.go +++ b/validator.go @@ -120,7 +120,7 @@ func (v *validator) Validate(claims Claims) error { // If we have an expected subject, we also require the subject claim if v.expectedSub != "" { if err = v.verifySubject(claims, v.expectedSub, true); err != nil { - errs = append(errs, ErrTokenInvalidSubject) + errs = append(errs, err) } } @@ -137,7 +137,7 @@ func (v *validator) Validate(claims Claims) error { return nil } - return joinErrors(errs) + return joinErrors(errs...) } // verifyExpiresAt compares the exp claim in claims against cmp. This function @@ -276,7 +276,7 @@ func (v *validator) verifySubject(claims Claims, cmp string, required bool) erro return errorIfRequired(required, "sub") } - return errorIfFalse(sub == cmp, ErrTokenInvalidIssuer) + return errorIfFalse(sub == cmp, ErrTokenInvalidSubject) } // errorIfFalse returns the error specified in err, if the value is true. diff --git a/validator_test.go b/validator_test.go new file mode 100644 index 00000000..869b0507 --- /dev/null +++ b/validator_test.go @@ -0,0 +1,261 @@ +package jwt + +import ( + "errors" + "testing" + "time" +) + +var ErrFooBar = errors.New("must be foobar") + +type MyCustomClaims struct { + Foo string `json:"foo"` + RegisteredClaims +} + +func (m MyCustomClaims) Validate() error { + if m.Foo != "bar" { + return ErrFooBar + } + return nil +} + +func Test_validator_Validate(t *testing.T) { + type fields struct { + leeway time.Duration + timeFunc func() time.Time + verifyIat bool + expectedAud string + expectedIss string + expectedSub string + } + type args struct { + claims Claims + } + tests := []struct { + name string + fields fields + args args + wantErr error + }{ + { + name: "expected iss mismatch", + fields: fields{expectedIss: "me"}, + args: args{RegisteredClaims{Issuer: "not_me"}}, + wantErr: ErrTokenInvalidIssuer, + }, + { + name: "expected iss is missing", + fields: fields{expectedIss: "me"}, + args: args{RegisteredClaims{}}, + wantErr: ErrTokenRequiredClaimMissing, + }, + { + name: "expected sub mismatch", + fields: fields{expectedSub: "me"}, + args: args{RegisteredClaims{Subject: "not-me"}}, + wantErr: ErrTokenInvalidSubject, + }, + { + name: "expected sub is missing", + fields: fields{expectedSub: "me"}, + args: args{RegisteredClaims{}}, + wantErr: ErrTokenRequiredClaimMissing, + }, + { + name: "custom validator", + fields: fields{}, + args: args{MyCustomClaims{Foo: "not-bar"}}, + wantErr: ErrFooBar, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + v := &validator{ + leeway: tt.fields.leeway, + timeFunc: tt.fields.timeFunc, + verifyIat: tt.fields.verifyIat, + expectedAud: tt.fields.expectedAud, + expectedIss: tt.fields.expectedIss, + expectedSub: tt.fields.expectedSub, + } + if err := v.Validate(tt.args.claims); (err != nil) && !errors.Is(err, tt.wantErr) { + t.Errorf("validator.Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_validator_verifyExpiresAt(t *testing.T) { + type fields struct { + leeway time.Duration + timeFunc func() time.Time + } + type args struct { + claims Claims + cmp time.Time + required bool + } + tests := []struct { + name string + fields fields + args args + wantErr error + }{ + { + name: "good claim", + fields: fields{timeFunc: time.Now}, + args: args{claims: RegisteredClaims{ExpiresAt: NewNumericDate(time.Now().Add(10 * time.Minute))}}, + wantErr: nil, + }, + { + name: "claims with invalid type", + fields: fields{}, + args: args{claims: MapClaims{"exp": "string"}}, + wantErr: ErrInvalidType, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + v := &validator{ + leeway: tt.fields.leeway, + timeFunc: tt.fields.timeFunc, + } + + err := v.verifyExpiresAt(tt.args.claims, tt.args.cmp, tt.args.required) + if (err != nil) && !errors.Is(err, tt.wantErr) { + t.Errorf("validator.verifyExpiresAt() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_validator_verifyIssuer(t *testing.T) { + type fields struct { + expectedIss string + } + type args struct { + claims Claims + cmp string + required bool + } + tests := []struct { + name string + fields fields + args args + wantErr error + }{ + { + name: "good claim", + fields: fields{expectedIss: "me"}, + args: args{claims: MapClaims{"iss": "me"}, cmp: "me"}, + wantErr: nil, + }, + { + name: "claims with invalid type", + fields: fields{expectedIss: "me"}, + args: args{claims: MapClaims{"iss": 1}, cmp: "me"}, + wantErr: ErrInvalidType, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + v := &validator{ + expectedIss: tt.fields.expectedIss, + } + err := v.verifyIssuer(tt.args.claims, tt.args.cmp, tt.args.required) + if (err != nil) && !errors.Is(err, tt.wantErr) { + t.Errorf("validator.verifyIssuer() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_validator_verifySubject(t *testing.T) { + type fields struct { + expectedSub string + } + type args struct { + claims Claims + cmp string + required bool + } + tests := []struct { + name string + fields fields + args args + wantErr error + }{ + { + name: "good claim", + fields: fields{expectedSub: "me"}, + args: args{claims: MapClaims{"sub": "me"}, cmp: "me"}, + wantErr: nil, + }, + { + name: "claims with invalid type", + fields: fields{expectedSub: "me"}, + args: args{claims: MapClaims{"sub": 1}, cmp: "me"}, + wantErr: ErrInvalidType, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + v := &validator{ + expectedSub: tt.fields.expectedSub, + } + err := v.verifySubject(tt.args.claims, tt.args.cmp, tt.args.required) + if (err != nil) && !errors.Is(err, tt.wantErr) { + t.Errorf("validator.verifySubject() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_validator_verifyIssuedAt(t *testing.T) { + type fields struct { + leeway time.Duration + timeFunc func() time.Time + verifyIat bool + } + type args struct { + claims Claims + cmp time.Time + required bool + } + tests := []struct { + name string + fields fields + args args + wantErr error + }{ + { + name: "good claim without iat", + fields: fields{verifyIat: true}, + args: args{claims: MapClaims{}, required: false}, + wantErr: nil, + }, + { + name: "good claim with iat", + fields: fields{verifyIat: true}, + args: args{ + claims: RegisteredClaims{IssuedAt: NewNumericDate(time.Now())}, + cmp: time.Now().Add(10 * time.Minute), + required: false, + }, + wantErr: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + v := &validator{ + leeway: tt.fields.leeway, + timeFunc: tt.fields.timeFunc, + verifyIat: tt.fields.verifyIat, + } + if err := v.verifyIssuedAt(tt.args.claims, tt.args.cmp, tt.args.required); (err != nil) && !errors.Is(err, tt.wantErr) { + t.Errorf("validator.verifyIssuedAt() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} From 4b346caa33be303f5bab2333605376e5982c0ef1 Mon Sep 17 00:00:00 2001 From: Christian Banse Date: Mon, 20 Feb 2023 23:24:50 +0100 Subject: [PATCH 8/9] Added malformed claims JSON --- parser.go | 4 ++-- parser_test.go | 10 ++++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/parser.go b/parser.go index 1a4ddb92..46b67931 100644 --- a/parser.go +++ b/parser.go @@ -126,7 +126,7 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke return token, parts, newError("could not base64 decode header", ErrTokenMalformed, err) } if err = json.Unmarshal(headerBytes, &token.Header); err != nil { - return token, parts, newError("could not unmarshal header", ErrTokenMalformed, err) + return token, parts, newError("could not JSON decode header", ErrTokenMalformed, err) } // parse Claims @@ -148,7 +148,7 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke } // Handle decode error if err != nil { - return token, parts, newError("could JSON decode claim", ErrTokenMalformed, err) + return token, parts, newError("could not JSON decode claim", ErrTokenMalformed, err) } // Lookup signature method diff --git a/parser_test.go b/parser_test.go index a6af8df3..fdb5eef3 100644 --- a/parser_test.go +++ b/parser_test.go @@ -65,6 +65,16 @@ var jwtTestData = []struct { nil, jwt.SigningMethodRS256, }, + { + "invalid JSON claim", + "eyJhbGciOiJSUzI1NiIsInppcCI6IkRFRiJ9.eNqqVkqtKFCyMjQ1s7Q0sbA0MtFRyk3NTUot8kxRslIKLbZQggn4JeamAoUcfRz99HxcXRWeze172tr4bFq7Ui0AAAD__w.jBXD4LT4aq4oXTgDoPkiV6n4QdSZPZI1Z4J8MWQC42aHK0oXwcovEU06dVbtB81TF-2byuu0-qi8J0GUttODT67k6gCl6DV_iuCOV7gczwTcvKslotUvXzoJ2wa0QuujnjxLEE50r0p6k0tsv_9OIFSUZzDksJFYNPlJH2eFG55DROx4TsOz98az37SujZi9GGbTc9SLgzFHPrHMrovRZ5qLC_w4JrdtsLzBBI11OQJgRYwV8fQf4O8IsMkHtetjkN7dKgUkJtRarNWOk76rpTPppLypiLU4_J0-wrElLMh1TzUVZW6Fz2cDHDDBACJgMmKQ2pOFEDK_vYZN74dLCF5GiTZV6DbXhNxO7lqT7JUN4a3p2z96G7WNRjblf2qZeuYdQvkIsiK-rCbSIE836XeY5gaBgkOzuEvzl_tMrpRmb5Oox1ibOfVT2KBh9Lvqsb1XbQjCio2CLE2ViCLqoe0AaRqlUyrk3n8BIG-r0IW4dcw96CEryEMIjsjVp9mtPXamJzf391kt8Rf3iRBqwv3zP7Plg1ResXbmsFUgOflAUPcYmfLug4W3W52ntcUlTHAKXrNfaJL9QQiYAaDukG-ZHDytsOWTuuXw7lVxjt-XYi1VbRAIjh1aIYSELEmEpE4Ny74htQtywYXMQNfJpB0nNn8IiWakgcYYMJ0TmKM", + defaultKeyFunc, + nil, + false, + []error{jwt.ErrTokenMalformed}, + nil, + jwt.SigningMethodRS256, + }, { "bearer in JWT", "bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", From 04260dc03570558cd72894c854b5a0f4836b38d6 Mon Sep 17 00:00:00 2001 From: Christian Banse Date: Tue, 21 Feb 2023 08:53:01 +0100 Subject: [PATCH 9/9] Reverted go.mod --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index aa4a2c34..7fbfcedd 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/golang-jwt/jwt/v5 -go 1.20 +go 1.18