Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moving DecodeSegement to Parser #278

Merged
merged 2 commits into from
Mar 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 7 additions & 15 deletions ecdsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,7 @@ func (m *SigningMethodECDSA) Alg() string {

// Verify implements token verification for the SigningMethod.
// For this verify method, key must be an ecdsa.PublicKey struct
func (m *SigningMethodECDSA) Verify(signingString, signature string, key interface{}) error {
var err error

// Decode the signature
var sig []byte
if sig, err = DecodeSegment(signature); err != nil {
return err
}

func (m *SigningMethodECDSA) Verify(signingString string, sig []byte, key interface{}) error {
// Get the key
var ecdsaKey *ecdsa.PublicKey
switch k := key.(type) {
Expand Down Expand Up @@ -97,19 +89,19 @@ func (m *SigningMethodECDSA) Verify(signingString, signature string, key interfa

// Sign implements token signing for the SigningMethod.
// For this signing method, key must be an ecdsa.PrivateKey struct
func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) (string, error) {
func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) ([]byte, error) {
// Get the key
var ecdsaKey *ecdsa.PrivateKey
switch k := key.(type) {
case *ecdsa.PrivateKey:
ecdsaKey = k
default:
return "", ErrInvalidKeyType
return nil, ErrInvalidKeyType
}

// Create the hasher
if !m.Hash.Available() {
return "", ErrHashUnavailable
return nil, ErrHashUnavailable
}

hasher := m.Hash.New()
Expand All @@ -120,7 +112,7 @@ func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) (string
curveBits := ecdsaKey.Curve.Params().BitSize

if m.CurveBits != curveBits {
return "", ErrInvalidKey
return nil, ErrInvalidKey
}

keyBytes := curveBits / 8
Expand All @@ -135,8 +127,8 @@ func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) (string
r.FillBytes(out[0:keyBytes]) // r is assigned to the first half of output.
s.FillBytes(out[keyBytes:]) // s is assigned to the second half of output.

return EncodeSegment(out), nil
return out, nil
} else {
return "", err
return nil, err
}
}
26 changes: 21 additions & 5 deletions ecdsa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package jwt_test
import (
"crypto/ecdsa"
"os"
"reflect"
"strings"
"testing"

Expand Down Expand Up @@ -65,7 +66,7 @@ func TestECDSAVerify(t *testing.T) {
parts := strings.Split(data.tokenString, ".")

method := jwt.GetSigningMethod(data.alg)
err = method.Verify(strings.Join(parts[0:2], "."), parts[2], ecdsaKey)
err = method.Verify(strings.Join(parts[0:2], "."), decodeSegment(t, parts[2]), ecdsaKey)
if data.valid && err != nil {
t.Errorf("[%v] Error while verifying key: %v", data.name, err)
}
Expand All @@ -90,12 +91,13 @@ func TestECDSASign(t *testing.T) {
toSign := strings.Join(parts[0:2], ".")
method := jwt.GetSigningMethod(data.alg)
sig, err := method.Sign(toSign, ecdsaKey)

if err != nil {
t.Errorf("[%v] Error signing token: %v", data.name, err)
}
if sig == parts[2] {
t.Errorf("[%v] Identical signatures\nbefore:\n%v\nafter:\n%v", data.name, parts[2], sig)

ssig := encodeSegment(sig)
if ssig == parts[2] {
t.Errorf("[%v] Identical signatures\nbefore:\n%v\nafter:\n%v", data.name, parts[2], ssig)
}

err = method.Verify(toSign, sig, ecdsaKey.Public())
Expand Down Expand Up @@ -155,10 +157,24 @@ func BenchmarkECDSASigning(b *testing.B) {
if err != nil {
b.Fatalf("[%v] Error signing token: %v", data.name, err)
}
if sig == parts[2] {
if reflect.DeepEqual(sig, decodeSegment(b, parts[2])) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels more like a test, do we need to assert here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not quite sure what you mean

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, because it’s a benchmark you mean? Good point

b.Fatalf("[%v] Identical signatures\nbefore:\n%v\nafter:\n%v", data.name, parts[2], sig)
}
}
})
}
}

func decodeSegment(t interface{ Fatalf(string, ...any) }, signature string) (sig []byte) {
var err error
sig, err = jwt.NewParser().DecodeSegment(signature)
if err != nil {
t.Fatalf("could not decode segment: %v", err)
}

return
}

func encodeSegment(sig []byte) string {
return (&jwt.Token{}).EncodeSegment(sig)
}
25 changes: 10 additions & 15 deletions ed25519.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ func (m *SigningMethodEd25519) Alg() string {

// Verify implements token verification for the SigningMethod.
// For this verify method, key must be an ed25519.PublicKey
func (m *SigningMethodEd25519) Verify(signingString, signature string, key interface{}) error {
var err error
func (m *SigningMethodEd25519) Verify(signingString string, sig []byte, key interface{}) error {
var ed25519Key ed25519.PublicKey
var ok bool

Expand All @@ -47,12 +46,6 @@ func (m *SigningMethodEd25519) Verify(signingString, signature string, key inter
return ErrInvalidKey
}

// Decode the signature
var sig []byte
if sig, err = DecodeSegment(signature); err != nil {
return err
}

// Verify the signature
if !ed25519.Verify(ed25519Key, []byte(signingString), sig) {
return ErrEd25519Verification
Expand All @@ -63,23 +56,25 @@ func (m *SigningMethodEd25519) Verify(signingString, signature string, key inter

// Sign implements token signing for the SigningMethod.
// For this signing method, key must be an ed25519.PrivateKey
func (m *SigningMethodEd25519) Sign(signingString string, key interface{}) (string, error) {
func (m *SigningMethodEd25519) Sign(signingString string, key interface{}) ([]byte, error) {
var ed25519Key crypto.Signer
var ok bool

if ed25519Key, ok = key.(crypto.Signer); !ok {
return "", ErrInvalidKeyType
return nil, ErrInvalidKeyType
}

if _, ok := ed25519Key.Public().(ed25519.PublicKey); !ok {
return "", ErrInvalidKey
return nil, ErrInvalidKey
}

// Sign the string and return the encoded result
// ed25519 performs a two-pass hash as part of its algorithm. Therefore, we need to pass a non-prehashed message into the Sign function, as indicated by crypto.Hash(0)
// Sign the string and return the result. ed25519 performs a two-pass hash
// as part of its algorithm. Therefore, we need to pass a non-prehashed
// message into the Sign function, as indicated by crypto.Hash(0)
sig, err := ed25519Key.Sign(rand.Reader, []byte(signingString), crypto.Hash(0))
if err != nil {
return "", err
return nil, err
}
return EncodeSegment(sig), nil

return sig, nil
}
8 changes: 5 additions & 3 deletions ed25519_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func TestEd25519Verify(t *testing.T) {

method := jwt.GetSigningMethod(data.alg)

err = method.Verify(strings.Join(parts[0:2], "."), parts[2], ed25519Key)
err = method.Verify(strings.Join(parts[0:2], "."), decodeSegment(t, parts[2]), ed25519Key)
if data.valid && err != nil {
t.Errorf("[%v] Error while verifying key: %v", data.name, err)
}
Expand Down Expand Up @@ -77,8 +77,10 @@ func TestEd25519Sign(t *testing.T) {
if err != nil {
t.Errorf("[%v] Error signing token: %v", data.name, err)
}
if sig == parts[2] && !data.valid {
t.Errorf("[%v] Identical signatures\nbefore:\n%v\nafter:\n%v", data.name, parts[2], sig)

ssig := encodeSegment(sig)
if ssig == parts[2] && !data.valid {
t.Errorf("[%v] Identical signatures\nbefore:\n%v\nafter:\n%v", data.name, parts[2], ssig)
}
}
}
16 changes: 5 additions & 11 deletions hmac.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,13 @@ func (m *SigningMethodHMAC) Alg() string {
}

// Verify implements token verification for the SigningMethod. Returns nil if the signature is valid.
func (m *SigningMethodHMAC) Verify(signingString, signature string, key interface{}) error {
func (m *SigningMethodHMAC) Verify(signingString string, sig []byte, key interface{}) error {
// Verify the key is the right type
keyBytes, ok := key.([]byte)
if !ok {
return ErrInvalidKeyType
}

// Decode signature, for comparison
sig, err := DecodeSegment(signature)
if err != nil {
return err
}

// Can we use the specified hashing method?
if !m.Hash.Available() {
return ErrHashUnavailable
Expand All @@ -79,17 +73,17 @@ func (m *SigningMethodHMAC) Verify(signingString, signature string, key interfac

// Sign implements token signing for the SigningMethod.
// Key must be []byte
func (m *SigningMethodHMAC) Sign(signingString string, key interface{}) (string, error) {
func (m *SigningMethodHMAC) Sign(signingString string, key interface{}) ([]byte, error) {
if keyBytes, ok := key.([]byte); ok {
if !m.Hash.Available() {
return "", ErrHashUnavailable
return nil, ErrHashUnavailable
}

hasher := hmac.New(m.Hash.New, keyBytes)
hasher.Write([]byte(signingString))

return EncodeSegment(hasher.Sum(nil)), nil
return hasher.Sum(nil), nil
}

return "", ErrInvalidKeyType
return nil, ErrInvalidKeyType
}
5 changes: 3 additions & 2 deletions hmac_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package jwt_test

import (
"os"
"reflect"
"strings"
"testing"

Expand Down Expand Up @@ -53,7 +54,7 @@ func TestHMACVerify(t *testing.T) {
parts := strings.Split(data.tokenString, ".")

method := jwt.GetSigningMethod(data.alg)
err := method.Verify(strings.Join(parts[0:2], "."), parts[2], hmacTestKey)
err := method.Verify(strings.Join(parts[0:2], "."), decodeSegment(t, parts[2]), hmacTestKey)
if data.valid && err != nil {
t.Errorf("[%v] Error while verifying key: %v", data.name, err)
}
Expand All @@ -72,7 +73,7 @@ func TestHMACSign(t *testing.T) {
if err != nil {
t.Errorf("[%v] Error signing token: %v", data.name, err)
}
if sig != parts[2] {
if !reflect.DeepEqual(sig, decodeSegment(t, parts[2])) {
t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", data.name, sig, parts[2])
}
}
Expand Down
11 changes: 6 additions & 5 deletions none.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ func (m *signingMethodNone) Alg() string {
}

// Only allow 'none' alg type if UnsafeAllowNoneSignatureType is specified as the key
func (m *signingMethodNone) Verify(signingString, signature string, key interface{}) (err error) {
func (m *signingMethodNone) Verify(signingString string, sig []byte, key interface{}) (err error) {
// Key must be UnsafeAllowNoneSignatureType to prevent accidentally
// accepting 'none' signing method
if _, ok := key.(unsafeNoneMagicConstant); !ok {
return NoneSignatureTypeDisallowedError
}
// If signing method is none, signature must be an empty string
if signature != "" {
if string(sig) != "" {
return newError("'none' signing method with non-empty signature", ErrTokenUnverifiable)
}

Expand All @@ -41,9 +41,10 @@ func (m *signingMethodNone) Verify(signingString, signature string, key interfac
}

// Only allow 'none' signing if UnsafeAllowNoneSignatureType is specified as the key
func (m *signingMethodNone) Sign(signingString string, key interface{}) (string, error) {
func (m *signingMethodNone) Sign(signingString string, key interface{}) ([]byte, error) {
if _, ok := key.(unsafeNoneMagicConstant); ok {
return "", nil
return []byte{}, nil
}
return "", NoneSignatureTypeDisallowedError

return nil, NoneSignatureTypeDisallowedError
}
5 changes: 3 additions & 2 deletions none_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package jwt_test

import (
"reflect"
"strings"
"testing"

Expand Down Expand Up @@ -46,7 +47,7 @@ func TestNoneVerify(t *testing.T) {
parts := strings.Split(data.tokenString, ".")

method := jwt.GetSigningMethod(data.alg)
err := method.Verify(strings.Join(parts[0:2], "."), parts[2], data.key)
err := method.Verify(strings.Join(parts[0:2], "."), decodeSegment(t, parts[2]), data.key)
if data.valid && err != nil {
t.Errorf("[%v] Error while verifying key: %v", data.name, err)
}
Expand All @@ -65,7 +66,7 @@ func TestNoneSign(t *testing.T) {
if err != nil {
t.Errorf("[%v] Error signing token: %v", data.name, err)
}
if sig != parts[2] {
if !reflect.DeepEqual(sig, decodeSegment(t, parts[2])) {
t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", data.name, sig, parts[2])
}
}
Expand Down
Loading