Skip to content

Commit

Permalink
Add more context to errors #3
Browse files Browse the repository at this point in the history
  • Loading branch information
o1egl committed Nov 8, 2018
1 parent 0c84ee6 commit ca68a37
Show file tree
Hide file tree
Showing 10 changed files with 80 additions and 59 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ require (
github.com/aead/chacha20poly1305 v0.0.0-20170617001512-233f39982aeb
github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pkg/errors v0.8.0
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/stretchr/testify v1.2.2
golang.org/x/crypto v0.0.0-20181025213731-e84da0312774
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 h1:52m0LGchQBBVqJRyY
github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635/go.mod h1:lmLxL+FV291OopO93Bwf9fQLQeLyt33VJRUg5VJ30us=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pkg/errors v0.8.0 h1:WdK/asTD0HN+q6hsWO3/vpuAkAr+tw6aNJNDFFf0+qw=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w=
Expand Down
6 changes: 4 additions & 2 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ package paseto
import (
"crypto"
"strings"

"github.com/pkg/errors"
)

// Version defines the token version.
Expand Down Expand Up @@ -71,9 +73,9 @@ func ParseFooter(token string, footer interface{}) error {
if len(parts) == 4 {
b, err := tokenEncoder.DecodeString(parts[3])
if err != nil {
return err
return errors.Wrap(err, "failed to decode token")
}
return fillValue(b, footer)
return errors.Wrap(fillValue(b, footer), "failed to decode footer")
}
if len(parts) < 3 {
return ErrIncorrectTokenFormat
Expand Down
2 changes: 2 additions & 0 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ var (
ErrInvalidSignature = errors.New("invalid signature")
// ErrDataUnmarshal can't unmarshal token data to the given type of value
ErrDataUnmarshal = errors.New("can't unmarshal token data to the given type of value")
// ErrTokenValidationError invalid token data
ErrTokenValidationError = errors.New("token validation error")
)

// Protocol defines the PASETO token protocol interface.
Expand Down
18 changes: 9 additions & 9 deletions token_validator.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package paseto

import (
"errors"
"fmt"
"time"

"github.com/pkg/errors"
)

// Validator defines a JSONToken validator function.
Expand All @@ -13,7 +13,7 @@ type Validator func(token *JSONToken) error
func ForAudience(audience string) Validator {
return func(token *JSONToken) error {
if token.Audience != audience {
return fmt.Errorf(`token was not intended for "%s" audience`, audience)
return errors.Wrapf(ErrTokenValidationError, `token was not intended for "%s" audience`, audience)
}
return nil
}
Expand All @@ -23,7 +23,7 @@ func ForAudience(audience string) Validator {
func IdentifiedBy(jti string) Validator {
return func(token *JSONToken) error {
if token.Jti != jti {
return fmt.Errorf(`token was expected to be identified by "%s"`, jti)
return errors.Wrapf(ErrTokenValidationError, `token was expected to be identified by "%s"`, jti)
}
return nil
}
Expand All @@ -33,7 +33,7 @@ func IdentifiedBy(jti string) Validator {
func IssuedBy(issuer string) Validator {
return func(token *JSONToken) error {
if token.Issuer != issuer {
return fmt.Errorf(`token was not issued by "%s"`, issuer)
return errors.Wrapf(ErrTokenValidationError, `token was not issued by "%s"`, issuer)
}
return nil
}
Expand All @@ -43,7 +43,7 @@ func IssuedBy(issuer string) Validator {
func Subject(subject string) Validator {
return func(token *JSONToken) error {
if token.Subject != subject {
return fmt.Errorf(`token was not related to subject "%s"`, subject)
return errors.Wrapf(ErrTokenValidationError, `token was not related to subject "%s"`, subject)
}
return nil
}
Expand All @@ -54,13 +54,13 @@ func Subject(subject string) Validator {
func ValidAt(t time.Time) Validator {
return func(token *JSONToken) error {
if !token.IssuedAt.IsZero() && t.Before(token.IssuedAt) {
return errors.New("token was issued in the future")
return errors.Wrapf(ErrTokenValidationError, "token was issued in the future")
}
if !token.NotBefore.IsZero() && t.Before(token.NotBefore) {
return errors.New("token cannot be used yet")
return errors.Wrapf(ErrTokenValidationError, "token cannot be used yet")
}
if !token.Expiration.IsZero() && t.After(token.Expiration) {
return errors.New("token has expired")
return errors.Wrapf(ErrTokenValidationError, "token has expired")
}
return nil
}
Expand Down
5 changes: 3 additions & 2 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"encoding/json"
"io"

"github.com/pkg/errors"
"golang.org/x/crypto/hkdf"
)

Expand Down Expand Up @@ -66,13 +67,13 @@ func splitToken(token []byte, header []byte) (payload []byte, footer []byte, err

payload = make([]byte, tokenEncoder.DecodedLen(len(encodedPayload)))
if _, err = tokenEncoder.Decode(payload, encodedPayload); err != nil {
return nil, nil, err
return nil, nil, errors.Wrap(err, "failed to decode payload")
}

if encodedFooter != nil {
footer = make([]byte, tokenEncoder.DecodedLen(len(encodedFooter)))
if _, err = tokenEncoder.Decode(footer, encodedFooter); err != nil {
return nil, nil, err
return nil, nil, errors.Wrap(err, "failed to decode footer")
}
}

Expand Down
52 changes: 29 additions & 23 deletions v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"crypto/sha512"
"encoding/base64"
"io"

"github.com/pkg/errors"
)

const (
Expand Down Expand Up @@ -39,12 +41,12 @@ func NewV1() *PasetoV1 {
func (p *PasetoV1) Encrypt(key []byte, payload interface{}, footer interface{}) (string, error) {
payloadBytes, err := infToByteArr(payload)
if err != nil {
return "", err
return "", errors.Wrap(err, "failed to encode payload to []byte")
}

footerBytes, err := infToByteArr(footer)
if err != nil {
return "", err
return "", errors.Wrap(err, "failed to encode footer to []byte")
}

var rndBytes []byte
Expand All @@ -54,32 +56,32 @@ func (p *PasetoV1) Encrypt(key []byte, payload interface{}, footer interface{})
} else {
rndBytes = make([]byte, nonceSize)
if _, err := io.ReadFull(rand.Reader, rndBytes); err != nil {
return "", err
return "", errors.Wrap(err, "failed to read from rand.Reader")
}
}

macN := hmac.New(sha512.New384, rndBytes)
if _, err := macN.Write(payloadBytes); err != nil {
return "", err
return "", errors.Wrap(err, "failed to hash payload")
}
nonce := macN.Sum(nil)[:32]

encKey, authKey, err := splitKey(key, nonce[:16])
if err != nil {
return "", err
return "", errors.Wrap(err, "failed to create enc and auth keys")
}

block, err := aes.NewCipher(encKey)
if err != nil {
return "", err
return "", errors.Wrap(err, "failed to create aes cipher")
}

encryptedPayload := make([]byte, len(payloadBytes))
cipher.NewCTR(block, nonce[16:]).XORKeyStream(encryptedPayload, payloadBytes)

h := hmac.New(sha512.New384, authKey)
if _, err := h.Write(preAuthEncode(headerV1, nonce, encryptedPayload, footerBytes)); err != nil {
return "", err
return "", errors.Wrap(err, "failed to create a signature")
}

mac := h.Sum(nil)
Expand All @@ -96,7 +98,7 @@ func (p *PasetoV1) Encrypt(key []byte, payload interface{}, footer interface{})
func (p *PasetoV1) Decrypt(token string, key []byte, payload interface{}, footer interface{}) error {
data, footerBytes, err := splitToken([]byte(token), headerV1)
if err != nil {
return err
return errors.Wrap(err, "failed to decode token")
}

if len(data) < nonceSize+macSize {
Expand All @@ -109,34 +111,34 @@ func (p *PasetoV1) Decrypt(token string, key []byte, payload interface{}, footer

encKey, authKey, err := splitKey(key, nonce[:16])
if err != nil {
return err
return errors.Wrap(err, "failed to create enc and auth keys")
}

h := hmac.New(sha512.New384, authKey)
if _, err := h.Write(preAuthEncode(headerV1, nonce, encryptedPayload, footerBytes)); err != nil {
return err
return errors.Wrap(err, "failed to create a signature")
}

if !hmac.Equal(h.Sum(nil), mac) {
return ErrInvalidTokenAuth
return errors.Wrap(ErrInvalidTokenAuth, "failed to check token signature")
}

block, err := aes.NewCipher(encKey)
if err != nil {
return err
return errors.Wrap(err, "failed to create aes cipher")
}
decryptedPayload := make([]byte, len(encryptedPayload))
cipher.NewCTR(block, nonce[16:]).XORKeyStream(decryptedPayload, encryptedPayload)

if payload != nil {
if err := fillValue(decryptedPayload, payload); err != nil {
return err
return errors.Wrap(err, "failed to decode payload")
}
}

if footer != nil {
if err := fillValue(footerBytes, footer); err != nil {
return err
return errors.Wrap(err, "failed to decode footer")
}
}

Expand All @@ -152,25 +154,27 @@ func (p *PasetoV1) Sign(privateKey crypto.PrivateKey, payload interface{}, foote

payloadBytes, err := infToByteArr(payload)
if err != nil {
return "", err
return "", errors.Wrap(err, "failed to encode payload to []byte")
}

footerBytes, err := infToByteArr(footer)
if err != nil {
return "", err
return "", errors.Wrap(err, "failed to encode footer to []byte")
}

var opts rsa.PSSOptions
opts.SaltLength = rsa.PSSSaltLengthEqualsHash
PSSMessage := preAuthEncode(headerV1Public, payloadBytes, footerBytes)
sha384 := crypto.SHA384
pssHash := sha384.New()
pssHash.Write(PSSMessage)
if _, err := pssHash.Write(PSSMessage); err != nil {
return "", errors.Wrap(err, "failed to create pss hash")
}
hashed := pssHash.Sum(nil)

signature, err := rsa.SignPSS(rand.Reader, rsaPrivateKey, sha384, hashed, &opts)
if err != nil {
panic(err)
return "", errors.Wrap(err, "failed to sign token")
}

body := append(payloadBytes, signature...)
Expand All @@ -187,11 +191,11 @@ func (p *PasetoV1) Verify(token string, publicKey crypto.PublicKey, payload inte

data, footerBytes, err := splitToken([]byte(token), headerV1Public)
if err != nil {
return err
return errors.Wrap(err, "failed to decode token")
}

if len(data) < v1SignSize {
return ErrIncorrectTokenFormat
return errors.Wrap(ErrIncorrectTokenFormat, "incorrect signature size")
}

payloadBytes := data[:len(data)-v1SignSize]
Expand All @@ -202,7 +206,9 @@ func (p *PasetoV1) Verify(token string, publicKey crypto.PublicKey, payload inte
PSSMessage := preAuthEncode(headerV1Public, payloadBytes, footerBytes)
sha384 := crypto.SHA384
pssHash := sha384.New()
pssHash.Write(PSSMessage)
if _, err := pssHash.Write(PSSMessage); err != nil {
return errors.Wrap(err, "failed to create pss hash")
}
hashed := pssHash.Sum(nil)

if err = rsa.VerifyPSS(rsaPublicKey, sha384, hashed, signature, &opts); err != nil {
Expand All @@ -211,13 +217,13 @@ func (p *PasetoV1) Verify(token string, publicKey crypto.PublicKey, payload inte

if payload != nil {
if err := fillValue(payloadBytes, payload); err != nil {
return err
return errors.Wrap(err, "failed to decode payload")
}
}

if footer != nil {
if err := fillValue(footerBytes, footer); err != nil {
return err
return errors.Wrap(err, "failed to decode footer")
}
}

Expand Down
7 changes: 5 additions & 2 deletions v1_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"encoding/pem"
"testing"

"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -308,7 +309,8 @@ func TestPasetoV1_Decrypt_Error(t *testing.T) {

for name, test := range cases {
t.Run(name, func(t *testing.T) {
assert.EqualError(t, v1.Decrypt(test.token, symmetricKey, test.payload, test.footer), test.error.Error())
err := v1.Decrypt(test.token, symmetricKey, test.payload, test.footer)
assert.Equal(t, test.error, errors.Cause(err))
})
}
}
Expand Down Expand Up @@ -374,7 +376,8 @@ func TestPasetoV1_Verify_Error(t *testing.T) {

for name, test := range cases {
t.Run(name, func(t *testing.T) {
assert.EqualError(t, v1.Verify(test.token, rsaPublicKey, test.payload, test.footer), test.error.Error())
err := v1.Verify(test.token, rsaPublicKey, test.payload, test.footer)
assert.Equal(t, test.error, errors.Cause(err))
})
}
}

0 comments on commit ca68a37

Please sign in to comment.