From 9832bc458afb70e69d3950c1d528b214d468a69a Mon Sep 17 00:00:00 2001 From: Dmitriy Kinoshenko Date: Fri, 14 Feb 2020 14:17:59 +0200 Subject: [PATCH] feat: JWT library with flexible signers closes #1264 Signed-off-by: Dmitriy Kinoshenko --- pkg/doc/jose/common.go | 59 ++++++ pkg/doc/jose/jws.go | 380 ++++++++++++++++++++++++++++++++++ pkg/doc/jose/jws_test.go | 300 +++++++++++++++++++++++++++ pkg/doc/jwt/jwt.go | 271 ++++++++++++++++++++++++ pkg/doc/jwt/jwt_test.go | 389 +++++++++++++++++++++++++++++++++++ pkg/doc/jwt/support_test.go | 177 ++++++++++++++++ pkg/doc/jwt/verifier.go | 150 ++++++++++++++ pkg/doc/jwt/verifier_test.go | 167 +++++++++++++++ 8 files changed, 1893 insertions(+) create mode 100644 pkg/doc/jose/common.go create mode 100644 pkg/doc/jose/jws.go create mode 100644 pkg/doc/jose/jws_test.go create mode 100644 pkg/doc/jwt/jwt.go create mode 100644 pkg/doc/jwt/jwt_test.go create mode 100644 pkg/doc/jwt/support_test.go create mode 100644 pkg/doc/jwt/verifier.go create mode 100644 pkg/doc/jwt/verifier_test.go diff --git a/pkg/doc/jose/common.go b/pkg/doc/jose/common.go new file mode 100644 index 0000000000..d436998e6b --- /dev/null +++ b/pkg/doc/jose/common.go @@ -0,0 +1,59 @@ +/* +Copyright SecureKey Technologies Inc. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package jose + +// IANA registered JOSE headers (https://tools.ietf.org/html/rfc7515#section-4.1) +const ( + // HeaderAlgorithm identifies the cryptographic algorithm used to secure the JWS. + HeaderAlgorithm = "alg" // string + + // HeaderJWKSetURL is a URI that refers to a resource for a set of JSON-encoded public keys, one of + // which corresponds to the key used to digitally sign the JWS. + HeaderJWKSetURL = "jku" // string + + // HeaderJSONWebKey is the public key that corresponds to the key used to digitally sign the JWS. + HeaderJSONWebKey = "jwk" // JSON + + // HeaderKeyID is a hint indicating which key was used to secure the JWS. + HeaderKeyID = "kid" // string + + // HeaderX509URL is a URI that refers to a resource for the X.509 public key certificate or certificate + // chain corresponding to the key used to digitally sign the JWS. + HeaderX509URL = "x5u" + + // HeaderX509CertificateChain contains the X.509 public key certificate or certificate chain + // corresponding to the key used to digitally sign the JWS. + HeaderX509CertificateChain = "x5c" + + // HeaderX509CertificateDigest (X.509 certificate SHA-1 thumbprint) is a base64url-encoded + // SHA-1 thumbprint (a.k.a. digest) of the DER encoding of the X.509 certificate corresponding to the key + // used to digitally sign the JWS. + HeaderX509CertificateDigestSha1 = "x5t" + + // HeaderX509CertificateDigestSha256 (X.509 certificate SHA-256 thumbprint) is a base64url-encoded SHA-256 + // thumbprint (a.k.a. digest) of the DER encoding of the X.509 certificate corresponding to the key used to + // digitally sign the JWS. + HeaderX509CertificateDigestSha256 = "x5t#S256" // string + + // HeaderType is used by JWS applications to declare the media type of this complete JWS. + HeaderType = "typ" // string + + // HeaderContentType is used by JWS applications to declare the media type of the + // secured content (the payload). + HeaderContentType = "cty" // string + + // HeaderCritical indicates that extensions to this specification and/or are being used that MUST be + // understood and processed. + HeaderCritical = "crit" // array +) + +// Header defined in https://tools.ietf.org/html/rfc7797 +const ( + // HeaderB64 determines whether the payload is represented in the JWS and the JWS Signing + // Input as ASCII(BASE64URL(JWS Payload)) or as the JWS Payload value itself with no encoding performed. + HeaderB64Payload = "b64" // bool +) diff --git a/pkg/doc/jose/jws.go b/pkg/doc/jose/jws.go new file mode 100644 index 0000000000..70e10cc697 --- /dev/null +++ b/pkg/doc/jose/jws.go @@ -0,0 +1,380 @@ +/* +Copyright SecureKey Technologies Inc. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package jose + +import ( + "encoding/base64" + "errors" + "fmt" + "strings" + + "github.com/square/go-jose/v3" + "github.com/square/go-jose/v3/json" +) + +const ( + jwsPartsCount = 3 + jwsHeaderPart = 0 + jwsPayloadPart = 1 + jwsSignaturePart = 2 +) + +// Headers represents JOSE headers. +type Headers map[string]interface{} + +// JWK (JSON Web Key) is a JSON data structure that represents a cryptographic key. +type JWK jose.JSONWebKey + +// GetJWK gets JWK from JOSE headers. +func (jh Headers) GetJWK() (*JWK, bool) { + jwkRaw, ok := jh[HeaderJSONWebKey] + if !ok { + return nil, false + } + + var jwk JWK + + err := convertMapToValue(jwkRaw, &jwk) + if err != nil { + return nil, false + } + + return &jwk, true +} + +// GetKeyID gets Key ID from JOSE headers. +func (jh Headers) GetKeyID() (string, bool) { + return jh.stringValue(HeaderKeyID) +} + +// GetAlgorithm gets Key ID from JOSE headers. +func (jh Headers) GetAlgorithm() (string, bool) { + return jh.stringValue(HeaderAlgorithm) +} + +func (jh Headers) stringValue(key string) (string, bool) { + kRaw, ok := jh[key] + if !ok { + return "", false + } + + kStr, ok := kRaw.(string) + + return kStr, ok +} + +// JSONWebSignature defines JSON Web Signature (https://tools.ietf.org/html/rfc7515) +type JSONWebSignature struct { + ProtectedHeaders Headers + UnprotectedHeaders Headers + Payload []byte + + signature []byte +} + +// SignatureVerifier makes verification of JSON Web Signature. +type SignatureVerifier interface { + // Verify verifies JWS based on the signing input. + Verify(joseHeaders Headers, payload, signingInput, signature []byte) error +} + +// SignatureVerifierFunc is a function wrapper for SignatureVerifier. +type SignatureVerifierFunc func(joseHeaders Headers, payload, signingInput, signature []byte) error + +// Verify verifies JWS signature. +func (s SignatureVerifierFunc) Verify(joseHeaders Headers, payload, signingInput, signature []byte) error { + return s(joseHeaders, payload, signingInput, signature) +} + +// CompositeAlgSignatureVerifier defines composite signature verifier based on the algorithm +// taken from JOSE header alg. +type CompositeAlgSignatureVerifier struct { + verifierByAlg map[string]SignatureVerifier +} + +// NewCompositeAlgSignatureVerifier creates a new CompositeAlgSignatureVerifier +func NewCompositeAlgSignatureVerifier() *CompositeAlgSignatureVerifier { + return &CompositeAlgSignatureVerifier{ + verifierByAlg: make(map[string]SignatureVerifier), + } +} + +// AddVerifier adds a new verifier of algorithm. +// If algorithm verifier is already present, it's get rewritten. +func (v *CompositeAlgSignatureVerifier) AddVerifier( + alg string, verifier SignatureVerifier) *CompositeAlgSignatureVerifier { + v.verifierByAlg[alg] = verifier + return v +} + +// Verify verifiers JWS signature. +func (v *CompositeAlgSignatureVerifier) Verify(joseHeaders Headers, payload, signingInput, signature []byte) error { + alg, ok := joseHeaders.GetAlgorithm() + if !ok { + return errors.New("'alg' JOSE header is not present") + } + + verifier, ok := v.verifierByAlg[alg] + if !ok { + return fmt.Errorf("no verifier found for %s algorithm", alg) + } + + return verifier.Verify(joseHeaders, payload, signingInput, signature) +} + +// Signer defines JWS Signer interface. It makes signing of data and provides custom JWS headers relevant to the signer. +type Signer interface { + // Sign signs. + Sign(data []byte) ([]byte, error) + + // Headers provides JWS headers. "alg" header must be provided (see https://tools.ietf.org/html/rfc7515#section-4.1) + Headers() Headers +} + +// NewJWS creates JSON Web Signature. +func NewJWS(protectedHeaders, unprotectedHeaders Headers, payload []byte) *JSONWebSignature { + return &JSONWebSignature{ + ProtectedHeaders: protectedHeaders, + UnprotectedHeaders: unprotectedHeaders, + Payload: payload, + } +} + +// SerializeCompact makes JWS Compact Serialization (https://tools.ietf.org/html/rfc7515#section-7.1) +func (s *JSONWebSignature) SerializeCompact(signer Signer, detached bool) (string, error) { + err := s.sign(signer) + if err != nil { + return "", fmt.Errorf("calculate compact JWS signature: %w", err) + } + + joseHeaders := mergeHeaders(s.ProtectedHeaders, signer.Headers()) + + byteHeaders, err := json.Marshal(joseHeaders) + if err != nil { + return "", fmt.Errorf("marshal Protected JWS Headers: %w", err) + } + + b64Headers := base64.RawURLEncoding.EncodeToString(byteHeaders) + + b64Payload := "" + if !detached { + b64Payload = base64.RawURLEncoding.EncodeToString(s.Payload) + } + + b64Signature := base64.RawURLEncoding.EncodeToString(s.signature) + + return fmt.Sprintf("%s.%s.%s", + b64Headers, + b64Payload, + b64Signature), nil +} + +// Signature returns a copy of JWS signature. +func (s *JSONWebSignature) Signature() []byte { + if s.signature == nil { + return nil + } + + return append(s.signature[:0:0], s.signature...) +} + +func mergeHeaders(h1, h2 Headers) Headers { + h := make(Headers, len(h1)+len(h2)) + + for k, v := range h2 { + h[k] = v + } + + for k, v := range h1 { + h[k] = v + } + + return h +} + +func (s *JSONWebSignature) sign(signer Signer) error { + // build headers + headers := signer.Headers() + + err := checkJWSHeaders(headers) + if err != nil { + return fmt.Errorf("check JOSE headers: %w", err) + } + + sigInput, err := signingInput(headers, s.Payload) + if err != nil { + return fmt.Errorf("prepare JWS verification data: %w", err) + } + + signature, err := signer.Sign(sigInput) + if err != nil { + return fmt.Errorf("sign JWS verification data: %w", err) + } + + s.signature = signature + + return nil +} + +// jwsParseOpts holds options for the JWS Parsing. +type jwsParseOpts struct { + detachedPayload []byte +} + +// JWSParseOpt is the JWS Parser option. +type JWSParseOpt func(opts *jwsParseOpts) + +// WithJWSDetachedPayload option is for definition of JWS detached payload. +func WithJWSDetachedPayload(payload []byte) JWSParseOpt { + return func(opts *jwsParseOpts) { + opts.detachedPayload = payload + } +} + +// ParseJWS parses serialized JWS. Currently only JWS Compact Serialization parsing is supported. +func ParseJWS(jws string, verifier SignatureVerifier, opts ...JWSParseOpt) (*JSONWebSignature, error) { + pOpts := &jwsParseOpts{} + + for _, opt := range opts { + opt(pOpts) + } + + if strings.HasPrefix(jws, "{") { + // TODO support JWS JSON serialization format + // https://github.com/hyperledger/aries-framework-go/issues/1331 + return nil, errors.New("JWS JSON serialization is not supported") + } + + return parseCompacted(jws, verifier, pOpts) +} + +// IsCompactJWS checks weather input is a compact JWS (based on https://tools.ietf.org/html/rfc7516#section-9) +func IsCompactJWS(s string) bool { + parts := strings.Split(s, ".") + + return len(parts) == jwsPartsCount +} + +func parseCompacted(jwsCompact string, verifier SignatureVerifier, opts *jwsParseOpts) (*JSONWebSignature, error) { + parts := strings.Split(jwsCompact, ".") + if len(parts) != jwsPartsCount { + return nil, errors.New("invalid JWS compact format") + } + + joseHeaders, err := parseCompactedHeaders(parts) + if err != nil { + return nil, err + } + + payload, err := parseCompactedPayload(parts[jwsPayloadPart], opts) + if err != nil { + return nil, err + } + + sInput, err := signingInput(joseHeaders, payload) + if err != nil { + return nil, fmt.Errorf("build signing input: %w", err) + } + + signature, err := base64.RawURLEncoding.DecodeString(parts[jwsSignaturePart]) + if err != nil { + return nil, fmt.Errorf("decode base64 signature: %w", err) + } + + err = verifier.Verify(joseHeaders, payload, sInput, signature) + if err != nil { + return nil, err + } + + return &JSONWebSignature{ + ProtectedHeaders: joseHeaders, + Payload: payload, + signature: signature, + }, nil +} + +func parseCompactedPayload(jwsPayload string, opts *jwsParseOpts) ([]byte, error) { + if len(opts.detachedPayload) > 0 { + return opts.detachedPayload, nil + } + + payload, err := base64.RawURLEncoding.DecodeString(jwsPayload) + if err != nil { + return nil, fmt.Errorf("decode base64 payload: %w", err) + } + + return payload, nil +} + +func parseCompactedHeaders(parts []string) (Headers, error) { + headersBytes, err := base64.RawURLEncoding.DecodeString(parts[jwsHeaderPart]) + if err != nil { + return nil, fmt.Errorf("decode base64 header: %w", err) + } + + var joseHeaders Headers + + err = json.Unmarshal(headersBytes, &joseHeaders) + if err != nil { + return nil, fmt.Errorf("unmarshal JSON headers: %w", err) + } + + err = checkJWSHeaders(joseHeaders) + if err != nil { + return nil, err + } + + return joseHeaders, nil +} + +func signingInput(headers Headers, payload []byte) ([]byte, error) { + headersBytes, err := json.Marshal(headers) + if err != nil { + return nil, fmt.Errorf("serialize JWS headers: %w", err) + } + + hBase64 := true + + if b64, ok := headers[HeaderB64Payload]; ok { + if hBase64, ok = b64.(bool); !ok { + return nil, errors.New("invalid b64 header") + } + } + + headersStr := base64.RawURLEncoding.EncodeToString(headersBytes) + + var payloadStr string + + if hBase64 { + payloadStr = base64.RawURLEncoding.EncodeToString(payload) + } else { + payloadStr = string(payload) + } + + return []byte(fmt.Sprintf("%s.%s", headersStr, payloadStr)), nil +} + +func checkJWSHeaders(headers Headers) error { + if _, ok := headers[HeaderAlgorithm]; !ok { + return fmt.Errorf("%s JWS header is not defined", HeaderAlgorithm) + } + + return nil +} + +func convertMapToValue(vOriginToBeMap, vDest interface{}) error { + if _, ok := vOriginToBeMap.(map[string]interface{}); !ok { + return errors.New("expected value to be a map") + } + + mBytes, err := json.Marshal(vOriginToBeMap) + if err != nil { + return err + } + + return json.Unmarshal(mBytes, vDest) +} diff --git a/pkg/doc/jose/jws_test.go b/pkg/doc/jose/jws_test.go new file mode 100644 index 0000000000..d98862e74f --- /dev/null +++ b/pkg/doc/jose/jws_test.go @@ -0,0 +1,300 @@ +/* +Copyright SecureKey Technologies Inc. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package jose + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/base64" + "errors" + "fmt" + "strings" + "testing" + + "github.com/square/go-jose/v3/json" + "github.com/stretchr/testify/require" +) + +func TestHeaders_GetJWK(t *testing.T) { + headers := Headers{} + + pubKey, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + jwk := JWK{ + Key: pubKey, + KeyID: "kid", + Algorithm: "EdDSA", + } + + jwkBytes, err := json.Marshal(jwk) + require.NoError(t, err) + + var jwkMap map[string]interface{} + + err = json.Unmarshal(jwkBytes, &jwkMap) + require.NoError(t, err) + + headers["jwk"] = jwkMap + + parsedJWK, ok := headers.GetJWK() + require.True(t, ok) + require.NotNil(t, parsedJWK) + + // jwk is not present + delete(headers, "jwk") + parsedJWK, ok = headers.GetJWK() + require.False(t, ok) + require.Nil(t, parsedJWK) + + // jwk is not a map + headers["jwk"] = "not a map" + parsedJWK, ok = headers.GetJWK() + require.False(t, ok) + require.Nil(t, parsedJWK) +} + +func TestHeaders_GetKeyID(t *testing.T) { + kid, ok := Headers{"kid": "key id"}.GetKeyID() + require.True(t, ok) + require.Equal(t, "key id", kid) + + kid, ok = Headers{"kid": 777}.GetKeyID() + require.False(t, ok) + require.Empty(t, kid) + + kid, ok = Headers{}.GetKeyID() + require.False(t, ok) + require.Empty(t, kid) +} + +func TestHeaders_GetAlgorithm(t *testing.T) { + kid, ok := Headers{"alg": "EdDSA"}.GetAlgorithm() + require.True(t, ok) + require.Equal(t, "EdDSA", kid) + + kid, ok = Headers{"alg": 777}.GetAlgorithm() + require.False(t, ok) + require.Empty(t, kid) + + kid, ok = Headers{}.GetAlgorithm() + require.False(t, ok) + require.Empty(t, kid) +} + +func TestNewCompositeAlgSignatureVerifier(t *testing.T) { + verifier := NewCompositeAlgSignatureVerifier() + + verifier.AddVerifier("EdDSA", SignatureVerifierFunc( + func(joseHeaders Headers, payload, signingInput, signature []byte) error { + return errors.New("signature is invalid") + }, + )) + + err := verifier.Verify(Headers{"alg": "EdDSA"}, nil, nil, nil) + require.Error(t, err) + require.EqualError(t, err, "signature is invalid") + + // alg is not defined + err = verifier.Verify(Headers{}, nil, nil, nil) + require.Error(t, err) + require.EqualError(t, err, "'alg' JOSE header is not present") + + // not supported alg + err = verifier.Verify(Headers{"alg": "RS256"}, nil, nil, nil) + require.Error(t, err) + require.EqualError(t, err, "no verifier found for RS256 algorithm") +} + +func TestJSONWebSignature_SerializeCompact(t *testing.T) { + jws := NewJWS(Headers{"alg": "EdSDA", "typ": "JWT"}, nil, []byte("payload")) + + jwsCompact, err := jws.SerializeCompact(&testSigner{ + headers: Headers{"alg": "dummy"}, + signature: []byte("signature"), + }, false) + require.NoError(t, err) + require.NotEmpty(t, jwsCompact) + + // b64=false + jwsCompact, err = jws.SerializeCompact(&testSigner{ + headers: Headers{"alg": "dummy", "b64": false}, + signature: []byte("signature"), + }, false) + require.NoError(t, err) + require.NotEmpty(t, jwsCompact) + + // signer error + jwsCompact, err = jws.SerializeCompact(&testSigner{ + headers: Headers{"alg": "dummy"}, + err: errors.New("signer error"), + }, false) + require.Error(t, err) + require.Contains(t, err.Error(), "sign JWS verification data") + require.Empty(t, jwsCompact) + + // no alg defined + jwsCompact, err = jws.SerializeCompact(&testSigner{ + headers: Headers{}, + }, false) + require.Error(t, err) + require.Contains(t, err.Error(), "alg JWS header is not defined") + require.Empty(t, jwsCompact) + + // jose headers marshalling error + jwsCompact, err = jws.SerializeCompact(&testSigner{ + headers: getUnmarshallableMap(), + }, false) + require.Error(t, err) + require.Contains(t, err.Error(), "serialize JWS headers") + require.Empty(t, jwsCompact) + + // invalid b64 + jwsCompact, err = jws.SerializeCompact(&testSigner{ + headers: Headers{"alg": "dummy", "b64": "invalid"}, + signature: []byte("signature"), + }, false) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid b64 header") + require.Empty(t, jwsCompact) + + // invalid protected JWS headers + jws.ProtectedHeaders = getUnmarshallableMap() + jwsCompact, err = jws.SerializeCompact(&testSigner{ + headers: Headers{"alg": "dummy"}, + signature: []byte("signature"), + }, false) + require.Error(t, err) + require.Contains(t, err.Error(), "marshal Protected JWS Headers") + require.Empty(t, jwsCompact) +} + +func TestJSONWebSignature_Signature(t *testing.T) { + jws := &JSONWebSignature{ + signature: []byte("signature"), + } + require.NotEmpty(t, jws.Signature()) + + jws.signature = nil + require.Empty(t, jws.Signature()) +} + +func TestParseJWS(t *testing.T) { + corruptedBased64 := "XXXXXaGVsbG8=" + + jws := NewJWS(Headers{"alg": "EdSDA", "typ": "JWT"}, nil, []byte("payload")) + + jwsCompact, err := jws.SerializeCompact(&testSigner{ + headers: Headers{"alg": "dummy"}, + signature: []byte("signature"), + }, false) + require.NoError(t, err) + require.NotEmpty(t, jwsCompact) + + validJWSParts := strings.Split(jwsCompact, ".") + + parsedJWS, err := ParseJWS(jwsCompact, &testVerifier{}) + require.NoError(t, err) + require.NotNil(t, parsedJWS) + require.Equal(t, jws, parsedJWS) + + jwsDetached := fmt.Sprintf("%s.%s.%s", validJWSParts[0], "", validJWSParts[2]) + + detachedPayload, err := base64.RawURLEncoding.DecodeString(validJWSParts[1]) + require.NoError(t, err) + + parsedJWS, err = ParseJWS(jwsDetached, &testVerifier{}, WithJWSDetachedPayload(detachedPayload)) + require.NoError(t, err) + require.NotNil(t, parsedJWS) + require.Equal(t, jws, parsedJWS) + + // Parse not compact JWS format + parsedJWS, err = ParseJWS(`{"some": "JSON"}`, &testVerifier{}) + require.Error(t, err) + require.EqualError(t, err, "JWS JSON serialization is not supported") + require.Nil(t, parsedJWS) + + // Parse invalid compact JWS format + parsedJWS, err = ParseJWS("two_parts.only", &testVerifier{}) + require.Error(t, err) + require.EqualError(t, err, "invalid JWS compact format") + require.Nil(t, parsedJWS) + + // invalid headers + jwsWithInvalidHeaders := fmt.Sprintf("%s.%s.%s", "invalid", validJWSParts[1], validJWSParts[2]) + parsedJWS, err = ParseJWS(jwsWithInvalidHeaders, &testVerifier{}) + require.Error(t, err) + require.Contains(t, err.Error(), "unmarshal JSON headers") + require.Nil(t, parsedJWS) + + jwsWithInvalidHeaders = fmt.Sprintf("%s.%s.%s", corruptedBased64, validJWSParts[1], validJWSParts[2]) + parsedJWS, err = ParseJWS(jwsWithInvalidHeaders, &testVerifier{}) + require.Error(t, err) + require.Contains(t, err.Error(), "decode base64 header") + require.Nil(t, parsedJWS) + + emptyHeaders := base64.RawURLEncoding.EncodeToString([]byte("{}")) + + jwsWithInvalidHeaders = fmt.Sprintf("%s.%s.%s", emptyHeaders, validJWSParts[1], validJWSParts[2]) + parsedJWS, err = ParseJWS(jwsWithInvalidHeaders, &testVerifier{}) + require.Error(t, err) + require.Contains(t, err.Error(), "alg JWS header is not defined") + require.Nil(t, parsedJWS) + + // invalid payload + jwsWithInvalidPayload := fmt.Sprintf("%s.%s.%s", validJWSParts[0], corruptedBased64, validJWSParts[2]) + parsedJWS, err = ParseJWS(jwsWithInvalidPayload, &testVerifier{}) + require.Error(t, err) + require.Contains(t, err.Error(), "decode base64 payload") + require.Nil(t, parsedJWS) + + // invalid signature + jwsWithInvalidSignature := fmt.Sprintf("%s.%s.%s", validJWSParts[0], validJWSParts[1], corruptedBased64) + parsedJWS, err = ParseJWS(jwsWithInvalidSignature, &testVerifier{}) + require.Error(t, err) + require.Contains(t, err.Error(), "decode base64 signature") + require.Nil(t, parsedJWS) + + // verifier error + parsedJWS, err = ParseJWS(jwsCompact, &testVerifier{err: errors.New("bad signature")}) + require.Error(t, err) + require.EqualError(t, err, "bad signature") + require.Nil(t, parsedJWS) +} + +func TestIsCompactJWS(t *testing.T) { + require.True(t, IsCompactJWS("a.b.c")) + require.False(t, IsCompactJWS("a.b")) + require.False(t, IsCompactJWS(`{"some": "JSON"}`)) + require.False(t, IsCompactJWS("")) +} + +type testSigner struct { + headers Headers + signature []byte + err error +} + +func (s testSigner) Sign(_ []byte) ([]byte, error) { + return s.signature, s.err +} + +func (s testSigner) Headers() Headers { + return s.headers +} + +type testVerifier struct { + err error +} + +func (v testVerifier) Verify(_ Headers, _, _, _ []byte) error { + return v.err +} + +func getUnmarshallableMap() map[string]interface{} { + return map[string]interface{}{"alg": "JWS", "error": map[chan int]interface{}{make(chan int): 6}} +} diff --git a/pkg/doc/jwt/jwt.go b/pkg/doc/jwt/jwt.go new file mode 100644 index 0000000000..4d3b0ab9ea --- /dev/null +++ b/pkg/doc/jwt/jwt.go @@ -0,0 +1,271 @@ +/* +Copyright SecureKey Technologies Inc. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package jwt + +import ( + "bytes" + "errors" + "fmt" + "reflect" + + "github.com/square/go-jose/v3/json" + + "github.com/hyperledger/aries-framework-go/pkg/doc/jose" +) + +const ( + // TypeJWT defines JWT type + TypeJWT = "JWT" + + // AlgorithmNone used to indicate unsecured JWT + AlgorithmNone = "none" +) + +// jwtParseOpts holds options for the JWT parsing. +type parseOpts struct { + detachedPayload []byte + sigVerifier jose.SignatureVerifier +} + +// ParseOpt is the JWT Parser option. +type ParseOpt func(opts *parseOpts) + +// WithJWTDetachedPayload option is for definition of JWT detached payload. +func WithJWTDetachedPayload(payload []byte) ParseOpt { + return func(opts *parseOpts) { + opts.detachedPayload = payload + } +} + +// WithSignatureVerifier option is for definition of JWT detached payload. +func WithSignatureVerifier(signatureVerifier jose.SignatureVerifier) ParseOpt { + return func(opts *parseOpts) { + opts.sigVerifier = signatureVerifier + } +} + +type signatureVerifierFunc func(joseHeaders jose.Headers, payload, signingInput, signature []byte) error + +func (v signatureVerifierFunc) Verify(joseHeaders jose.Headers, payload, signingInput, signature []byte) error { + return v(joseHeaders, payload, signingInput, signature) +} + +func verifyUnsecuredJWT(joseHeaders jose.Headers, _, _, signature []byte) error { + alg, ok := joseHeaders.GetAlgorithm() + if !ok { + return errors.New("alg is not defined") + } + + if alg != AlgorithmNone { + return errors.New("alg value is not 'none'") + } + + if len(signature) > 0 { + return errors.New("not empty signature") + } + + return nil +} + +// UnsecuredJWTVerifier provides verifier for unsecured JWT. +func UnsecuredJWTVerifier() jose.SignatureVerifier { + return signatureVerifierFunc(verifyUnsecuredJWT) +} + +type unsecuredJWTSigner struct { + extraHeaders map[string]interface{} +} + +func (s unsecuredJWTSigner) Sign(_ []byte) ([]byte, error) { + return []byte(""), nil +} + +func (s unsecuredJWTSigner) Headers() jose.Headers { + jHeaders := map[string]interface{}{ + jose.HeaderAlgorithm: AlgorithmNone, + jose.HeaderType: TypeJWT, + } + + for k, v := range s.extraHeaders { + if _, ok := jHeaders[k]; !ok { + jHeaders[k] = v + } + } + + return jHeaders +} + +// JSONWebToken defines JSON Web Token (https://tools.ietf.org/html/rfc7519) +type JSONWebToken struct { + Headers jose.Headers + + Payload map[string]interface{} + + signature []byte +} + +// Parse parses input JWT in serialized form into JSON Web Token. +// Currently JWS and unsecured JWT is supported.9 +func Parse(jwt string, opts ...ParseOpt) (*JSONWebToken, error) { + if !jose.IsCompactJWS(jwt) { + return nil, errors.New("JWT of compacted JWS form is supported only") + } + + pOpts := &parseOpts{} + + for _, opt := range opts { + opt(pOpts) + } + + return parseJWS(jwt, pOpts) +} + +// DecodeClaims fills input c with claims of a token. +func (j *JSONWebToken) DecodeClaims(c interface{}) error { + pBytes, err := json.Marshal(j.Payload) + if err != nil { + return err + } + + return json.Unmarshal(pBytes, c) +} + +// LookupStringHeader makes look up of particular header with string value. +func (j *JSONWebToken) LookupStringHeader(name string) string { + if headerValue, ok := j.Headers[name]; ok { + if headerStrValue, ok := headerValue.(string); ok { + return headerStrValue + } + } + + return "" +} + +// SerializeSigned makes (compact) serialization of token. +func (j *JSONWebToken) SerializeSigned(signer jose.Signer, detached bool) (string, error) { + return j.serialize(signer, detached) +} + +// SerializeUnsecured build unsecured JWT. +func (j *JSONWebToken) SerializeUnsecured(extraHeaders map[string]interface{}, detached bool) (string, error) { + return j.serialize(&unsecuredJWTSigner{extraHeaders}, detached) +} + +func (j *JSONWebToken) serialize(signer jose.Signer, detached bool) (string, error) { + payloadBytes, err := j.marshalPayload() + if err != nil { + return "", fmt.Errorf("marshal JWT claims: %w", err) + } + + jws := jose.NewJWS(j.Headers, nil, payloadBytes) + j.signature = jws.Signature() + + return jws.SerializeCompact(signer, detached) +} + +func (j *JSONWebToken) marshalPayload() ([]byte, error) { + return json.Marshal(j.Payload) +} + +func parseJWS(jwt string, opts *parseOpts) (*JSONWebToken, error) { + jwsOpts := make([]jose.JWSParseOpt, 0) + + if opts.detachedPayload != nil { + jwsOpts = append(jwsOpts, jose.WithJWSDetachedPayload(opts.detachedPayload)) + } + + jws, err := jose.ParseJWS(jwt, opts.sigVerifier, jwsOpts...) + if err != nil { + return nil, fmt.Errorf("parse JWT from compact JWS: %w", err) + } + + return mapJWSToJWT(jws) +} + +func mapJWSToJWT(jws *jose.JSONWebSignature) (*JSONWebToken, error) { + headers := jws.ProtectedHeaders + + err := checkHeaders(headers) + if err != nil { + return nil, fmt.Errorf("check JWT headers: %w", err) + } + + claims, err := toMap(jws.Payload) + if err != nil { + return nil, fmt.Errorf("read JWT claims from JWS payload: %w", err) + } + + return &JSONWebToken{ + Headers: headers, + Payload: claims, + signature: jws.Signature(), + }, nil +} + +// New creates new JSON Web Token based on input claims. +func New(claims interface{}) (*JSONWebToken, error) { + m, err := toMap(claims) + if err != nil { + return nil, fmt.Errorf("unmarshallable claims: %w", err) + } + + return &JSONWebToken{ + Payload: m, + }, nil +} + +func checkHeaders(headers map[string]interface{}) error { + if _, ok := headers[jose.HeaderAlgorithm]; !ok { + return errors.New("alg header is not defined") + } + + typ, ok := headers[jose.HeaderType] + if ok && typ != TypeJWT { + return errors.New("typ is not JWT") + } + + cty, ok := headers[jose.HeaderContentType] + if ok && cty == TypeJWT { + return errors.New("nested JWT is not supported") + } + + return nil +} + +func toMap(i interface{}) (map[string]interface{}, error) { + if reflect.ValueOf(i).Kind() == reflect.Map { + return i.(map[string]interface{}), nil + } + + var ( + b []byte + err error + ) + + switch cv := i.(type) { + case []byte: + b = cv + case string: + b = []byte(cv) + default: + b, err = json.Marshal(i) + if err != nil { + return nil, fmt.Errorf("convert to bytes: ") + } + } + + var m map[string]interface{} + + d := json.NewDecoder(bytes.NewReader(b)) + d.UseNumber() + + if err := d.Decode(&m); err != nil { + return nil, fmt.Errorf("convert to map: %v", err) + } + + return m, nil +} diff --git a/pkg/doc/jwt/jwt_test.go b/pkg/doc/jwt/jwt_test.go new file mode 100644 index 0000000000..3f2157194e --- /dev/null +++ b/pkg/doc/jwt/jwt_test.go @@ -0,0 +1,389 @@ +/* +Copyright SecureKey Technologies Inc. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package jwt + +import ( + "crypto/ed25519" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "fmt" + "strings" + "testing" + "time" + + "github.com/square/go-jose/v3/json" + "github.com/square/go-jose/v3/jwt" + "github.com/stretchr/testify/require" + + "github.com/hyperledger/aries-framework-go/pkg/doc/jose" +) + +type CustomClaim struct { + *jwt.Claims + + PrivateClaim1 string `json:"privateClaim1,omitempty"` +} + +func TestNew(t *testing.T) { + issued := time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC) + expiry := time.Date(2022, time.January, 1, 0, 0, 0, 0, time.UTC) + notBefore := time.Date(2021, time.January, 1, 0, 0, 0, 0, time.UTC) + + claims := &CustomClaim{ + Claims: &jwt.Claims{ + Issuer: "iss", + Subject: "sub", + Audience: []string{"aud"}, + Expiry: jwt.NewNumericDate(expiry), + NotBefore: jwt.NewNumericDate(notBefore), + IssuedAt: jwt.NewNumericDate(issued), + ID: "id", + }, + + PrivateClaim1: "private claim", + } + + t.Run("Create JWS signed by EdDSA", func(t *testing.T) { + r := require.New(t) + + pubKey, privKey, err := ed25519.GenerateKey(rand.Reader) + r.NoError(err) + + token, err := New(claims) + r.NoError(err) + jws, err := token.SerializeSigned(newEd25519Signer(privKey), false) + require.NoError(t, err) + + var parsedClaims CustomClaim + err = verifyEd25519ViaGoJose(jws, pubKey, &parsedClaims) + r.NoError(err) + r.Equal(*claims, parsedClaims) + + err = verifyEd25519(jws, pubKey) + r.NoError(err) + }) + + t.Run("Create JWS signed by RS256", func(t *testing.T) { + r := require.New(t) + + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + r.NoError(err) + + pubKey := &privKey.PublicKey + + token, err := New(claims) + r.NoError(err) + jws, err := token.SerializeSigned(newRS256Signer(privKey, nil), false) + require.NoError(t, err) + + var parsedClaims CustomClaim + err = verifyRS256ViaGoJose(jws, pubKey, &parsedClaims) + r.NoError(err) + r.Equal(*claims, parsedClaims) + + err = verifyRS256(jws, pubKey) + r.NoError(err) + }) + + t.Run("Create unsecured JWT", func(t *testing.T) { + r := require.New(t) + + token, err := New(claims) + r.NoError(err) + jwtUnsecured, err := token.SerializeUnsecured(map[string]interface{}{"custom": "ok"}, false) + r.NoError(err) + r.NotEmpty(jwtUnsecured) + + parsedJWT, err := Parse(jwtUnsecured, WithSignatureVerifier(UnsecuredJWTVerifier())) + r.NoError(err) + r.NotNil(parsedJWT) + + var parsedClaims CustomClaim + err = parsedJWT.DecodeClaims(&parsedClaims) + r.NoError(err) + r.Equal(*claims, parsedClaims) + }) + + t.Run("Invalid claims", func(t *testing.T) { + token, err := New("not JSON claims") + require.Error(t, err) + require.Nil(t, token) + }) +} + +func TestWithJWTDetachedPayload(t *testing.T) { + detachedPayloadOpt := WithJWTDetachedPayload([]byte("payload")) + require.NotNil(t, detachedPayloadOpt) + + opts := &parseOpts{} + detachedPayloadOpt(opts) + require.Equal(t, []byte("payload"), opts.detachedPayload) +} + +func TestParse(t *testing.T) { + r := require.New(t) + + pubKey, privKey, err := ed25519.GenerateKey(rand.Reader) + r.NoError(err) + + signer := newEd25519Signer(privKey) + claims := map[string]interface{}{"iss": "Albert"} + + token, err := New(claims) + r.NoError(err) + jws, err := token.SerializeSigned(signer, false) + r.NoError(err) + + verifier, err := newEd25519Verifier(pubKey) + r.NoError(err) + + jsonWebToken, err := Parse(jws, WithSignatureVerifier(verifier)) + r.NoError(err) + + var parsedClaims map[string]interface{} + err = jsonWebToken.DecodeClaims(&parsedClaims) + r.NoError(err) + + r.Equal(claims, parsedClaims) + + // parse detached JWT + jwsParts := strings.Split(jws, ".") + jwsDetached := fmt.Sprintf("%s..%s", jwsParts[0], jwsParts[2]) + + jwsPayload, err := base64.RawURLEncoding.DecodeString(jwsParts[1]) + require.NoError(t, err) + + jsonWebToken, err = Parse(jwsDetached, + WithSignatureVerifier(verifier), WithJWTDetachedPayload(jwsPayload)) + r.NoError(err) + r.NotNil(r, jsonWebToken) + + // claims is not JSON + jws, err = buildJWS(signer, "not JSON") + r.NoError(err) + token, err = Parse(jws, WithSignatureVerifier(verifier)) + r.Error(err) + r.Contains(err.Error(), "read JWT claims from JWS payload") + r.Nil(token) + + // type is not JWT + signer.headers = map[string]interface{}{"alg": "EdDSA", "typ": "JWM"} + jws, err = buildJWS(signer, map[string]interface{}{"iss": "Albert"}) + r.NoError(err) + token, err = Parse(jws, WithSignatureVerifier(verifier)) + r.Error(err) + r.Contains(err.Error(), "typ is not JWT") + r.Nil(token) + + // content type is not empty (equals to JWT) + signer.headers = map[string]interface{}{"alg": "EdDSA", "typ": "JWT", "cty": "JWT"} + jws, err = buildJWS(signer, map[string]interface{}{"iss": "Albert"}) + r.NoError(err) + token, err = Parse(jws, WithSignatureVerifier(verifier)) + r.Error(err) + r.Contains(err.Error(), "nested JWT is not supported") + r.Nil(token) + + // handle compact JWS of invalid form + token, err = Parse("invalid.compact.JWS") + r.Error(err) + r.Contains(err.Error(), "parse JWT from compact JWS") + r.Nil(token) + + // pass not compact JWS + token, err = Parse("invalid jws") + r.Error(err) + r.EqualError(err, "JWT of compacted JWS form is supported only") + r.Nil(token) +} + +func buildJWS(signer jose.Signer, claims interface{}) (string, error) { + claimsBytes, err := json.Marshal(claims) + if err != nil { + return "", err + } + + jws := jose.NewJWS(nil, nil, claimsBytes) + + return jws.SerializeCompact(signer, false) +} + +func TestJSONWebToken_DecodeClaims(t *testing.T) { + token := getValidJSONWebToken() + + var tokensMap map[string]interface{} + + err := token.DecodeClaims(&tokensMap) + require.NoError(t, err) + require.Equal(t, map[string]interface{}{"iss": "Albert"}, tokensMap) + + var claims jwt.Claims + + err = token.DecodeClaims(&claims) + require.NoError(t, err) + require.Equal(t, jwt.Claims{Issuer: "Albert"}, claims) + + err = getJSONWebTokenWithInvalidPayload().DecodeClaims(&claims) + require.Error(t, err) +} + +func TestJSONWebToken_LookupStringHeader(t *testing.T) { + token := getValidJSONWebToken() + + require.Equal(t, "JWT", token.LookupStringHeader("typ")) + + require.Empty(t, token.LookupStringHeader("undef")) + + token.Headers["not_str"] = 55 + require.Empty(t, token.LookupStringHeader("not_str")) +} + +func TestJSONWebToken_SerializeSigned(t *testing.T) { + token := getValidJSONWebToken() + + _, privKey, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + signer := newEd25519Signer(privKey) + + jws, err := token.SerializeSigned(signer, false) + require.NoError(t, err) + require.NotEmpty(t, jws) + + // unmarshallable claims case + token = getJSONWebTokenWithInvalidPayload() + jws, err = token.SerializeSigned(signer, false) + require.Error(t, err) + require.Contains(t, err.Error(), "marshal JWT claims") + require.Empty(t, jws) +} + +func TestJSONWebToken_SerializeUnsecured(t *testing.T) { + token := getValidJSONWebToken() + + jws, err := token.SerializeUnsecured(nil, false) + require.NoError(t, err) + require.NotEmpty(t, jws) + + // unmarshallable claims case + token = getJSONWebTokenWithInvalidPayload() + jws, err = token.SerializeUnsecured(nil, false) + require.Error(t, err) + require.Contains(t, err.Error(), "marshal JWT claims") + require.Empty(t, jws) +} + +func TestUnsecuredJWTVerifier(t *testing.T) { + verifier := UnsecuredJWTVerifier() + + err := verifier.Verify(map[string]interface{}{"alg": "none"}, nil, nil, nil) + require.NoError(t, err) + + err = verifier.Verify(map[string]interface{}{}, nil, nil, nil) + require.Error(t, err) + require.EqualError(t, err, "alg is not defined") + + err = verifier.Verify(map[string]interface{}{"alg": "EdDSA"}, nil, nil, nil) + require.Error(t, err) + require.EqualError(t, err, "alg value is not 'none'") + + err = verifier.Verify(map[string]interface{}{"alg": "none"}, nil, nil, []byte("unexpected signature")) + require.Error(t, err) + require.EqualError(t, err, "not empty signature") +} + +type testToMapStruct struct { + TestField string `json:"a"` +} + +func Test_toMap(t *testing.T) { + inputMap := map[string]interface{}{"a": "b"} + + r := require.New(t) + + // pass map + resultMap, err := toMap(inputMap) + r.NoError(err) + r.Equal(inputMap, resultMap) + + // pass []byte + inputMapBytes, err := json.Marshal(inputMap) + r.NoError(err) + resultMap, err = toMap(inputMapBytes) + r.NoError(err) + r.Equal(inputMap, resultMap) + + // pass string + inputMapStr := string(inputMapBytes) + resultMap, err = toMap(inputMapStr) + r.NoError(err) + r.Equal(inputMap, resultMap) + + // pass struct + s := testToMapStruct{TestField: "b"} + resultMap, err = toMap(s) + r.NoError(err) + r.Equal(inputMap, resultMap) + + // pass invalid []byte + resultMap, err = toMap([]byte("not JSON")) + r.Error(err) + r.Contains(err.Error(), "convert to map") + r.Nil(resultMap) + + // pass invalid structure + resultMap, err = toMap(make(chan int)) + r.Error(err) + r.Contains(err.Error(), "convert to bytes") + r.Nil(resultMap) +} + +func getValidJSONWebToken() *JSONWebToken { + return &JSONWebToken{ + Headers: map[string]interface{}{"typ": "JWT", "alg": "EdDSA"}, + Payload: map[string]interface{}{"iss": "Albert"}, + signature: []byte("signature"), + } +} + +func getJSONWebTokenWithInvalidPayload() *JSONWebToken { + return &JSONWebToken{ + Headers: map[string]interface{}{"typ": "JWT", "alg": "EdDSA"}, + Payload: getUnmarshallableMap(), + signature: []byte("signature")} +} + +func verifyEd25519ViaGoJose(jws string, pubKey ed25519.PublicKey, claims interface{}) error { + jwtToken, err := jwt.ParseSigned(jws) + if err != nil { + return fmt.Errorf("parse VC from signed JWS: %w", err) + } + + if err = jwtToken.Claims(pubKey, claims); err != nil { + return fmt.Errorf("verify JWT signature: %w", err) + } + + return nil +} + +func verifyRS256ViaGoJose(jws string, pubKey *rsa.PublicKey, claims interface{}) error { + jwtToken, err := jwt.ParseSigned(jws) + if err != nil { + return fmt.Errorf("parse VC from signed JWS: %w", err) + } + + if err = jwtToken.Claims(pubKey, claims); err != nil { + return fmt.Errorf("verify JWT signature: %w", err) + } + + return nil +} + +func getUnmarshallableMap() map[string]interface{} { + return map[string]interface{}{"error": map[chan int]interface{}{make(chan int): 6}} +} diff --git a/pkg/doc/jwt/support_test.go b/pkg/doc/jwt/support_test.go new file mode 100644 index 0000000000..e4e7087d5f --- /dev/null +++ b/pkg/doc/jwt/support_test.go @@ -0,0 +1,177 @@ +/* +Copyright SecureKey Technologies Inc. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package jwt + +import ( + "crypto" + "crypto/ed25519" + "crypto/rand" + "crypto/rsa" + "errors" + + "github.com/hyperledger/aries-framework-go/pkg/doc/jose" +) + +type ed25519Signer struct { + privKey []byte + headers map[string]interface{} +} + +func (s ed25519Signer) Sign(data []byte) ([]byte, error) { + return ed25519.Sign(s.privKey, data), nil +} + +func (s ed25519Signer) Headers() jose.Headers { + return s.headers +} + +func newEd25519Signer(privKey []byte) *ed25519Signer { + return &ed25519Signer{ + privKey: privKey, + headers: prepareJWSHeaders(nil, signatureEdDSA), + } +} + +type ed25519Verifier struct { + pubKey []byte +} + +func (v ed25519Verifier) Verify(joseHeaders jose.Headers, _, signingInput, signature []byte) error { + alg, ok := joseHeaders.GetAlgorithm() + if !ok { + return errors.New("alg is not defined") + } + + if alg != "EdDSA" { + return errors.New("alg is not EdDSA") + } + + if ok := ed25519.Verify(v.pubKey, signingInput, signature); !ok { + return errors.New("signature doesn't match") + } + + return nil +} + +func newEd25519Verifier(pubKey []byte) (*ed25519Verifier, error) { + if l := len(pubKey); l != ed25519.PublicKeySize { + return nil, errors.New("bad ed25519 public key length") + } + + return &ed25519Verifier{pubKey: pubKey}, nil +} + +type rs256Signer struct { + privKey *rsa.PrivateKey + headers map[string]interface{} +} + +func newRS256Signer(privKey *rsa.PrivateKey, headers map[string]interface{}) *rs256Signer { + return &rs256Signer{ + privKey: privKey, + headers: prepareJWSHeaders(headers, signatureRS256), + } +} + +func (s rs256Signer) Sign(data []byte) ([]byte, error) { + hash := crypto.SHA256.New() + + _, err := hash.Write(data) + if err != nil { + return nil, err + } + + hashed := hash.Sum(nil) + + return rsa.SignPKCS1v15(rand.Reader, s.privKey, crypto.SHA256, hashed) +} + +func (s rs256Signer) Headers() jose.Headers { + return s.headers +} + +type rs256Verifier struct { + pubKey *rsa.PublicKey +} + +func newRS256Verifier(pubKey *rsa.PublicKey) *rs256Verifier { + return &rs256Verifier{pubKey: pubKey} +} + +func (v rs256Verifier) Verify(joseHeaders jose.Headers, _, signingInput, signature []byte) error { + alg, ok := joseHeaders.GetAlgorithm() + if !ok { + return errors.New("alg is not defined") + } + + if alg != "RS256" { + return errors.New("alg is not RS256") + } + + hash := crypto.SHA256.New() + + _, err := hash.Write(signingInput) + if err != nil { + return err + } + + hashed := hash.Sum(nil) + + return rsa.VerifyPKCS1v15(v.pubKey, crypto.SHA256, hashed, signature) +} + +func verifyEd25519(jws string, pubKey ed25519.PublicKey) error { + verifier, err := newEd25519Verifier(pubKey) + if err != nil { + return err + } + + sVerifier := jose.NewCompositeAlgSignatureVerifier() + sVerifier.AddVerifier("EdDSA", verifier) + + token, err := Parse(jws, WithSignatureVerifier(sVerifier)) + if err != nil { + return err + } + + if token == nil { + return errors.New("nil token") + } + + return nil +} + +func verifyRS256(jws string, pubKey *rsa.PublicKey) error { + verifier := newRS256Verifier(pubKey) + + sVerifier := jose.NewCompositeAlgSignatureVerifier() + sVerifier.AddVerifier("RS256", verifier) + + token, err := Parse(jws, WithSignatureVerifier(sVerifier)) + if err != nil { + return err + } + + if token == nil { + return errors.New("nil token") + } + + return nil +} + +func prepareJWSHeaders(headers map[string]interface{}, alg string) map[string]interface{} { + newHeaders := make(map[string]interface{}) + + for k, v := range headers { + newHeaders[k] = v + } + + newHeaders[jose.HeaderType] = TypeJWT + newHeaders[jose.HeaderAlgorithm] = alg + + return newHeaders +} diff --git a/pkg/doc/jwt/verifier.go b/pkg/doc/jwt/verifier.go new file mode 100644 index 0000000000..24e1f5d3ac --- /dev/null +++ b/pkg/doc/jwt/verifier.go @@ -0,0 +1,150 @@ +/* +Copyright SecureKey Technologies Inc. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package jwt + +import ( + "crypto" + "crypto/rsa" + "errors" + "fmt" + + "github.com/square/go-jose/v3/json" + "golang.org/x/crypto/ed25519" + + "github.com/hyperledger/aries-framework-go/pkg/doc/jose" +) + +const ( + // signatureEdDSA defines EdDSA alg + signatureEdDSA = "EdDSA" + + // signatureRS256 defines RS256 alg + signatureRS256 = "RS256" +) + +const issuerClaim = "iss" + +// KeyResolver resolves public key based on what and kid. +type KeyResolver interface { + + // Resolve resolves public key. + Resolve(what, kid string) (interface{}, error) +} + +// BasicVerifier defines basic Signed JWT verifier based on Issuer Claim and Key ID JOSE Header. +type BasicVerifier struct { + resolver KeyResolver + compositeVerifier *jose.CompositeAlgSignatureVerifier +} + +// NewVerifier creates a new basic Verifier. +func NewVerifier(resolver KeyResolver) *BasicVerifier { + // TODO Support pluggable JWS verifiers + // (https://github.com/hyperledger/aries-framework-go/issues/1267) + compositeVerifier := jose.NewCompositeAlgSignatureVerifier() + compositeVerifier.AddVerifier(signatureEdDSA, getVerifier(resolver, VerifyEdDSA)) + compositeVerifier.AddVerifier(signatureRS256, getVerifier(resolver, VerifyRS256)) + // TODO ECDSA to support NIST P256 curve + // https://github.com/hyperledger/aries-framework-go/issues/1266 + + return &BasicVerifier{resolver: resolver, compositeVerifier: compositeVerifier} +} + +type verifier func(pubKey interface{}, message, signature []byte) error + +func getVerifier(resolver KeyResolver, verifier verifier) jose.SignatureVerifier { + return jose.SignatureVerifierFunc(func(joseHeaders jose.Headers, payload, signingInput, signature []byte) error { + return verifySignature(resolver, verifier, joseHeaders, payload, signingInput, signature) + }) +} + +func verifySignature(resolver KeyResolver, verifier verifier, + joseHeaders jose.Headers, payload, signingInput, signature []byte) error { + claims := make(map[string]interface{}) + + err := json.Unmarshal(payload, &claims) + if err != nil { + return fmt.Errorf("read claims from JSON Web Token: %w", err) + } + + issuer, err := getIssuerClaim(claims) + if err != nil { + return fmt.Errorf("read issuer claim: %w", err) + } + + kid, _ := joseHeaders.GetKeyID() + + pubKey, err := resolver.Resolve(issuer, kid) + if err != nil { + return err + } + + return verifier(pubKey, signingInput, signature) +} + +// Verify verifies JSON Web Token. Public key is fetched using Issuer Claim and Key ID JOSE Header. +func (v BasicVerifier) Verify(joseHeaders jose.Headers, payload, signingInput, signature []byte) error { + return v.compositeVerifier.Verify(joseHeaders, payload, signingInput, signature) +} + +// VerifyEdDSA verifies EdDSA signature. +func VerifyEdDSA(pubKey interface{}, message, signature []byte) error { + // TODO Use crypto for signing/verification logic + // https://github.com/hyperledger/aries-framework-go/issues/1278 + pubKeyEdDSA, ok := pubKey.([]byte) + if !ok { + pubKeyEdDSA, ok = pubKey.(ed25519.PublicKey) + if !ok { + return errors.New("not []byte or ed25519.PublicKey public key") + } + } + + if l := len(pubKeyEdDSA); l != ed25519.PublicKeySize { + return errors.New("bad ed25519 public key length") + } + + if ok := ed25519.Verify(pubKeyEdDSA, message, signature); !ok { + return errors.New("signature doesn't match") + } + + return nil +} + +// VerifyRS256 verifies RS256 signature. +func VerifyRS256(pubKey interface{}, message, signature []byte) error { + // TODO Use crypto for signing/verification logic + // https://github.com/hyperledger/aries-framework-go/issues/1278 + pubKeyRsa, ok := pubKey.(*rsa.PublicKey) + if !ok { + return errors.New("not *rsa.PublicKey public key") + } + + hash := crypto.SHA256.New() + + _, err := hash.Write(message) + if err != nil { + return err + } + + hashed := hash.Sum(nil) + + return rsa.VerifyPKCS1v15(pubKeyRsa, crypto.SHA256, hashed, signature) +} + +func getIssuerClaim(claims map[string]interface{}) (string, error) { + v, ok := claims[issuerClaim] + if !ok { + return "", errors.New("issuer claim is not defined") + } + + s, ok := v.(string) + if !ok { + return "", errors.New("issuer claim is not a string") + } + + return s, nil +} diff --git a/pkg/doc/jwt/verifier_test.go b/pkg/doc/jwt/verifier_test.go new file mode 100644 index 0000000000..9823d87a2a --- /dev/null +++ b/pkg/doc/jwt/verifier_test.go @@ -0,0 +1,167 @@ +/* +Copyright SecureKey Technologies Inc. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package jwt + +import ( + "crypto" + "crypto/ed25519" + "crypto/rand" + "crypto/rsa" + "errors" + "testing" + + "github.com/square/go-jose/v3/json" + "github.com/square/go-jose/v3/jwt" + "github.com/stretchr/testify/require" + + "github.com/hyperledger/aries-framework-go/pkg/doc/jose" +) + +type testKeyResolver struct { + pubKey interface{} + err error +} + +func (r testKeyResolver) Resolve(_, _ string) (interface{}, error) { + return r.pubKey, r.err +} + +func TestNewVerifier(t *testing.T) { + r := require.New(t) + + t.Run("Verify JWT signed by EdDSA", func(t *testing.T) { + pubKey, privKey, err := ed25519.GenerateKey(rand.Reader) + r.NoError(err) + + signer := newEd25519Signer(privKey) + + token, err := New(&jwt.Claims{Issuer: "Mike"}) + r.NoError(err) + jws, err := token.SerializeSigned(signer, false) + r.NoError(err) + + verifier := NewVerifier(&testKeyResolver{pubKey: pubKey}) + _, err = jose.ParseJWS(jws, verifier) + r.NoError(err) + }) + + t.Run("Verify JWT signed by RS256", func(t *testing.T) { + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + r.NoError(err) + + pubKey := &privKey.PublicKey + + signer := newRS256Signer(privKey, nil) + + token, err := New(&jwt.Claims{Issuer: "Mike"}) + r.NoError(err) + jws, err := token.SerializeSigned(signer, false) + r.NoError(err) + + verifier := NewVerifier(&testKeyResolver{pubKey: pubKey}) + _, err = jose.ParseJWS(jws, verifier) + r.NoError(err) + }) +} + +func TestBasicVerifier_Verify(t *testing.T) { // error corner cases + r := require.New(t) + + pubKey, _, err := ed25519.GenerateKey(rand.Reader) + r.NoError(err) + + verifier := NewVerifier(&testKeyResolver{pubKey: pubKey}) + + validHeaders := map[string]interface{}{ + "alg": "EdDSA", + } + + // Invalid claims + err = verifier.Verify(validHeaders, []byte("invalid JSON claims"), nil, nil) + r.Error(err) + r.Contains(err.Error(), "read claims from JSON Web Token") + + // Issuer claim is not defined + claimsWithoutIssuer, err := json.Marshal(map[string]interface{}{}) + r.NoError(err) + err = verifier.Verify(validHeaders, claimsWithoutIssuer, nil, nil) + r.Error(err) + r.Contains(err.Error(), "issuer claim is not defined") + + // Issuer claim is not a string + claimsWithInvalidIssuer, err := json.Marshal(map[string]interface{}{"iss": 444}) + r.NoError(err) + err = verifier.Verify(validHeaders, claimsWithInvalidIssuer, nil, nil) + r.Error(err) + r.Contains(err.Error(), "issuer claim is not a string") + + validClaims, err := json.Marshal(map[string]interface{}{"iss": "Bob"}) + r.NoError(err) + + // key resolver error + verifier = NewVerifier(&testKeyResolver{err: errors.New("failed to resolve public key")}) + err = verifier.Verify(validHeaders, validClaims, nil, nil) + r.Error(err) + r.Contains(err.Error(), "failed to resolve public key") +} + +func TestVerifyEdDSA(t *testing.T) { + r := require.New(t) + + pubKey, privKey, err := ed25519.GenerateKey(rand.Reader) + r.NoError(err) + + signature := ed25519.Sign(privKey, []byte("test message")) + + err = VerifyEdDSA(pubKey, []byte("test message"), signature) + r.NoError(err) + + err = VerifyEdDSA([]byte("invalid pub key"), []byte("test message"), signature) + r.Error(err) + r.EqualError(err, "bad ed25519 public key length") + + anotherPubKey, _, err := ed25519.GenerateKey(rand.Reader) + r.NoError(err) + + err = VerifyEdDSA(anotherPubKey, []byte("test message"), signature) + r.Error(err) + r.EqualError(err, "signature doesn't match") + + err = VerifyEdDSA("not EdDSA public key", []byte("test message"), signature) + r.Error(err) + r.EqualError(err, "not []byte or ed25519.PublicKey public key") +} + +func TestVerifyRS256(t *testing.T) { + r := require.New(t) + + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + r.NoError(err) + + hash := crypto.SHA256.New() + + _, err = hash.Write([]byte("test message")) + r.NoError(err) + + hashed := hash.Sum(nil) + + signature, err := rsa.SignPKCS1v15(rand.Reader, privKey, crypto.SHA256, hashed) + r.NoError(err) + + err = VerifyRS256(&privKey.PublicKey, []byte("test message"), signature) + r.NoError(err) + + anotherPrivKey, err := rsa.GenerateKey(rand.Reader, 2048) + r.NoError(err) + + err = VerifyRS256(&anotherPrivKey.PublicKey, []byte("test message"), signature) + r.Error(err) + + err = VerifyRS256("not RS256 public key", []byte("test message"), signature) + r.Error(err) + r.EqualError(err, "not *rsa.PublicKey public key") +}