From db246de5230e4646281dd1a13c62dc663820c0ad Mon Sep 17 00:00:00 2001 From: Dmitriy Kinoshenko Date: Fri, 14 Feb 2020 14:17:59 +0200 Subject: [PATCH] feat: JWT library with flexible signers closes #1264 Signed-off-by: Dmitriy Kinoshenko --- pkg/doc/jwt/claims.go | 123 ++++++ pkg/doc/jwt/claims_test.go | 88 ++++ pkg/doc/jwt/jwt.go | 448 +++++++++++++++++++++ pkg/doc/jwt/jwt_test.go | 749 +++++++++++++++++++++++++++++++++++ pkg/doc/jwt/support_test.go | 122 ++++++ pkg/doc/jwt/verifier.go | 116 ++++++ pkg/doc/jwt/verifier_test.go | 171 ++++++++ 7 files changed, 1817 insertions(+) create mode 100644 pkg/doc/jwt/claims.go create mode 100644 pkg/doc/jwt/claims_test.go create mode 100644 pkg/doc/jwt/jwt.go create mode 100644 pkg/doc/jwt/jwt_test.go create mode 100644 pkg/doc/jwt/support_test.go create mode 100644 pkg/doc/jwt/verifier.go create mode 100644 pkg/doc/jwt/verifier_test.go diff --git a/pkg/doc/jwt/claims.go b/pkg/doc/jwt/claims.go new file mode 100644 index 0000000000..a79bc5b492 --- /dev/null +++ b/pkg/doc/jwt/claims.go @@ -0,0 +1,123 @@ +/* + * + * * Copyright SecureKey Technologies Inc. All Rights Reserved. + * * + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package jwt + +import ( + "encoding/json" + "errors" + "strconv" + "time" +) + +// Claims defines JWT Claims Set (https://tools.ietf.org/html/rfc7519#section-4) +type Claims struct { + + // Issuer defines iss claim. + Issuer string `json:"iss,omitempty"` + + // Subject defines sub claim. + Subject string `json:"sub,omitempty"` + + // Audience defines aud claim. + Audience Audience `json:"aud,omitempty"` + + // Expiry defines exp claim. + Expiry *NumericDate `json:"exp,omitempty"` + + // NotBefore defines nbf claim. + NotBefore *NumericDate `json:"nbf,omitempty"` + + // IssuedAt defines iat claim. + IssuedAt *NumericDate `json:"iat,omitempty"` + + // ID defines jti claim. + ID string `json:"jti,omitempty"` +} + +// NumericDate is a JSON numeric value representing the number of seconds from +// 1970-01-01T00:00:00Z UTC until the specified UTC date/time, ignoring leap seconds. +type NumericDate int64 + +// NewNumericDate creates a new NumericDate from time.Time +func NewNumericDate(t time.Time) *NumericDate { + if t.IsZero() { + return nil + } + + nd := NumericDate(t.Unix()) + + return &nd +} + +// MarshalJSON converts NumericDate to JSON bytes. +func (n NumericDate) MarshalJSON() ([]byte, error) { + return []byte(strconv.FormatInt(int64(n), 10)), nil +} + +// UnmarshalJSON defines custom unmarshalling of NumericDate from JSON bytes. +func (n *NumericDate) UnmarshalJSON(data []byte) error { + s := string(data) + + f, err := strconv.ParseFloat(s, 64) + if err != nil { + return errors.New("not a number value") + } + + *n = NumericDate(f) + + return nil +} + +// Time provides time.Time value of NumericDate. +func (n NumericDate) Time() time.Time { + return time.Unix(int64(n), 0) +} + +// Audience identifies the recipients that the JWT is intended for. +type Audience []string + +// MarshalJSON converts Audience to JSON bytes. +func (s Audience) MarshalJSON() ([]byte, error) { + if len(s) == 1 { + return json.Marshal(s[0]) + } + + return json.Marshal([]string(s)) +} + +// UnmarshalJSON defines custom unmarshalling of Audience from JSON bytes. +func (s *Audience) UnmarshalJSON(b []byte) error { + var i interface{} + + if err := json.Unmarshal(b, &i); err != nil { + return err + } + + switch v := i.(type) { + case string: + *s = []string{v} + case []interface{}: + a := make([]string, len(v)) + + for i, e := range v { + s, ok := e.(string) + if !ok { + return errors.New("expecting string Audience item") + } + + a[i] = s + } + + *s = a + default: + return errors.New("expecting string or []interface{} Audience") + } + + return nil +} diff --git a/pkg/doc/jwt/claims_test.go b/pkg/doc/jwt/claims_test.go new file mode 100644 index 0000000000..3a09b019f0 --- /dev/null +++ b/pkg/doc/jwt/claims_test.go @@ -0,0 +1,88 @@ +/* + * Copyright SecureKey Technologies Inc. All Rights Reserved. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package jwt + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestNewNumericDate(t *testing.T) { + require.NotNil(t, NewNumericDate(time.Now())) + require.Nil(t, NewNumericDate(time.Time{})) +} + +func TestNumericDate_MarshalJSON(t *testing.T) { + date := NewNumericDate(time.Now()) + bytes, err := json.Marshal(date) + require.NoError(t, err) + require.NotEmpty(t, bytes) +} + +func TestNumericDate_UnmarshalJSON(t *testing.T) { + date := NewNumericDate(time.Now()) + bytes, err := json.Marshal(date) + require.NoError(t, err) + require.NotEmpty(t, bytes) + + var parsedDate NumericDate + err = json.Unmarshal(bytes, &parsedDate) + require.NoError(t, err) + + err = parsedDate.UnmarshalJSON([]byte("not a number")) + require.Error(t, err) + require.EqualError(t, err, "not a number value") +} + +func TestNumericDate_Time(t *testing.T) { + now := time.Unix(0, 0) + date := NewNumericDate(now) + require.True(t, now.Equal(date.Time())) +} + +func TestAudience_MarshalJSON(t *testing.T) { + single := Audience{"aud"} + bytes, err := json.Marshal(single) + require.NoError(t, err) + require.Equal(t, "\"aud\"", string(bytes)) + + many := Audience{"aud1", "aud2"} + bytes, err = json.Marshal(many) + require.NoError(t, err) + require.Equal(t, "[\"aud1\",\"aud2\"]", string(bytes)) +} + +func TestAudience_UnmarshalJSON(t *testing.T) { + var aud Audience + + // single + err := json.Unmarshal([]byte("\"aud\""), &aud) + require.NoError(t, err) + require.Equal(t, Audience{"aud"}, aud) + + // many + err = json.Unmarshal([]byte("[\"aud1\",\"aud2\"]"), &aud) + require.NoError(t, err) + require.Equal(t, Audience{"aud1", "aud2"}, aud) + + // invalid aud in many + err = json.Unmarshal([]byte("[\"aud1\",7]"), &aud) + require.Error(t, err) + require.EqualError(t, err, "expecting string Audience item") + + // invalid aud + err = json.Unmarshal([]byte("7"), &aud) + require.Error(t, err) + require.EqualError(t, err, "expecting string or []interface{} Audience") + + // invalid JSON + err = aud.UnmarshalJSON([]byte("invalid JSON")) + require.Error(t, err) +} diff --git a/pkg/doc/jwt/jwt.go b/pkg/doc/jwt/jwt.go new file mode 100644 index 0000000000..c7d5a22246 --- /dev/null +++ b/pkg/doc/jwt/jwt.go @@ -0,0 +1,448 @@ +/* + * Copyright SecureKey Technologies Inc. All Rights Reserved. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package jwt + +import ( + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "reflect" + "strings" +) + +const ( + // HeaderB64 defines b64 header of boolean type + HeaderB64 = "b64" // bool + + // HeaderType defines typ header of string type + HeaderType = "typ" // string + + // HeaderAlgorithm defines alg header of string type + HeaderAlgorithm = "alg" // string + + // HeaderKeyID defines kid header of string type + HeaderKeyID = "kid" // string + + // TypeJWT defines JWT type + TypeJWT = "JWT" + + // AlgorithmNone used to indicate unsecured JWT + AlgorithmNone = "none" + + // SignatureEdDSA defines EdDSA alg + SignatureEdDSA = "EdDSA" + + // SignatureRS256 defines RS256 alg + SignatureRS256 = "RS256" +) + +const ( + jwtPartsNumber = 3 + jwtHeaderPart = 0 + jwtPayloadPart = 1 + jwtSignaturePart = 2 +) + +// JSONWebToken defines JWT basic operations with JWT headers, payload and signature. +type JSONWebToken interface { + + // Claims fills input c with claims. It could be structure (e.g. Claims or its extension) or map[string]interface{}. + Claims(c interface{}) error + + // VerificationData build verification data which can be used for signing or verification. + VerificationData() ([]byte, error) + + // Header makes look up of particular header. + Header(name string) (interface{}, bool) + + // StringHeader makes look up of particular header with string value. + StringHeader(name string) string + + // Signature returns a signature if defined. + Signature() []byte + + // CompactSerialize makes compact serialization of token. + CompactSerialize(detached bool) (string, error) +} + +type jsonWebToken struct { + headers map[string]interface{} + + payload map[string]interface{} + + signature []byte +} + +// CompactSerialize implements compact serialization of token. +func (j *jsonWebToken) CompactSerialize(detached bool) (string, error) { + headers, err := toBase64(j.headers) + if err != nil { + return "", fmt.Errorf("base64 headers: %w", err) + } + + payload := "" + if !detached { + payload, err = toBase64(j.payload) + if err != nil { + return "", fmt.Errorf("base64 payload: %w", err) + } + } + + signature := base64.RawURLEncoding.EncodeToString(j.signature) + + return fmt.Sprintf("%s.%s.%s", headers, payload, signature), nil +} + +// Claims fills input c with JWT claims taken from payload of a token. +func (j *jsonWebToken) Claims(c interface{}) error { + pBytes, err := json.Marshal(j.payload) + if err != nil { + return err + } + + return json.Unmarshal(pBytes, c) +} + +// VerificationData build verification data which can be used for signing or verification. +func (j *jsonWebToken) VerificationData() ([]byte, error) { + return jwsVerificationData(j.headers, j.payload) +} + +// Header makes look up of particular header. +func (j *jsonWebToken) Header(name string) (interface{}, bool) { + v, ok := j.headers[name] + + return v, ok +} + +// StringHeader makes look up of particular header with string value. +func (j *jsonWebToken) StringHeader(name string) string { + if headerValue, ok := j.headers[name]; ok { + if headerStrValue, ok := headerValue.(string); ok { + return headerStrValue + } + } + + return "" +} + +// Signature returns a copy of JWT signature if defined. +func (j *jsonWebToken) Signature() []byte { + if j.signature == nil { + return nil + } + + return append(j.signature[:0:0], j.signature...) +} + +// Signer defines JWT Signer interface. It makes signing of data and provides custom JWT headers relevant to the signer. +type Signer interface { + // Sign signs. + Sign(data []byte) ([]byte, error) + + // Headers provides JWT headers. + Headers() map[string]interface{} +} + +// Builder construct JSON Web Token. +type Builder interface { + // Signed creates JWS. + Signed(signer Signer) (JSONWebToken, error) + + // Unsecured creates unsecured JWT. + Unsecured(headers map[string]interface{}) (JSONWebToken, error) +} + +// builder implements Builder. +type builder struct { + claims map[string]interface{} + + err error +} + +// Signed creates JWS. +func (b *builder) Signed(signer Signer) (JSONWebToken, error) { + if b.err != nil { + return nil, b.err + } + + // build headers + headers := signer.Headers() + + verificationData, err := jwsVerificationData(headers, b.claims) + if err != nil { + return nil, fmt.Errorf("prepare JWT verification data: %w", err) + } + + signature, err := signer.Sign(verificationData) + if err != nil { + return nil, fmt.Errorf("sign JWT verification data: %w", err) + } + + return &jsonWebToken{ + headers: headers, + payload: b.claims, + signature: signature, + }, nil +} + +// Unsecured creates unsecured JWT. +func (b *builder) Unsecured(headers map[string]interface{}) (JSONWebToken, error) { + if b.err != nil { + return nil, b.err + } + + return &jsonWebToken{ + headers: PrepareJWTUnsecuredHeaders(headers), + payload: b.claims, + }, nil +} + +// New initiates creation of a new JSON Web Tokne. claims could be anything marshallable to the map. +func New(claims interface{}) Builder { + m, err := toMap(claims) + + return &builder{ + claims: m, + err: err, + } +} + +// FromSigned parses JWS in compact serialization format. +// verifier can be nil, in this case signature verification is not made. +func FromSigned(jws string, verifier Verifier, detachedPayload []byte) (JSONWebToken, error) { + parts := strings.Split(jws, ".") + if len(parts) != jwtPartsNumber { + return nil, errors.New("invalid JWT compact format") + } + + headers, err := base64.RawURLEncoding.DecodeString(parts[jwtHeaderPart]) + if err != nil { + return nil, fmt.Errorf("decode base64 header: %w", err) + } + + var jwtHeaders map[string]interface{} + + err = json.Unmarshal(headers, &jwtHeaders) + if err != nil { + return nil, fmt.Errorf("unmarshal JSON headers: %w", err) + } + + var payload []byte + if len(detachedPayload) == 0 { + payload, err = base64.RawURLEncoding.DecodeString(parts[jwtPayloadPart]) + if err != nil { + return nil, fmt.Errorf("decode base64 payload: %w", err) + } + } else { + payload = detachedPayload + } + + var jwtPayload map[string]interface{} + + err = json.Unmarshal(payload, &jwtPayload) + if err != nil { + return nil, fmt.Errorf("unmarshal JSON payload: %w", err) + } + + signature, err := base64.RawURLEncoding.DecodeString(parts[jwtSignaturePart]) + if err != nil { + return nil, fmt.Errorf("decode base64 signature: %w", err) + } + + jwt := &jsonWebToken{ + headers: jwtHeaders, + payload: jwtPayload, + signature: signature, + } + + if verifier != nil { + err = verifier.Verify(jwt) + if err != nil { + return nil, err + } + } + + return jwt, nil +} + +// FromUnsecured parses unsecured JWT in compact serialization format. +func FromUnsecured(jwtUnsecured string, detachedPayload []byte) (JSONWebToken, error) { + parts := strings.Split(jwtUnsecured, ".") + if len(parts) != jwtPartsNumber { + return nil, errors.New("invalid JWT compact format") + } + + headers, err := base64.RawURLEncoding.DecodeString(parts[jwtHeaderPart]) + if err != nil { + return nil, fmt.Errorf("decode base64 header: %w", err) + } + + var jwtHeaders map[string]interface{} + + err = json.Unmarshal(headers, &jwtHeaders) + if err != nil { + return nil, fmt.Errorf("unmarshal JSON headers: %w", err) + } + + var payload []byte + if len(detachedPayload) == 0 { + payload, err = base64.RawURLEncoding.DecodeString(parts[jwtPayloadPart]) + if err != nil { + return nil, fmt.Errorf("decode base64 payload: %w", err) + } + } else { + payload = detachedPayload + } + + var jwtPayload map[string]interface{} + + err = json.Unmarshal(payload, &jwtPayload) + if err != nil { + return nil, fmt.Errorf("unmarshal JSON payload: %w", err) + } + + return &jsonWebToken{ + headers: jwtHeaders, + payload: jwtPayload, + signature: nil, + }, nil +} + +// PrepareJWSHeaders refines input headers with JWT typ and input alg. +func PrepareJWSHeaders(headers map[string]interface{}, alg string) map[string]interface{} { + newHeaders := make(map[string]interface{}) + + for k, v := range headers { + newHeaders[k] = v + } + + newHeaders[HeaderType] = TypeJWT + newHeaders[HeaderAlgorithm] = alg + + return newHeaders +} + +// PrepareJWTUnsecuredHeaders refines input headers with JWT typ and none alg. +func PrepareJWTUnsecuredHeaders(headers map[string]interface{}) map[string]interface{} { + newHeaders := make(map[string]interface{}) + + for k, v := range headers { + newHeaders[k] = v + } + + newHeaders[HeaderType] = TypeJWT + newHeaders[HeaderAlgorithm] = AlgorithmNone + + return newHeaders +} + +// IsCompactJWS checks that input is a JWS in compact form. +func IsCompactJWS(s string) bool { + parts := strings.Split(s, ".") + + return len(parts) == 3 && + isValidCompactJWTPart(parts[jwtHeaderPart]) && // headers JSON + (parts[jwtPayloadPart] == "" || isValidCompactJWTPart(parts[jwtPayloadPart])) && // could be detached + parts[jwtSignaturePart] != "" // not empty signature +} + +// IsCompactUnsecuredJWT checks that input is a u in compact form. +func IsCompactUnsecuredJWT(s string) bool { + parts := strings.Split(s, ".") + + return len(parts) == 3 && + isValidCompactJWTPart(parts[jwtHeaderPart]) && // headers JSON + (parts[jwtPayloadPart] == "" || isValidCompactJWTPart(parts[jwtPayloadPart])) && // could be detached + parts[jwtSignaturePart] == "" // empty signature +} + +func jwsVerificationData(headers, payload map[string]interface{}) ([]byte, error) { + headersBytes, err := json.Marshal(headers) + if err != nil { + return nil, fmt.Errorf("serialize JWT headers: %w", err) + } + + hBase64 := true + + if b64, ok := headers[HeaderB64]; ok { + if hBase64, ok = b64.(bool); !ok { + return nil, errors.New("invalid b64 header") + } + } + + payloadBytes, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("serialize JWT payload: %w", err) + } + + headersStr := base64.RawURLEncoding.EncodeToString(headersBytes) + + var payloadStr string + + if hBase64 { + payloadStr = base64.RawURLEncoding.EncodeToString(payloadBytes) + } else { + payloadStr = string(payloadBytes) + } + + return []byte(fmt.Sprintf("%s.%s", headersStr, payloadStr)), nil +} + +func isValidCompactJWTPart(s string) bool { + b, err := base64.RawURLEncoding.DecodeString(s) + if err != nil { + return false + } + + var j map[string]interface{} + err = json.Unmarshal(b, &j) + + return err == nil +} + +func toBase64(m map[string]interface{}) (string, error) { + b, err := json.Marshal(m) + if err != nil { + return "", err + } + + return base64.RawURLEncoding.EncodeToString(b), nil +} + +func toMap(i interface{}) (map[string]interface{}, error) { + if reflect.ValueOf(i).Kind() == reflect.Map { + return i.(map[string]interface{}), nil + } + + var ( + b []byte + err error + ) + + switch cv := i.(type) { + case []byte: + b = cv + case string: + b = []byte(cv) + default: + b, err = json.Marshal(i) + if err != nil { + return nil, fmt.Errorf("convert to bytes: ") + } + } + + var m map[string]interface{} + + err = json.Unmarshal(b, &m) + if err != nil { + return nil, fmt.Errorf("convert to map: %v", err) + } + + return m, nil +} diff --git a/pkg/doc/jwt/jwt_test.go b/pkg/doc/jwt/jwt_test.go new file mode 100644 index 0000000000..8ce3d01d56 --- /dev/null +++ b/pkg/doc/jwt/jwt_test.go @@ -0,0 +1,749 @@ +/* + * Copyright SecureKey Technologies Inc. All Rights Reserved. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package jwt + +import ( + "crypto/ed25519" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "strings" + "testing" + "time" + + "github.com/square/go-jose/v3/jwt" + + "github.com/stretchr/testify/require" +) + +type CustomClaim struct { + *Claims + + PrivateClaim1 string `json:"privateClaim1,omitempty"` +} + +func TestNew(t *testing.T) { + issued := time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC) + expiry := time.Date(2022, time.January, 1, 0, 0, 0, 0, time.UTC) + notBefore := time.Date(2021, time.January, 1, 0, 0, 0, 0, time.UTC) + + claims := &CustomClaim{ + Claims: &Claims{ + Issuer: "iss", + Subject: "sub", + Audience: []string{"aud"}, + Expiry: NewNumericDate(expiry), + NotBefore: NewNumericDate(notBefore), + IssuedAt: NewNumericDate(issued), + ID: "id", + }, + + PrivateClaim1: "private claim", + } + + t.Run("Create JWS signed by EdDSA", func(t *testing.T) { + r := require.New(t) + + pubKey, privKey, err := ed25519.GenerateKey(rand.Reader) + r.NoError(err) + + signer := newEd25519SingleKeySigner(privKey, nil) + + jwtToken, err := New(claims).Signed(signer) + r.NoError(err) + r.NotEmpty(jwtToken) + + jws, err := jwtToken.CompactSerialize(false) + r.NoError(err) + r.NotEmpty(jws) + + var parsedClaims CustomClaim + err = verifyEd25519ViaGoJose(jws, pubKey, &parsedClaims) + r.NoError(err) + r.Equal(*claims, parsedClaims) + + err = verifyEd25519(jws, pubKey) + r.NoError(err) + }) + + t.Run("Create JWS signed by RS256", func(t *testing.T) { + r := require.New(t) + + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + r.NoError(err) + + pubKey := &privKey.PublicKey + + signer := newRS256SingleKeySigner(privKey, nil) + + jwtToken, err := New(claims).Signed(signer) + r.NoError(err) + r.NotEmpty(jwtToken) + + jws, err := jwtToken.CompactSerialize(false) + r.NoError(err) + r.NotEmpty(jws) + + var parsedClaims CustomClaim + err = verifyRS256ViaGoJose(jws, pubKey, &parsedClaims) + r.NoError(err) + r.Equal(*claims, parsedClaims) + + err = verifyRS256(jws, pubKey) + r.NoError(err) + }) + + t.Run("Create unsecured JWT", func(t *testing.T) { + r := require.New(t) + + jwtToken, err := New(claims).Unsecured(nil) + r.NoError(err) + r.NotEmpty(jwtToken) + + jwtUnsecured, err := jwtToken.CompactSerialize(false) + r.NoError(err) + r.NotEmpty(jwtUnsecured) + + decodedJWTToken, err := FromUnsecured(jwtUnsecured, nil) + require.NoError(t, err) + require.Equal(t, jwtToken, decodedJWTToken) + }) +} + +func TestJsonWebToken_CompactSerialize(t *testing.T) { + token := getValidJSONWebToken() + + jwtSerialized, err := token.CompactSerialize(false) + require.NoError(t, err) + require.NotEmpty(t, jwtSerialized) + jwtParts := strings.Split(jwtSerialized, ".") + require.Len(t, jwtParts, 3) + require.NotEmpty(t, jwtParts[0]) + require.NotEmpty(t, jwtParts[1]) + require.NotEmpty(t, jwtParts[2]) + + jwtSerialized, err = token.CompactSerialize(true) + require.NoError(t, err) + require.NotEmpty(t, jwtSerialized) + jwtParts = strings.Split(jwtSerialized, ".") + require.Len(t, jwtParts, 3) + require.NotEmpty(t, jwtParts[0]) + require.Empty(t, jwtParts[1]) + require.NotEmpty(t, jwtParts[2]) + + jwtSerialized, err = getJSONWebTokenWithInvalidHeaders().CompactSerialize(true) + require.Error(t, err) + require.Contains(t, err.Error(), "base64 headers") + require.Empty(t, jwtSerialized) + + jwtSerialized, err = getJSONWebTokenWithInvalidPayload().CompactSerialize(false) + require.Error(t, err) + require.Contains(t, err.Error(), "base64 payload") + require.Empty(t, jwtSerialized) +} + +func TestJsonWebToken_Claims(t *testing.T) { + token := getValidJSONWebToken() + + var tokensMap map[string]interface{} + + err := token.Claims(&tokensMap) + require.NoError(t, err) + require.Equal(t, map[string]interface{}{"iss": "Albert"}, tokensMap) + + var claims Claims + + err = token.Claims(&claims) + require.NoError(t, err) + require.Equal(t, Claims{Issuer: "Albert"}, claims) + + err = getJSONWebTokenWithInvalidPayload().Claims(&claims) + require.Error(t, err) +} + +func TestJsonWebToken_VerificationData(t *testing.T) { + validToken := getValidJSONWebToken() + verificationData, err := validToken.VerificationData() + require.NoError(t, err) + require.NotEmpty(t, verificationData) + + validToken.headers["b64"] = false + verificationData, err = validToken.VerificationData() + require.NoError(t, err) + require.NotEmpty(t, verificationData) + + validToken.headers["b64"] = "not boolean" + verificationData, err = validToken.VerificationData() + require.Error(t, err) + require.EqualError(t, err, "invalid b64 header") + require.Empty(t, verificationData) + + verificationData, err = getJSONWebTokenWithInvalidHeaders().VerificationData() + require.Error(t, err) + require.Empty(t, verificationData) + + verificationData, err = getJSONWebTokenWithInvalidPayload().VerificationData() + require.Error(t, err) + require.Empty(t, verificationData) +} + +func TestJsonWebToken_Header(t *testing.T) { + token := getValidJSONWebToken() + + header, ok := token.Header("typ") + require.True(t, ok) + require.Equal(t, "JWT", header) + + header, ok = token.Header("undef") + require.False(t, ok) + require.Empty(t, header) +} + +func TestJsonWebToken_StringHeader(t *testing.T) { + token := getValidJSONWebToken() + + require.Equal(t, "JWT", token.StringHeader("typ")) + + require.Empty(t, token.StringHeader("undef")) + + token.headers["not_str"] = 55 + require.Empty(t, token.StringHeader("not_str")) +} + +type testSigner struct { + headers map[string]interface{} + signErr error + signResult []byte +} + +func (t testSigner) Sign(_ []byte) ([]byte, error) { + return t.signResult, t.signErr +} + +func (t testSigner) Headers() map[string]interface{} { + return t.headers +} + +func TestBuilder_Signed(t *testing.T) { + jwtBuilder := &builder{ + claims: map[string]interface{}{"iss": "Albert"}, + } + + signer := &testSigner{ + headers: map[string]interface{}{"typ": "JWT", "alg": "EdDSA"}, + signErr: nil, + signResult: []byte("sign result"), + } + + token, err := jwtBuilder.Signed(signer) + require.NoError(t, err) + require.NotNil(t, token) + + signerWithInvalidHeaders := &testSigner{ + headers: getUnmarshallableMap(), + signErr: nil, + signResult: []byte("sign result"), + } + token, err = jwtBuilder.Signed(signerWithInvalidHeaders) + require.Error(t, err) + require.Contains(t, err.Error(), "prepare JWT verification data") + require.Nil(t, token) + + signerWithSignError := &testSigner{ + headers: map[string]interface{}{"typ": "JWT", "alg": "EdDSA"}, + signErr: errors.New("signature error"), + signResult: []byte("sign result"), + } + token, err = jwtBuilder.Signed(signerWithSignError) + require.Error(t, err) + require.EqualError(t, err, "sign JWT verification data: signature error") + require.Nil(t, token) + + jwtBuilderWithError := &builder{ + err: errors.New("builder error"), + } + token, err = jwtBuilderWithError.Signed(signer) + require.Error(t, err) + require.EqualError(t, err, "builder error") + require.Nil(t, token) +} + +func TestBuilder_Unsecured(t *testing.T) { + jwtBuilder := &builder{ + claims: map[string]interface{}{"iss": "Albert"}, + } + headers := map[string]interface{}{"typ": "JWT", "alg": "EdDSA"} + + token, err := jwtBuilder.Unsecured(headers) + require.NoError(t, err) + require.NotNil(t, token) + + jwtBuilderWithError := &builder{ + err: errors.New("builder error"), + } + token, err = jwtBuilderWithError.Unsecured(headers) + require.Error(t, err) + require.EqualError(t, err, "builder error") + require.Nil(t, token) +} + +type testVerifier struct { + err error +} + +func (t testVerifier) Verify(_ JSONWebToken) error { + return t.err +} + +func TestFromSigned(t *testing.T) { + t.Run("Successful FromSigned() with not detached payload", func(t *testing.T) { + token, err := New(Claims{Issuer: "Albert"}).Signed(&testSigner{signResult: []byte("signature")}) + require.NoError(t, err) + + jws, err := token.CompactSerialize(false) + require.NoError(t, err) + + decodedToken, err := FromSigned(jws, &testVerifier{err: nil}, nil) + require.NoError(t, err) + require.Equal(t, token, decodedToken) + }) + + t.Run("Successful FromSigned() with detached payload", func(t *testing.T) { + token, err := New(Claims{Issuer: "Albert"}).Signed(&testSigner{signResult: []byte("signature")}) + require.NoError(t, err) + + jws, err := token.CompactSerialize(false) + require.NoError(t, err) + + // keep payload + jwsSplit := strings.Split(jws, ".") + jwsPayload, err := base64.RawURLEncoding.DecodeString(jwsSplit[1]) + require.NoError(t, err) + + jws, err = token.CompactSerialize(true) + require.NoError(t, err) + + decodedToken, err := FromSigned(jws, &testVerifier{err: nil}, jwsPayload) + require.NoError(t, err) + require.Equal(t, token, decodedToken) + }) + + t.Run("FromSigned() failures", func(t *testing.T) { + validToken, err := New(Claims{Issuer: "Albert"}).Signed(&testSigner{signResult: []byte("signature")}) + require.NoError(t, err) + + validJWS, err := validToken.CompactSerialize(false) + require.NoError(t, err) + + validJWSParts := strings.Split(validJWS, ".") + + token, err := FromSigned("invalid jws", nil, nil) + require.Error(t, err) + require.EqualError(t, err, "invalid JWT compact format") + require.Nil(t, token) + + jwsWithInvalidHeaders := fmt.Sprintf("%s.%s.%s", "invalid", validJWSParts[1], validJWSParts[2]) + token, err = FromSigned(jwsWithInvalidHeaders, nil, nil) + require.Error(t, err) + require.Contains(t, err.Error(), "unmarshal JSON headers") + require.Nil(t, token) + + corruptedBased64 := "XXXXXaGVsbG8=" + + jwsWithInvalidHeaders = fmt.Sprintf("%s.%s.%s", corruptedBased64, validJWSParts[1], validJWSParts[2]) + token, err = FromSigned(jwsWithInvalidHeaders, nil, nil) + require.Error(t, err) + require.Contains(t, err.Error(), "decode base64 header") + require.Nil(t, token) + + jwsWithInvalidPayload := fmt.Sprintf("%s.%s.%s", validJWSParts[0], "invalid", validJWSParts[2]) + token, err = FromSigned(jwsWithInvalidPayload, nil, nil) + require.Error(t, err) + require.Contains(t, err.Error(), "unmarshal JSON payload") + require.Nil(t, token) + + jwsWithInvalidPayload = fmt.Sprintf("%s.%s.%s", validJWSParts[0], corruptedBased64, validJWSParts[2]) + token, err = FromSigned(jwsWithInvalidPayload, nil, nil) + require.Error(t, err) + require.Contains(t, err.Error(), "decode base64 payload") + require.Nil(t, token) + + jwsWithInvalidSignature := fmt.Sprintf("%s.%s.%s", validJWSParts[0], validJWSParts[1], corruptedBased64) + token, err = FromSigned(jwsWithInvalidSignature, nil, nil) + require.Error(t, err) + require.Contains(t, err.Error(), "decode base64 signature") + require.Nil(t, token) + + token, err = FromSigned(validJWS, &testVerifier{err: errors.New("verification error")}, nil) + require.Error(t, err) + require.EqualError(t, err, "verification error") + require.Nil(t, token) + }) +} + +func TestFromUnsecured(t *testing.T) { + t.Run("Successful FromUnsecured() with not detached payload", func(t *testing.T) { + token, err := New(Claims{Issuer: "Albert"}).Unsecured(nil) + require.NoError(t, err) + + jws, err := token.CompactSerialize(false) + require.NoError(t, err) + + decodedToken, err := FromUnsecured(jws, nil) + require.NoError(t, err) + require.Equal(t, token, decodedToken) + }) + + t.Run("Successful FromUnsecured() with detached payload", func(t *testing.T) { + token, err := New(Claims{Issuer: "Albert"}).Unsecured(nil) + require.NoError(t, err) + + jws, err := token.CompactSerialize(false) + require.NoError(t, err) + + // keep payload + jwsSplit := strings.Split(jws, ".") + jwsPayload, err := base64.RawURLEncoding.DecodeString(jwsSplit[1]) + require.NoError(t, err) + + jws, err = token.CompactSerialize(true) + require.NoError(t, err) + + decodedToken, err := FromUnsecured(jws, jwsPayload) + require.NoError(t, err) + require.Equal(t, token, decodedToken) + }) + + t.Run("FromUnsecured() failures", func(t *testing.T) { + validToken, err := New(Claims{Issuer: "Albert"}).Unsecured(nil) + require.NoError(t, err) + + validJWS, err := validToken.CompactSerialize(false) + require.NoError(t, err) + + validJWSParts := strings.Split(validJWS, ".") + + token, err := FromUnsecured("invalid jws", nil) + require.Error(t, err) + require.EqualError(t, err, "invalid JWT compact format") + require.Nil(t, token) + + jwsWithInvalidHeaders := fmt.Sprintf("%s.%s.%s", "invalid", validJWSParts[1], validJWSParts[2]) + token, err = FromUnsecured(jwsWithInvalidHeaders, nil) + require.Error(t, err) + require.Contains(t, err.Error(), "unmarshal JSON headers") + require.Nil(t, token) + + corruptedBased64 := "XXXXXaGVsbG8=" + + jwsWithInvalidHeaders = fmt.Sprintf("%s.%s.%s", corruptedBased64, validJWSParts[1], validJWSParts[2]) + token, err = FromUnsecured(jwsWithInvalidHeaders, nil) + require.Error(t, err) + require.Contains(t, err.Error(), "decode base64 header") + require.Nil(t, token) + + jwsWithInvalidPayload := fmt.Sprintf("%s.%s.%s", validJWSParts[0], "invalid", validJWSParts[2]) + token, err = FromUnsecured(jwsWithInvalidPayload, nil) + require.Error(t, err) + require.Contains(t, err.Error(), "unmarshal JSON payload") + require.Nil(t, token) + + jwsWithInvalidPayload = fmt.Sprintf("%s.%s.%s", validJWSParts[0], corruptedBased64, validJWSParts[2]) + token, err = FromUnsecured(jwsWithInvalidPayload, nil) + require.Error(t, err) + require.Contains(t, err.Error(), "decode base64 payload") + require.Nil(t, token) + }) +} + +func TestPrepareJWSHeaders(t *testing.T) { + headers := map[string]interface{}{ + "kid": "did:example:abfe13f712120431c276e12ecab#keys-1", + "b64": false, + } + + jwsHeaders := PrepareJWSHeaders(headers, "EdDSA") + + require.Equal(t, map[string]interface{}{ + "kid": "did:example:abfe13f712120431c276e12ecab#keys-1", + "b64": false, + "typ": "JWT", + "alg": "EdDSA", + }, jwsHeaders) +} + +func TestPrepareJWTUnsecuredHeaders(t *testing.T) { + headers := map[string]interface{}{ + "kid": "did:example:abfe13f712120431c276e12ecab#keys-1", + "b64": false, + } + + jwsHeaders := PrepareJWTUnsecuredHeaders(headers) + + require.Equal(t, map[string]interface{}{ + "kid": "did:example:abfe13f712120431c276e12ecab#keys-1", + "b64": false, + "typ": "JWT", + "alg": "none", + }, jwsHeaders) +} + +func getValidJSONWebToken() *jsonWebToken { + return &jsonWebToken{ + headers: map[string]interface{}{"typ": "JWT", "alg": "EdDSA"}, + payload: map[string]interface{}{"iss": "Albert"}, + signature: []byte("signature"), + } +} + +func getJSONWebTokenWithInvalidHeaders() *jsonWebToken { + return &jsonWebToken{ + headers: getUnmarshallableMap(), + payload: map[string]interface{}{"iss": "Albert"}, + signature: []byte("signature"), + } +} + +func getJSONWebTokenWithInvalidPayload() *jsonWebToken { + return &jsonWebToken{ + headers: map[string]interface{}{"typ": "JWT", "alg": "EdDSA"}, + payload: getUnmarshallableMap(), + signature: []byte("signature")} +} + +func verifyEd25519ViaGoJose(jws string, pubKey ed25519.PublicKey, claims interface{}) error { + jwtToken, err := jwt.ParseSigned(jws) + if err != nil { + return fmt.Errorf("parse VC from signed JWS: %w", err) + } + + if err = jwtToken.Claims(pubKey, claims); err != nil { + return fmt.Errorf("verify JWT signature: %w", err) + } + + return nil +} + +func verifyRS256ViaGoJose(jws string, pubKey *rsa.PublicKey, claims interface{}) error { + jwtToken, err := jwt.ParseSigned(jws) + if err != nil { + return fmt.Errorf("parse VC from signed JWS: %w", err) + } + + if err = jwtToken.Claims(pubKey, claims); err != nil { + return fmt.Errorf("verify JWT signature: %w", err) + } + + return nil +} + +func verifyEd25519(jws string, pubKey ed25519.PublicKey) error { + verifier, err := newEd25519SingleKeyVerifier(pubKey) + if err != nil { + return err + } + + _, err = FromSigned(jws, verifier, nil) + + if err != nil { + return err + } + + return nil +} + +func verifyRS256(jws string, pubKey *rsa.PublicKey) error { + verifier := newRS256SingleKeyVerifier(pubKey) + + _, err := FromSigned(jws, verifier, nil) + + if err != nil { + return err + } + + return nil +} + +func TestIsCompactJWS(t *testing.T) { + b64 := base64.RawURLEncoding.EncodeToString([]byte("not json")) + j, err := json.Marshal(map[string]string{"alg": "none"}) + require.NoError(t, err) + + jb64 := base64.RawURLEncoding.EncodeToString(j) + + type args struct { + data string + } + + tests := []struct { + name string + args args + want bool + }{ + { + name: "two parts only", + args: args{"two parts.only"}, + want: false, + }, + { + name: "empty third part", + args: args{"empty third.part."}, + want: false, + }, + { + name: "part 1 is not base64 decoded", + args: args{"not base64.part2.part3"}, + want: false, + }, + { + name: "part 1 is not JSON", + args: args{fmt.Sprintf("%s.part2.part3", b64)}, + want: false, + }, + { + name: "part 2 is not base64 decoded", + args: args{fmt.Sprintf("%s.not base64.part3", jb64)}, + want: false, + }, + { + name: "part 2 is not JSON", + args: args{fmt.Sprintf("%s.%s.part3", jb64, b64)}, + want: false, + }, + { + name: "is JWS", + args: args{fmt.Sprintf("%s.%s.signature", jb64, jb64)}, + want: true, + }, + } + + for i := range tests { + tt := tests[i] + t.Run(tt.name, func(t *testing.T) { + if got := IsCompactJWS(tt.args.data); got != tt.want { + t.Errorf("isJWS() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIsCompactUnsecuredJWT(t *testing.T) { + b64 := base64.RawURLEncoding.EncodeToString([]byte("not json")) + j, err := json.Marshal(map[string]string{"alg": "none"}) + require.NoError(t, err) + + jb64 := base64.RawURLEncoding.EncodeToString(j) + + type args struct { + data string + } + + tests := []struct { + name string + args args + want bool + }{ + { + name: "two parts only", + args: args{"two parts.only"}, + want: false, + }, + { + name: "empty third part", + args: args{"empty third.part."}, + want: false, + }, + { + name: "part 1 is not base64 decoded", + args: args{"not base64.part2.part3"}, + want: false, + }, + { + name: "part 1 is not JSON", + args: args{fmt.Sprintf("%s.part2.part3", b64)}, + want: false, + }, + { + name: "part 2 is not base64 decoded", + args: args{fmt.Sprintf("%s.not base64.part3", jb64)}, + want: false, + }, + { + name: "part 2 is not JSON", + args: args{fmt.Sprintf("%s.%s.part3", jb64, b64)}, + want: false, + }, + { + name: "is JWT unsecured", + args: args{fmt.Sprintf("%s.%s.signature", jb64, jb64)}, + want: false, + }, + { + name: "is JWS, not JWT unsecured", + args: args{fmt.Sprintf("%s.%s.", jb64, jb64)}, + want: true, + }, + } + + for i := range tests { + tt := tests[i] + t.Run(tt.name, func(t *testing.T) { + if got := IsCompactUnsecuredJWT(tt.args.data); got != tt.want { + t.Errorf("isJWS() = %v, want %v", got, tt.want) + } + }) + } +} + +type testToMapStruct struct { + TestField string `json:"a"` +} + +func Test_toMap(t *testing.T) { + inputMap := map[string]interface{}{"a": "b"} + + r := require.New(t) + + // pass map + resultMap, err := toMap(inputMap) + r.NoError(err) + r.Equal(inputMap, resultMap) + + // pass []byte + inputMapBytes, err := json.Marshal(inputMap) + r.NoError(err) + resultMap, err = toMap(inputMapBytes) + r.NoError(err) + r.Equal(inputMap, resultMap) + + // pass string + inputMapStr := string(inputMapBytes) + resultMap, err = toMap(inputMapStr) + r.NoError(err) + r.Equal(inputMap, resultMap) + + // pass struct + s := testToMapStruct{TestField: "b"} + resultMap, err = toMap(s) + r.NoError(err) + r.Equal(inputMap, resultMap) + + // pass invalid []byte + resultMap, err = toMap([]byte("not JSON")) + r.Error(err) + r.Contains(err.Error(), "convert to map") + r.Nil(resultMap) + + // pass invalid structure + resultMap, err = toMap(make(chan int)) + r.Error(err) + r.Contains(err.Error(), "convert to bytes") + r.Nil(resultMap) +} diff --git a/pkg/doc/jwt/support_test.go b/pkg/doc/jwt/support_test.go new file mode 100644 index 0000000000..c1e4cc0567 --- /dev/null +++ b/pkg/doc/jwt/support_test.go @@ -0,0 +1,122 @@ +/* + * Copyright SecureKey Technologies Inc. All Rights Reserved. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package jwt + +import ( + "crypto" + "crypto/ed25519" + "crypto/rand" + "crypto/rsa" + "errors" +) + +type ed25519SingleKeySigner struct { + privKey []byte + headers map[string]interface{} +} + +func newEd25519SingleKeySigner(privKey []byte, headers map[string]interface{}) *ed25519SingleKeySigner { + return &ed25519SingleKeySigner{ + privKey: privKey, + headers: PrepareJWSHeaders(headers, SignatureEdDSA), + } +} + +func (s ed25519SingleKeySigner) Sign(data []byte) ([]byte, error) { + return ed25519.Sign(s.privKey, data), nil +} + +func (s ed25519SingleKeySigner) Headers() map[string]interface{} { + return s.headers +} + +type ed25519SingleKeyVerifier struct { + pubKey []byte +} + +func newEd25519SingleKeyVerifier(pubKey []byte) (*ed25519SingleKeyVerifier, error) { + if l := len(pubKey); l != ed25519.PublicKeySize { + return nil, errors.New("bad ed25519 public key length") + } + + return &ed25519SingleKeyVerifier{pubKey: pubKey}, nil +} + +func (v ed25519SingleKeyVerifier) Verify(j JSONWebToken) error { + pubKey := v.pubKey + signature := j.Signature() + + claims, err := j.VerificationData() + if err != nil { + return err + } + + if ok := ed25519.Verify(pubKey, claims, signature); !ok { + return errors.New("signature doesn't match") + } + + return nil +} + +type rs256SingleKeySigner struct { + privKey *rsa.PrivateKey + headers map[string]interface{} +} + +func newRS256SingleKeySigner(privKey *rsa.PrivateKey, headers map[string]interface{}) *rs256SingleKeySigner { + return &rs256SingleKeySigner{ + privKey: privKey, + headers: PrepareJWSHeaders(headers, SignatureRS256), + } +} + +func (s rs256SingleKeySigner) Sign(data []byte) ([]byte, error) { + hash := crypto.SHA256.New() + + _, err := hash.Write(data) + if err != nil { + return nil, err + } + + hashed := hash.Sum(nil) + + return rsa.SignPKCS1v15(rand.Reader, s.privKey, crypto.SHA256, hashed) +} + +func (s rs256SingleKeySigner) Headers() map[string]interface{} { + return s.headers +} + +type rs256SingleKeyVerifier struct { + pubKey *rsa.PublicKey +} + +func newRS256SingleKeyVerifier(pubKey *rsa.PublicKey) *rs256SingleKeyVerifier { + return &rs256SingleKeyVerifier{pubKey: pubKey} +} + +func (v rs256SingleKeyVerifier) Verify(j JSONWebToken) error { + verificationData, err := j.VerificationData() + if err != nil { + return err + } + + hash := crypto.SHA256.New() + + _, err = hash.Write(verificationData) + if err != nil { + return err + } + + hashed := hash.Sum(nil) + + return rsa.VerifyPKCS1v15(v.pubKey, crypto.SHA256, hashed, j.Signature()) +} + +func getUnmarshallableMap() map[string]interface{} { + return map[string]interface{}{"error": map[chan int]interface{}{make(chan int): 6}} +} diff --git a/pkg/doc/jwt/verifier.go b/pkg/doc/jwt/verifier.go new file mode 100644 index 0000000000..cab309f349 --- /dev/null +++ b/pkg/doc/jwt/verifier.go @@ -0,0 +1,116 @@ +/* + * Copyright SecureKey Technologies Inc. All Rights Reserved. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package jwt + +import ( + "crypto" + "crypto/ed25519" + "crypto/rsa" + "errors" + "fmt" +) + +// Verifier makes verification of JSON Web Token. +type Verifier interface { + + // Verify verifies JWT. + Verify(j JSONWebToken) error +} + +// KeyResolver resolves public key based on what and kid. +type KeyResolver interface { + + // Resolve resolves public key. + Resolve(what, kid string) (interface{}, error) +} + +type basicVerifier struct { + resolver KeyResolver +} + +// NewVerifier creates a new basic Verifier. +func NewVerifier(resolver KeyResolver) Verifier { + return &basicVerifier{resolver: resolver} +} + +// Verify verifies JSON Web Token. +func (v *basicVerifier) Verify(j JSONWebToken) error { + claims := Claims{} + + err := j.Claims(&claims) + if err != nil { + return fmt.Errorf("read claims from JSON Web Token: %w", err) + } + + kid := j.StringHeader(HeaderKeyID) + + pubKey, err := v.resolver.Resolve(claims.Issuer, kid) + if err != nil { + return err + } + + alg := j.StringHeader(HeaderAlgorithm) + if alg == "" { + return errors.New("JWS alg is not defined") + } + + verificationData, err := j.VerificationData() + if err != nil { + return fmt.Errorf("build JWS verification data: %w", err) + } + + switch alg { + case SignatureEdDSA: + return VerifyEdDSA(pubKey, verificationData, j.Signature()) + + case SignatureRS256: + return VerifyRS256(pubKey, verificationData, j.Signature()) + + default: + return fmt.Errorf("unsupported alg: %s", alg) + } +} + +// VerifyEdDSA verifies EdDSA signature. +func VerifyEdDSA(pubKey interface{}, message, signature []byte) error { + pubKeyEdDSA, ok := pubKey.([]byte) + if !ok { + pubKeyEdDSA, ok = pubKey.(ed25519.PublicKey) + if !ok { + return errors.New("not []byte or *ed25519.PublicKey public key") + } + } + + if l := len(pubKeyEdDSA); l != ed25519.PublicKeySize { + return errors.New("bad ed25519 public key length") + } + + if ok := ed25519.Verify(pubKeyEdDSA, message, signature); !ok { + return errors.New("signature doesn't match") + } + + return nil +} + +// VerifyRS256 verifies RS256 signature. +func VerifyRS256(pubKey interface{}, message, signature []byte) error { + pubKeyRsa, ok := pubKey.(*rsa.PublicKey) + if !ok { + return errors.New("not *rsa.PublicKey public key") + } + + hash := crypto.SHA256.New() + + _, err := hash.Write(message) + if err != nil { + return err + } + + hashed := hash.Sum(nil) + + return rsa.VerifyPKCS1v15(pubKeyRsa, crypto.SHA256, hashed, signature) +} diff --git a/pkg/doc/jwt/verifier_test.go b/pkg/doc/jwt/verifier_test.go new file mode 100644 index 0000000000..d4ffee82c4 --- /dev/null +++ b/pkg/doc/jwt/verifier_test.go @@ -0,0 +1,171 @@ +/* + * Copyright SecureKey Technologies Inc. All Rights Reserved. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package jwt + +import ( + "crypto" + "crypto/ed25519" + "crypto/rand" + "crypto/rsa" + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +type testKeyResolver struct { + pubKey interface{} + err error +} + +func (r testKeyResolver) Resolve(_, _ string) (interface{}, error) { + return r.pubKey, r.err +} + +func TestNewVerifier(t *testing.T) { + r := require.New(t) + + t.Run("Verify JWT signed by EdDSA", func(t *testing.T) { + pubKey, privKey, err := ed25519.GenerateKey(rand.Reader) + r.NoError(err) + + signer := newEd25519SingleKeySigner(privKey, nil) + + token, err := New(&Claims{Issuer: "Mike"}).Signed(signer) + r.NoError(err) + + verifier := NewVerifier(&testKeyResolver{pubKey: pubKey}) + err = verifier.Verify(token) + r.NoError(err) + }) + + t.Run("Verify JWT signed by RS256", func(t *testing.T) { + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + r.NoError(err) + + pubKey := &privKey.PublicKey + + signer := newRS256SingleKeySigner(privKey, nil) + + token, err := New(&Claims{Issuer: "Mike"}).Signed(signer) + r.NoError(err) + + verifier := NewVerifier(&testKeyResolver{pubKey: pubKey}) + err = verifier.Verify(token) + r.NoError(err) + }) +} + +func TestBasicVerifier_Verify(t *testing.T) { // error corner cases + r := require.New(t) + + pubKey, privKey, err := ed25519.GenerateKey(rand.Reader) + r.NoError(err) + + signer := newEd25519SingleKeySigner(privKey, nil) + + validToken, err := New(&Claims{Issuer: "Mike"}).Signed(signer) + r.NoError(err) + + verifier := NewVerifier(&testKeyResolver{pubKey: pubKey}) + + // Invalid token claims + token := &jsonWebToken{payload: getUnmarshallableMap()} + err = verifier.Verify(token) + r.Error(err) + r.Contains(err.Error(), "read claims from JSON Web Token") + + // no JWS algorithm + token = &jsonWebToken{ + payload: validToken.(*jsonWebToken).payload, + headers: map[string]interface{}{}, // alg header is missing + } + err = verifier.Verify(token) + r.Error(err) + r.EqualError(err, "JWS alg is not defined") + + // failed to build verification data + token = &jsonWebToken{ + payload: validToken.(*jsonWebToken).payload, + headers: map[string]interface{}{"alg": "EdDSA", "b64": "incorrect value"}, + } + err = verifier.Verify(token) + r.Error(err) + r.Contains(err.Error(), "build JWS verification data") + + // unsupported algorithm + token = &jsonWebToken{ + payload: validToken.(*jsonWebToken).payload, + headers: map[string]interface{}{"alg": "unknown"}, + } + err = verifier.Verify(token) + r.Error(err) + r.EqualError(err, "unsupported alg: unknown") + + // key resolver error + verifier = NewVerifier(&testKeyResolver{err: errors.New("failed to resolve public key")}) + err = verifier.Verify(validToken) + r.Error(err) + r.Contains(err.Error(), "failed to resolve public key") +} + +func TestVerifyEdDSA(t *testing.T) { + r := require.New(t) + + pubKey, privKey, err := ed25519.GenerateKey(rand.Reader) + r.NoError(err) + + signature := ed25519.Sign(privKey, []byte("test message")) + + err = VerifyEdDSA(pubKey, []byte("test message"), signature) + r.NoError(err) + + err = VerifyEdDSA([]byte("invalid pub key"), []byte("test message"), signature) + r.Error(err) + r.EqualError(err, "bad ed25519 public key length") + + anotherPubKey, _, err := ed25519.GenerateKey(rand.Reader) + r.NoError(err) + + err = VerifyEdDSA(anotherPubKey, []byte("test message"), signature) + r.Error(err) + r.EqualError(err, "signature doesn't match") + + err = VerifyEdDSA("not EdDSA public key", []byte("test message"), signature) + r.Error(err) + r.EqualError(err, "not []byte or *ed25519.PublicKey public key") +} + +func TestVerifyRS256(t *testing.T) { + r := require.New(t) + + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + r.NoError(err) + + hash := crypto.SHA256.New() + + _, err = hash.Write([]byte("test message")) + r.NoError(err) + + hashed := hash.Sum(nil) + + signature, err := rsa.SignPKCS1v15(rand.Reader, privKey, crypto.SHA256, hashed) + r.NoError(err) + + err = VerifyRS256(&privKey.PublicKey, []byte("test message"), signature) + r.NoError(err) + + anotherPrivKey, err := rsa.GenerateKey(rand.Reader, 2048) + r.NoError(err) + + err = VerifyRS256(&anotherPrivKey.PublicKey, []byte("test message"), signature) + r.Error(err) + + err = VerifyRS256("not RS256 public key", []byte("test message"), signature) + r.Error(err) + r.EqualError(err, "not *rsa.PublicKey public key") +}