Skip to content

Commit

Permalink
Sign/Verify does take the decoded form now
Browse files Browse the repository at this point in the history
  • Loading branch information
oxisto committed Mar 14, 2023
1 parent 352f411 commit 321782a
Show file tree
Hide file tree
Showing 18 changed files with 153 additions and 115 deletions.
12 changes: 6 additions & 6 deletions ecdsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,19 +89,19 @@ func (m *SigningMethodECDSA) Verify(signingString string, sig []byte, key interf

// Sign implements token signing for the SigningMethod.
// For this signing method, key must be an ecdsa.PrivateKey struct
func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) (string, error) {
func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) ([]byte, error) {
// Get the key
var ecdsaKey *ecdsa.PrivateKey
switch k := key.(type) {
case *ecdsa.PrivateKey:
ecdsaKey = k
default:
return "", ErrInvalidKeyType
return nil, ErrInvalidKeyType
}

// Create the hasher
if !m.Hash.Available() {
return "", ErrHashUnavailable
return nil, ErrHashUnavailable
}

hasher := m.Hash.New()
Expand All @@ -112,7 +112,7 @@ func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) (string
curveBits := ecdsaKey.Curve.Params().BitSize

if m.CurveBits != curveBits {
return "", ErrInvalidKey
return nil, ErrInvalidKey
}

keyBytes := curveBits / 8
Expand All @@ -127,8 +127,8 @@ func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) (string
r.FillBytes(out[0:keyBytes]) // r is assigned to the first half of output.
s.FillBytes(out[keyBytes:]) // s is assigned to the second half of output.

return EncodeSegment(out), nil
return out, nil
} else {
return "", err
return nil, err
}
}
18 changes: 12 additions & 6 deletions ecdsa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package jwt_test
import (
"crypto/ecdsa"
"os"
"reflect"
"strings"
"testing"

Expand Down Expand Up @@ -90,15 +91,16 @@ func TestECDSASign(t *testing.T) {
toSign := strings.Join(parts[0:2], ".")
method := jwt.GetSigningMethod(data.alg)
sig, err := method.Sign(toSign, ecdsaKey)

if err != nil {
t.Errorf("[%v] Error signing token: %v", data.name, err)
}
if sig == parts[2] {
t.Errorf("[%v] Identical signatures\nbefore:\n%v\nafter:\n%v", data.name, parts[2], sig)

ssig := encodeSegment(sig)
if ssig == parts[2] {
t.Errorf("[%v] Identical signatures\nbefore:\n%v\nafter:\n%v", data.name, parts[2], ssig)
}

err = method.Verify(toSign, decodeSegment(t, sig), ecdsaKey.Public())
err = method.Verify(toSign, sig, ecdsaKey.Public())
if err != nil {
t.Errorf("[%v] Sign produced an invalid signature: %v", data.name, err)
}
Expand Down Expand Up @@ -155,15 +157,15 @@ func BenchmarkECDSASigning(b *testing.B) {
if err != nil {
b.Fatalf("[%v] Error signing token: %v", data.name, err)
}
if sig == parts[2] {
if reflect.DeepEqual(sig, decodeSegment(b, parts[2])) {
b.Fatalf("[%v] Identical signatures\nbefore:\n%v\nafter:\n%v", data.name, parts[2], sig)
}
}
})
}
}

func decodeSegment(t *testing.T, signature string) (sig []byte) {
func decodeSegment(t interface{ Fatalf(string, ...any) }, signature string) (sig []byte) {
var err error
sig, err = jwt.NewParser().DecodeSegment(signature)
if err != nil {
Expand All @@ -172,3 +174,7 @@ func decodeSegment(t *testing.T, signature string) (sig []byte) {

return
}

func encodeSegment(sig []byte) string {
return (&jwt.Token{}).EncodeSegment(sig)
}
16 changes: 9 additions & 7 deletions ed25519.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,23 +56,25 @@ func (m *SigningMethodEd25519) Verify(signingString string, sig []byte, key inte

// Sign implements token signing for the SigningMethod.
// For this signing method, key must be an ed25519.PrivateKey
func (m *SigningMethodEd25519) Sign(signingString string, key interface{}) (string, error) {
func (m *SigningMethodEd25519) Sign(signingString string, key interface{}) ([]byte, error) {
var ed25519Key crypto.Signer
var ok bool

if ed25519Key, ok = key.(crypto.Signer); !ok {
return "", ErrInvalidKeyType
return nil, ErrInvalidKeyType
}

if _, ok := ed25519Key.Public().(ed25519.PublicKey); !ok {
return "", ErrInvalidKey
return nil, ErrInvalidKey
}

// Sign the string and return the encoded result
// ed25519 performs a two-pass hash as part of its algorithm. Therefore, we need to pass a non-prehashed message into the Sign function, as indicated by crypto.Hash(0)
// Sign the string and return the result. ed25519 performs a two-pass hash
// as part of its algorithm. Therefore, we need to pass a non-prehashed
// message into the Sign function, as indicated by crypto.Hash(0)
sig, err := ed25519Key.Sign(rand.Reader, []byte(signingString), crypto.Hash(0))
if err != nil {
return "", err
return nil, err
}
return EncodeSegment(sig), nil

return sig, nil
}
6 changes: 4 additions & 2 deletions ed25519_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,10 @@ func TestEd25519Sign(t *testing.T) {
if err != nil {
t.Errorf("[%v] Error signing token: %v", data.name, err)
}
if sig == parts[2] && !data.valid {
t.Errorf("[%v] Identical signatures\nbefore:\n%v\nafter:\n%v", data.name, parts[2], sig)

ssig := encodeSegment(sig)
if ssig == parts[2] && !data.valid {
t.Errorf("[%v] Identical signatures\nbefore:\n%v\nafter:\n%v", data.name, parts[2], ssig)
}
}
}
8 changes: 4 additions & 4 deletions hmac.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,17 @@ func (m *SigningMethodHMAC) Verify(signingString string, sig []byte, key interfa

// Sign implements token signing for the SigningMethod.
// Key must be []byte
func (m *SigningMethodHMAC) Sign(signingString string, key interface{}) (string, error) {
func (m *SigningMethodHMAC) Sign(signingString string, key interface{}) ([]byte, error) {
if keyBytes, ok := key.([]byte); ok {
if !m.Hash.Available() {
return "", ErrHashUnavailable
return nil, ErrHashUnavailable
}

hasher := hmac.New(m.Hash.New, keyBytes)
hasher.Write([]byte(signingString))

return EncodeSegment(hasher.Sum(nil)), nil
return hasher.Sum(nil), nil
}

return "", ErrInvalidKeyType
return nil, ErrInvalidKeyType
}
3 changes: 2 additions & 1 deletion hmac_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package jwt_test

import (
"os"
"reflect"
"strings"
"testing"

Expand Down Expand Up @@ -72,7 +73,7 @@ func TestHMACSign(t *testing.T) {
if err != nil {
t.Errorf("[%v] Error signing token: %v", data.name, err)
}
if sig != parts[2] {
if !reflect.DeepEqual(sig, decodeSegment(t, parts[2])) {
t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", data.name, sig, parts[2])
}
}
Expand Down
7 changes: 4 additions & 3 deletions none.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ func (m *signingMethodNone) Verify(signingString string, sig []byte, key interfa
}

// Only allow 'none' signing if UnsafeAllowNoneSignatureType is specified as the key
func (m *signingMethodNone) Sign(signingString string, key interface{}) (string, error) {
func (m *signingMethodNone) Sign(signingString string, key interface{}) ([]byte, error) {
if _, ok := key.(unsafeNoneMagicConstant); ok {
return "", nil
return []byte{}, nil
}
return "", NoneSignatureTypeDisallowedError

return nil, NoneSignatureTypeDisallowedError
}
3 changes: 2 additions & 1 deletion none_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package jwt_test

import (
"reflect"
"strings"
"testing"

Expand Down Expand Up @@ -65,7 +66,7 @@ func TestNoneSign(t *testing.T) {
if err != nil {
t.Errorf("[%v] Error signing token: %v", data.name, err)
}
if sig != parts[2] {
if !reflect.DeepEqual(sig, decodeSegment(t, parts[2])) {
t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", data.name, sig, parts[2])
}
}
Expand Down
37 changes: 31 additions & 6 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ type Parser struct {
skipClaimsValidation bool

validator *validator

decodeStrict bool

decodePaddingAllowed bool
}

// NewParser creates a new Parser with the specified options
Expand Down Expand Up @@ -169,22 +173,43 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke
return token, parts, nil
}

// DecodeSegment decodes a JWT specific base64url encoding with padding stripped
//
// Deprecated: In a future release, we will demote this function to a
// non-exported function, since it should only be used internally
// DecodeSegment decodes a JWT specific base64url encoding. This function will
// take into account whether the [Parser] is configured with additional options,
// such as [WithStrictDecoding] or [WithPaddingAllowed].
func (p *Parser) DecodeSegment(seg string) ([]byte, error) {
encoding := base64.RawURLEncoding

if DecodePaddingAllowed {
if p.decodePaddingAllowed {
if l := len(seg) % 4; l > 0 {
seg += strings.Repeat("=", 4-l)
}
encoding = base64.URLEncoding
}

if DecodeStrict {
if p.decodeStrict {
encoding = encoding.Strict()
}
return encoding.DecodeString(seg)
}

// Parse parses, validates, verifies the signature and returns the parsed token.
// keyFunc will receive the parsed token and should return the cryptographic key
// for verifying the signature. The caller is strongly encouraged to set the
// WithValidMethods option to validate the 'alg' claim in the token matches the
// expected algorithm. For more details about the importance of validating the
// 'alg' claim, see
// https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/
func Parse(tokenString string, keyFunc Keyfunc, options ...ParserOption) (*Token, error) {
return NewParser(options...).Parse(tokenString, keyFunc)
}

// ParseWithClaims is a shortcut for NewParser().ParseWithClaims().
//
// Note: If you provide a custom claim implementation that embeds one of the
// standard claims (such as RegisteredClaims), make sure that a) you either
// embed a non-pointer version of the claims or b) if you are using a pointer,
// allocate the proper memory for it before passing in the overall claims,
// otherwise you might run into a panic.
func ParseWithClaims(tokenString string, claims Claims, keyFunc Keyfunc, options ...ParserOption) (*Token, error) {
return NewParser(options...).ParseWithClaims(tokenString, claims, keyFunc)
}
26 changes: 26 additions & 0 deletions parser_option.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,29 @@ func WithSubject(sub string) ParserOption {
p.validator.expectedSub = sub
}
}

// WithPaddingAllowed will enable the codec used for decoding JWTs to allow
// padding. Note that the JWS RFC7515 states that the tokens will utilize a
// Base64url encoding with no padding. Unfortunately, some implementations of
// JWT are producing non-standard tokens, and thus require support for decoding.
// Note that this is a global variable, and updating it will change the behavior
// on a package level, and is also NOT go-routine safe. To use the
// non-recommended decoding, set this boolean to `true` prior to using this
// package.
func WithPaddingAllowed() ParserOption {
return func(p *Parser) {
p.decodePaddingAllowed = true
}
}

// WithStrictDecoding will switch the codec used for decoding JWTs into strict
// mode. In this mode, the decoder requires that trailing padding bits are zero,
// as described in RFC 4648 section 3.5. Note that this is a global variable,
// and updating it will change the behavior on a package level, and is also NOT
// go-routine safe. To use strict decoding, set this boolean to `true` prior to
// using this package.
func WithStrictDecoding() ParserOption {
return func(p *Parser) {
p.decodeStrict = true
}
}
16 changes: 10 additions & 6 deletions parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -641,9 +641,6 @@ var setPaddingTestData = []struct {
func TestSetPadding(t *testing.T) {
for _, data := range setPaddingTestData {
t.Run(data.name, func(t *testing.T) {
jwt.DecodePaddingAllowed = data.paddedDecode
jwt.DecodeStrict = data.strictDecode

// If the token string is blank, use helper function to generate string
if data.tokenString == "" {
data.tokenString = signToken(data.claims, data.signingMethod)
Expand All @@ -652,7 +649,16 @@ func TestSetPadding(t *testing.T) {
// Parse the token
var token *jwt.Token
var err error
parser := jwt.NewParser(jwt.WithoutClaimsValidation())
var opts []jwt.ParserOption = []jwt.ParserOption{jwt.WithoutClaimsValidation()}

if data.paddedDecode {
opts = append(opts, jwt.WithPaddingAllowed())
}
if data.strictDecode {
opts = append(opts, jwt.WithStrictDecoding())
}

parser := jwt.NewParser(opts...)

// Figure out correct claims type
token, err = parser.ParseWithClaims(data.tokenString, jwt.MapClaims{}, data.keyfunc)
Expand All @@ -666,8 +672,6 @@ func TestSetPadding(t *testing.T) {
}

})
jwt.DecodePaddingAllowed = false
jwt.DecodeStrict = false
}
}

Expand Down
10 changes: 5 additions & 5 deletions rsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,27 +67,27 @@ func (m *SigningMethodRSA) Verify(signingString string, sig []byte, key interfac

// Sign implements token signing for the SigningMethod
// For this signing method, must be an *rsa.PrivateKey structure.
func (m *SigningMethodRSA) Sign(signingString string, key interface{}) (string, error) {
func (m *SigningMethodRSA) Sign(signingString string, key interface{}) ([]byte, error) {
var rsaKey *rsa.PrivateKey
var ok bool

// Validate type of key
if rsaKey, ok = key.(*rsa.PrivateKey); !ok {
return "", ErrInvalidKey
return nil, ErrInvalidKey
}

// Create the hasher
if !m.Hash.Available() {
return "", ErrHashUnavailable
return nil, ErrHashUnavailable
}

hasher := m.Hash.New()
hasher.Write([]byte(signingString))

// Sign the string and return the encoded bytes
if sigBytes, err := rsa.SignPKCS1v15(rand.Reader, rsaKey, m.Hash, hasher.Sum(nil)); err == nil {
return EncodeSegment(sigBytes), nil
return sigBytes, nil
} else {
return "", err
return nil, err
}
}
Loading

0 comments on commit 321782a

Please sign in to comment.