Skip to content

Commit

Permalink
Add JWK support for HMAC and RSA keys
Browse files Browse the repository at this point in the history
  • Loading branch information
mark-adams committed May 7, 2016
1 parent dac1dba commit 41c24b2
Show file tree
Hide file tree
Showing 6 changed files with 247 additions and 22 deletions.
134 changes: 132 additions & 2 deletions jwt/algorithms.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
import hashlib
import hmac
import json

from .compat import constant_time_compare, string_types, text_type
from .exceptions import InvalidKeyError
from .utils import der_to_raw_signature, raw_to_der_signature
from .utils import (
base64url_decode, base64url_encode, der_to_raw_signature,
from_base64url_uint, raw_to_der_signature, to_base64url_uint
)

try:
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.serialization import (
load_pem_private_key, load_pem_public_key, load_ssh_public_key
)
from cryptography.hazmat.primitives.asymmetric.rsa import (
RSAPrivateKey, RSAPublicKey
RSAPrivateKey, RSAPublicKey, RSAPrivateNumbers, RSAPublicNumbers,
rsa_recover_prime_factors, rsa_crt_dmp1, rsa_crt_dmq1, rsa_crt_iqmp
)
from cryptography.hazmat.primitives.asymmetric.ec import (
EllipticCurvePrivateKey, EllipticCurvePublicKey
Expand Down Expand Up @@ -77,6 +82,20 @@ def verify(self, msg, key, sig):
"""
raise NotImplementedError

@staticmethod
def to_jwk(key_obj):
"""
Serializes a given RSA key into a JWK
"""
raise NotImplementedError

@staticmethod
def from_jwk(jwk):
"""
Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object
"""
raise NotImplementedError


class NoneAlgorithm(Algorithm):
"""
Expand Down Expand Up @@ -131,6 +150,22 @@ def prepare_key(self, key):

return key

@staticmethod
def to_jwk(key_obj):
return json.dumps({
'k': base64url_encode(key_obj),
'typ': 'oct'
})

@staticmethod
def from_jwk(jwk):
obj = json.loads(jwk)

if obj.get('kty') != 'oct':
raise InvalidKeyError('Not an HMAC key')

return base64url_decode(obj['k'])

def sign(self, msg, key):
return hmac.new(key, msg, self.hash_alg).digest()

Expand Down Expand Up @@ -172,6 +207,101 @@ def prepare_key(self, key):

return key

@staticmethod
def to_jwk(key_obj):
obj = None

if getattr(key_obj, 'private_numbers', None):
# Private key
numbers = key_obj.private_numbers()

obj = {
'kty': 'RSA',
'key_ops': ['sign'],
'd': to_base64url_uint(numbers.d),
'p': to_base64url_uint(numbers.p),
'q': to_base64url_uint(numbers.q),
'dp': to_base64url_uint(numbers.dmp1),
'dq': to_base64url_uint(numbers.dmq1),
'qi': to_base64url_uint(numbers.iqmp)
}

elif getattr(key_obj, 'verifier', None):
# Public key
numbers = key_obj.public_numbers()

obj = {
'kty': 'RSA',
'use': 'sig',
'key_ops': ['verify'],
'n': to_base64url_uint(numbers.n),
'e': to_base64url_uint(numbers.e)
}
else:
raise InvalidKeyError('Not a public or private key')

return json.dumps(obj)

@staticmethod
def from_jwk(jwk):
obj = json.loads(jwk)

if obj.get('kty') != 'RSA':
raise InvalidKeyError('Not an RSA key')

if 'd' in obj and 'e' in obj and 'n' in obj:
# Private key
if 'oth' in obj:
raise InvalidKeyError('Unsupported RSA private key: > 2 primes not supported')

other_props = ['p', 'q', 'dp', 'dq', 'qi']
props_found = [True for prop in other_props if prop in obj]
any_props_found = any(props_found)

if any_props_found and not all(props_found):
raise InvalidKeyError('RSA key must include all parameters if any are present besides d')

public_numbers = RSAPublicNumbers(
from_base64url_uint(obj['e']), from_base64url_uint(obj['n'])
)

if any_props_found:
numbers = RSAPrivateNumbers(
d=from_base64url_uint(obj['d']),
p=from_base64url_uint(obj['p']),
q=from_base64url_uint(obj['q']),
dmp1=from_base64url_uint(obj['dp']),
dmq1=from_base64url_uint(obj['dq']),
iqmp=from_base64url_uint(obj['qi']),
public_numbers=public_numbers
)
else:
p, q = rsa_recover_prime_factors(
public_numbers.n, public_numbers.d, public_numbers.e
)
d = from_base64url_uint(obj['d'])

numbers = RSAPrivateNumbers(
d=d,
p=p,
q=q,
dmp1=rsa_crt_dmp1(d, p),
dmq1=rsa_crt_dmq1(d, q),
iqmp=rsa_crt_iqmp(p, q),
public_numbers=public_numbers
)

return numbers.private_key(default_backend())
elif 'n' in obj and 'e' in obj:
# Public key
numbers = RSAPublicNumbers(
from_base64url_uint(obj['e']), from_base64url_uint(obj['n'])
)

return numbers.public_key(default_backend())
else:
raise InvalidKeyError('Not a public or private key')

def sign(self, msg, key):
signer = key.signer(
padding.PKCS1v15(),
Expand Down
23 changes: 22 additions & 1 deletion jwt/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# flake8: noqa
import sys
import hmac

import struct

PY3 = sys.version_info[0] == 3

Expand Down Expand Up @@ -52,3 +52,24 @@ def constant_time_compare(val1, val2):
result |= ord(x) ^ ord(y)

return result == 0

# Use int.to_bytes if it exists (Python 3)
if getattr(int, 'to_bytes', None):
def bytes_from_int(val):
remaining = val
byte_length = 0

while remaining != 0:
remaining = remaining >> 8
byte_length += 1

return val.to_bytes(byte_length, 'big', signed=False)
else:
def bytes_from_int(val):
buf = []
while val:
val, remainder = divmod(val, 256)
buf.append(remainder)

buf.reverse()
return struct.pack('%sB' % len(buf), *buf)
28 changes: 28 additions & 0 deletions jwt/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import base64
import binascii
import struct

from .compat import bytes_from_int, text_type

try:
from cryptography.hazmat.primitives.asymmetric.utils import (
Expand All @@ -10,6 +13,9 @@


def base64url_decode(input):
if isinstance(input, text_type):
input = input.encode('ascii')

rem = len(input) % 4

if rem > 0:
Expand All @@ -22,6 +28,28 @@ def base64url_encode(input):
return base64.urlsafe_b64encode(input).replace(b'=', b'')


def to_base64url_uint(val):
if val < 0:
raise ValueError('Must be a positive integer')

int_bytes = bytes_from_int(val)

if len(int_bytes) == 0:
int_bytes = b'\x00'

return base64url_encode(int_bytes)


def from_base64url_uint(val):
if isinstance(val, text_type):
val = val.encode('ascii')

data = base64url_decode(val)

buf = struct.unpack('%sB' % len(data), data)
return int(''.join(["%02x" % byte for byte in buf]), 16)


def merge_dict(original, updates):
if not updates:
return original
Expand Down
22 changes: 3 additions & 19 deletions tests/keys/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,37 +20,21 @@ def load_hmac_key():
return base64url_decode(ensure_bytes(keyobj['k']))

try:
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.backends import default_backend

from jwt.algorithms import RSAAlgorithm
has_crypto = True
except ImportError:
has_crypto = False

if has_crypto:
def load_rsa_key():
with open(os.path.join(BASE_PATH, 'jwk_rsa_key.json'), 'r') as infile:
keyobj = json.load(infile)

return rsa.RSAPrivateNumbers(
p=decode_value(keyobj['p']),
q=decode_value(keyobj['q']),
d=decode_value(keyobj['d']),
dmp1=decode_value(keyobj['dp']),
dmq1=decode_value(keyobj['dq']),
iqmp=decode_value(keyobj['qi']),
public_numbers=load_rsa_pub_key().public_numbers()
).private_key(default_backend())
return RSAAlgorithm.from_jwk(infile.read())

def load_rsa_pub_key():
with open(os.path.join(BASE_PATH, 'jwk_rsa_pub.json'), 'r') as infile:
keyobj = json.load(infile)

return rsa.RSAPublicNumbers(
n=decode_value(keyobj['n']),
e=decode_value(keyobj['e'])
).public_key(default_backend())
return RSAAlgorithm.from_jwk(infile.read())

def load_ec_key():
with open(os.path.join(BASE_PATH, 'jwk_ec_key.json'), 'r') as infile:
Expand Down
34 changes: 34 additions & 0 deletions tests/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,18 @@ def test_algorithm_should_throw_exception_if_verify_not_impl(self):
with pytest.raises(NotImplementedError):
algo.verify('message', 'key', 'signature')

def test_algorithm_should_throw_exception_if_to_jwk_not_impl(self):
algo = Algorithm()

with pytest.raises(NotImplementedError):
algo.from_jwk('value')

def test_algorithm_should_throw_exception_if_from_jwk_not_impl(self):
algo = Algorithm()

with pytest.raises(NotImplementedError):
algo.to_jwk('value')

def test_none_algorithm_should_throw_exception_if_key_is_not_none(self):
algo = NoneAlgorithm()

Expand Down Expand Up @@ -84,6 +96,15 @@ def test_hmac_should_throw_exception_if_key_is_x509_cert(self):
with open(key_path('testkey2_rsa.pub.pem'), 'r') as keyfile:
algo.prepare_key(keyfile.read())

def test_hmac_jwk_public_and_private_keys_should_parse_and_verify(self):
algo = HMACAlgorithm(HMACAlgorithm.SHA256)

with open(key_path('jwk_hmac.json'), 'r') as keyfile:
key = algo.from_jwk(keyfile.read())

signature = algo.sign(b'Hello World!', key)
assert algo.verify(b'Hello World!', key, signature)

@pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library')
def test_rsa_should_parse_pem_public_key(self):
algo = RSAAlgorithm(RSAAlgorithm.SHA256)
Expand Down Expand Up @@ -127,6 +148,19 @@ def test_rsa_verify_should_return_false_if_signature_invalid(self):
result = algo.verify(message, pub_key, sig)
assert not result

@pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library')
def test_rsa_jwk_public_and_private_keys_should_parse_and_verify(self):
algo = RSAAlgorithm(RSAAlgorithm.SHA256)

with open(key_path('jwk_rsa_pub.json'), 'r') as keyfile:
pub_key = algo.from_jwk(keyfile.read())

with open(key_path('jwk_rsa_key.json'), 'r') as keyfile:
priv_key = algo.from_jwk(keyfile.read())

signature = algo.sign(ensure_bytes('Hello World!'), priv_key)
assert algo.verify(ensure_bytes('Hello World!'), pub_key, signature)

@pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library')
def test_ec_should_reject_non_string_key(self):
algo = ECAlgorithm(ECAlgorithm.SHA256)
Expand Down
28 changes: 28 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from jwt.utils import from_base64url_uint, to_base64url_uint

import pytest


@pytest.mark.parametrize("inputval,expected", [
(0, b'AA'),
(1, b'AQ'),
(255, b'_w'),
(65537, b'AQAB'),
(123456789, b'B1vNFQ'),
pytest.mark.xfail((-1, ''), raises=ValueError)
])
def test_to_base64url_uint(inputval, expected):
actual = to_base64url_uint(inputval)
assert actual == expected


@pytest.mark.parametrize("inputval,expected", [
(b'AA', 0),
(b'AQ', 1),
(b'_w', 255),
(b'AQAB', 65537),
(b'B1vNFQ', 123456789, ),
])
def test_from_base64url_uint(inputval, expected):
actual = from_base64url_uint(inputval)
assert actual == expected

0 comments on commit 41c24b2

Please sign in to comment.