Skip to content

Commit

Permalink
Merge pull request #188 from blag/support-loading-jwks
Browse files Browse the repository at this point in the history
Support loading JWKs directly
  • Loading branch information
blag committed Sep 8, 2020
2 parents eb7c5fe + 60fa95d commit 82bd8aa
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
12 changes: 9 additions & 3 deletions jose/jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from collections import Mapping, Iterable # Python 2, will be deprecated in Python 3.8

from jose import jwk
from jose.backends.base import Key
from jose.constants import ALGORITHMS
from jose.exceptions import JWSError
from jose.exceptions import JWSSignatureError
Expand Down Expand Up @@ -163,10 +164,11 @@ def _encode_payload(payload):
return base64url_encode(payload)


def _sign_header_and_claims(encoded_header, encoded_claims, algorithm, key_data):
def _sign_header_and_claims(encoded_header, encoded_claims, algorithm, key):
signing_input = b'.'.join([encoded_header, encoded_claims])
try:
key = jwk.construct(key_data, algorithm)
if not isinstance(key, Key):
key = jwk.construct(key, algorithm)
signature = key.sign(signing_input)
except Exception as e:
raise JWSError(e)
Expand Down Expand Up @@ -213,7 +215,8 @@ def _load(jwt):

def _sig_matches_keys(keys, signing_input, signature, alg):
for key in keys:
key = jwk.construct(key, alg)
if not isinstance(key, Key):
key = jwk.construct(key, alg)
try:
if key.verify(signing_input, signature):
return True
Expand All @@ -224,6 +227,9 @@ def _sig_matches_keys(keys, signing_input, signature, alg):

def _get_keys(key):

if isinstance(key, Key):
return (key,)

try:
key = json.loads(key, parse_int=str, parse_float=str)
except Exception:
Expand Down
12 changes: 12 additions & 0 deletions tests/test_jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ def test_round_trip_with_different_key_types(self, key):
assert verified_data['testkey'] == 'testvalue'


class TestJWK(object):
def test_jwk(self, payload):
key_data = 'key'
key = jwk.construct(key_data, algorithm='HS256')
token = jws.sign(payload, key, algorithm=ALGORITHMS.HS256)
assert jws.verify(token, key_data, ALGORITHMS.HS256) == payload


class TestHMAC(object):

def testHMAC256(self, payload):
Expand Down Expand Up @@ -272,6 +280,10 @@ def test_tuple(self):
def test_list(self):
assert ['test', 'key'] == jws._get_keys(['test', 'key'])

def test_jwk(self):
jwkey = jwk.construct('key', algorithm='HS256')
assert (jwkey,) == jws._get_keys(jwkey)


class TestRSA(object):

Expand Down

0 comments on commit 82bd8aa

Please sign in to comment.