Skip to content

Commit

Permalink
Add JWK support for HMAC and RSA keys
Browse files Browse the repository at this point in the history
- JWKs for RSA and HMAC can be encoded / decoded using the .to_jwk() and
  .from_jwk() methods on their respective jwt.algorithms instances

- Replaced tests.utils ensure_unicode and ensure_bytes with jwt.utils versions
  • Loading branch information
mark-adams committed Aug 28, 2016
1 parent b35d522 commit 42b0114
Show file tree
Hide file tree
Showing 11 changed files with 554 additions and 130 deletions.
158 changes: 144 additions & 14 deletions jwt/algorithms.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
import hashlib
import hmac
import json

from .compat import binary_type, constant_time_compare, is_string_type

from .compat import constant_time_compare, string_types
from .exceptions import InvalidKeyError
from .utils import der_to_raw_signature, raw_to_der_signature
from .utils import (
base64url_decode, base64url_encode, der_to_raw_signature,
force_bytes, force_unicode, from_base64url_uint, raw_to_der_signature,
to_base64url_uint
)

try:
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.serialization import (
load_pem_private_key, load_pem_public_key, load_ssh_public_key
)
from cryptography.hazmat.primitives.asymmetric.rsa import (
RSAPrivateKey, RSAPublicKey
RSAPrivateKey, RSAPublicKey, RSAPrivateNumbers, RSAPublicNumbers,
rsa_recover_prime_factors, rsa_crt_dmp1, rsa_crt_dmq1, rsa_crt_iqmp
)
from cryptography.hazmat.primitives.asymmetric.ec import (
EllipticCurvePrivateKey, EllipticCurvePublicKey
Expand Down Expand Up @@ -77,6 +84,20 @@ def verify(self, msg, key, sig):
"""
raise NotImplementedError

@staticmethod
def to_jwk(key_obj):
"""
Serializes a given RSA key into a JWK
"""
raise NotImplementedError

@staticmethod
def from_jwk(jwk):
"""
Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object
"""
raise NotImplementedError


class NoneAlgorithm(Algorithm):
"""
Expand Down Expand Up @@ -112,11 +133,7 @@ def __init__(self, hash_alg):
self.hash_alg = hash_alg

def prepare_key(self, key):
if not is_string_type(key):
raise TypeError('Expecting a string- or bytes-formatted key.')

if not isinstance(key, binary_type):
key = key.encode('utf-8')
key = force_bytes(key)

invalid_strings = [
b'-----BEGIN PUBLIC KEY-----',
Expand All @@ -131,6 +148,22 @@ def prepare_key(self, key):

return key

@staticmethod
def to_jwk(key_obj):
return json.dumps({
'k': force_unicode(base64url_encode(force_bytes(key_obj))),
'kty': 'oct'
})

@staticmethod
def from_jwk(jwk):
obj = json.loads(jwk)

if obj.get('kty') != 'oct':
raise InvalidKeyError('Not an HMAC key')

return base64url_decode(obj['k'])

def sign(self, msg, key):
return hmac.new(key, msg, self.hash_alg).digest()

Expand All @@ -156,9 +189,8 @@ def prepare_key(self, key):
isinstance(key, RSAPublicKey):
return key

if is_string_type(key):
if not isinstance(key, binary_type):
key = key.encode('utf-8')
if isinstance(key, string_types):
key = force_bytes(key)

try:
if key.startswith(b'ssh-rsa'):
Expand All @@ -172,6 +204,105 @@ def prepare_key(self, key):

return key

@staticmethod
def to_jwk(key_obj):
obj = None

if getattr(key_obj, 'private_numbers', None):
# Private key
numbers = key_obj.private_numbers()

obj = {
'kty': 'RSA',
'key_ops': ['sign'],
'n': force_unicode(to_base64url_uint(numbers.public_numbers.n)),
'e': force_unicode(to_base64url_uint(numbers.public_numbers.e)),
'd': force_unicode(to_base64url_uint(numbers.d)),
'p': force_unicode(to_base64url_uint(numbers.p)),
'q': force_unicode(to_base64url_uint(numbers.q)),
'dp': force_unicode(to_base64url_uint(numbers.dmp1)),
'dq': force_unicode(to_base64url_uint(numbers.dmq1)),
'qi': force_unicode(to_base64url_uint(numbers.iqmp))
}

elif getattr(key_obj, 'verifier', None):
# Public key
numbers = key_obj.public_numbers()

obj = {
'kty': 'RSA',
'key_ops': ['verify'],
'n': force_unicode(to_base64url_uint(numbers.n)),
'e': force_unicode(to_base64url_uint(numbers.e))
}
else:
raise InvalidKeyError('Not a public or private key')

return json.dumps(obj)

@staticmethod
def from_jwk(jwk):
try:
obj = json.loads(jwk)
except ValueError:
raise InvalidKeyError('Key is not valid JSON')

if obj.get('kty') != 'RSA':
raise InvalidKeyError('Not an RSA key')

if 'd' in obj and 'e' in obj and 'n' in obj:
# Private key
if 'oth' in obj:
raise InvalidKeyError('Unsupported RSA private key: > 2 primes not supported')

other_props = ['p', 'q', 'dp', 'dq', 'qi']
props_found = [prop in obj for prop in other_props]
any_props_found = any(props_found)

if any_props_found and not all(props_found):
raise InvalidKeyError('RSA key must include all parameters if any are present besides d')

public_numbers = RSAPublicNumbers(
from_base64url_uint(obj['e']), from_base64url_uint(obj['n'])
)

if any_props_found:
numbers = RSAPrivateNumbers(
d=from_base64url_uint(obj['d']),
p=from_base64url_uint(obj['p']),
q=from_base64url_uint(obj['q']),
dmp1=from_base64url_uint(obj['dp']),
dmq1=from_base64url_uint(obj['dq']),
iqmp=from_base64url_uint(obj['qi']),
public_numbers=public_numbers
)
else:
d = from_base64url_uint(obj['d'])
p, q = rsa_recover_prime_factors(
public_numbers.n, d, public_numbers.e
)

numbers = RSAPrivateNumbers(
d=d,
p=p,
q=q,
dmp1=rsa_crt_dmp1(d, p),
dmq1=rsa_crt_dmq1(d, q),
iqmp=rsa_crt_iqmp(p, q),
public_numbers=public_numbers
)

return numbers.private_key(default_backend())
elif 'n' in obj and 'e' in obj:
# Public key
numbers = RSAPublicNumbers(
from_base64url_uint(obj['e']), from_base64url_uint(obj['n'])
)

return numbers.public_key(default_backend())
else:
raise InvalidKeyError('Not a public or private key')

def sign(self, msg, key):
signer = key.signer(
padding.PKCS1v15(),
Expand Down Expand Up @@ -213,9 +344,8 @@ def prepare_key(self, key):
isinstance(key, EllipticCurvePublicKey):
return key

if is_string_type(key):
if not isinstance(key, binary_type):
key = key.encode('utf-8')
if isinstance(key, string_types):
key = force_bytes(key)

# Attempt to load key. We don't know if it's
# a Signing Key or a Verifying Key, so we try
Expand Down
14 changes: 8 additions & 6 deletions jwt/api_jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .algorithms import Algorithm, get_default_algorithms # NOQA
from .compat import binary_type, string_types, text_type
from .exceptions import DecodeError, InvalidAlgorithmError, InvalidTokenError
from .utils import base64url_decode, base64url_encode, merge_dict
from .utils import base64url_decode, base64url_encode, force_bytes, merge_dict


class PyJWS(object):
Expand Down Expand Up @@ -82,11 +82,13 @@ def encode(self, payload, key, algorithm='HS256', headers=None,
self._validate_headers(headers)
header.update(headers)

json_header = json.dumps(
header,
separators=(',', ':'),
cls=json_encoder
).encode('utf-8')
json_header = force_bytes(
json.dumps(
header,
separators=(',', ':'),
cls=json_encoder
)
)

segments.append(base64url_encode(json_header))
segments.append(base64url_encode(payload))
Expand Down
28 changes: 23 additions & 5 deletions jwt/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
versions of python, and compatibility wrappers around optional packages.
"""
# flake8: noqa
import sys
import hmac
import struct
import sys


PY3 = sys.version_info[0] == 3
Expand All @@ -20,10 +21,6 @@
string_types = (text_type, binary_type)


def is_string_type(val):
return any([isinstance(val, typ) for typ in string_types])


def timedelta_total_seconds(delta):
try:
delta.total_seconds
Expand Down Expand Up @@ -56,3 +53,24 @@ def constant_time_compare(val1, val2):
result |= ord(x) ^ ord(y)

return result == 0

# Use int.to_bytes if it exists (Python 3)
if getattr(int, 'to_bytes', None):
def bytes_from_int(val):
remaining = val
byte_length = 0

while remaining != 0:
remaining = remaining >> 8
byte_length += 1

return val.to_bytes(byte_length, 'big', signed=False)
else:
def bytes_from_int(val):
buf = []
while val:
val, remainder = divmod(val, 256)
buf.append(remainder)

buf.reverse()
return struct.pack('%sB' % len(buf), *buf)
46 changes: 46 additions & 0 deletions jwt/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import base64
import binascii
import struct

from .compat import binary_type, bytes_from_int, text_type

try:
from cryptography.hazmat.primitives.asymmetric.utils import (
Expand All @@ -9,7 +12,28 @@
pass


def force_unicode(value):
if isinstance(value, binary_type):
return value.decode('utf-8')
elif isinstance(value, text_type):
return value
else:
raise TypeError('Expected a string value')


def force_bytes(value):
if isinstance(value, text_type):
return value.encode('utf-8')
elif isinstance(value, binary_type):
return value
else:
raise TypeError('Expected a string value')


def base64url_decode(input):
if isinstance(input, text_type):
input = input.encode('ascii')

rem = len(input) % 4

if rem > 0:
Expand All @@ -22,6 +46,28 @@ def base64url_encode(input):
return base64.urlsafe_b64encode(input).replace(b'=', b'')


def to_base64url_uint(val):
if val < 0:
raise ValueError('Must be a positive integer')

int_bytes = bytes_from_int(val)

if len(int_bytes) == 0:
int_bytes = b'\x00'

return base64url_encode(int_bytes)


def from_base64url_uint(val):
if isinstance(val, text_type):
val = val.encode('ascii')

data = base64url_decode(val)

buf = struct.unpack('%sB' % len(data), data)
return int(''.join(["%02x" % byte for byte in buf]), 16)


def merge_dict(original, updates):
if not updates:
return original
Expand Down
Loading

0 comments on commit 42b0114

Please sign in to comment.