Skip to content

Commit

Permalink
Refactor error handling in Algorithm.prepare_key() methods
Browse files Browse the repository at this point in the history
Our error handling in Algorithm.prepare_key() was previously weird and
kind of inconsistent. This change makes a number of improvements:

* Refactors RSA and ECDSA prepare_key() methods to reduce nesting and
  make the code simpler to understand
* All calls to Algorithm.prepare_key() return InvalidKeyError (or a
  subclass) or a valid key instance.
* Created a new InvalidAsymmetricKeyError class that is used to provide
  a standard message when an invalid RSA or ECDSA key is used.
  • Loading branch information
mark-adams committed Mar 14, 2017
1 parent 1710c15 commit d04339d
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 40 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ This project adheres to [Semantic Versioning](http://semver.org/).
-------------------------------------------------------------------------
### Changed
- Add support for ECDSA public keys in RFC 4253 (OpenSSH) format [#244][244]
- All Algorithm.prepare_key() calls now return either a valid key value or raise InvalidKeyError
- Renamed commandline script `jwt` to `jwt-cli` to avoid issues with the script clobbering the `jwt` module in some circumstances.
- Better error messages when using an algorithm that requires the cryptography package, but it isn't available [#230][230]

Expand Down
78 changes: 47 additions & 31 deletions jwt/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@


from .compat import constant_time_compare, string_types
from .exceptions import InvalidKeyError
from .exceptions import (InvalidAsymmetricKeyError, InvalidJwkError,
InvalidKeyError)
from .utils import (
base64url_decode, base64url_encode, der_to_raw_signature,
force_bytes, force_unicode, from_base64url_uint, raw_to_der_signature,
Expand Down Expand Up @@ -137,6 +138,9 @@ def __init__(self, hash_alg):
self.hash_alg = hash_alg

def prepare_key(self, key):
if not isinstance(key, string_types):
raise InvalidKeyError("HMAC secret key must be a string type.")

key = force_bytes(key)

invalid_strings = [
Expand Down Expand Up @@ -164,7 +168,7 @@ def from_jwk(jwk):
obj = json.loads(jwk)

if obj.get('kty') != 'oct':
raise InvalidKeyError('Not an HMAC key')
raise InvalidKeyError('Invalid key: Not an HMAC key')

return base64url_decode(obj['k'])

Expand Down Expand Up @@ -194,20 +198,28 @@ def prepare_key(self, key):
isinstance(key, RSAPublicKey):
return key

if isinstance(key, string_types):
key = force_bytes(key)
if not isinstance(key, string_types):
raise InvalidAsymmetricKeyError

key = force_bytes(key)

if key.startswith(b'ssh-rsa'):
try:
if key.startswith(b'ssh-rsa'):
key = load_ssh_public_key(key, backend=default_backend())
else:
key = load_pem_private_key(key, password=None, backend=default_backend())
return load_ssh_public_key(key, backend=default_backend())
except ValueError:
key = load_pem_public_key(key, backend=default_backend())
else:
raise TypeError('Expecting a PEM-formatted key.')
raise InvalidAsymmetricKeyError

try:
return load_pem_private_key(key, password=None, backend=default_backend())
except ValueError:
pass

try:
return load_pem_public_key(key, backend=default_backend())
except ValueError:
pass

return key
raise InvalidAsymmetricKeyError

@staticmethod
def to_jwk(key_obj):
Expand Down Expand Up @@ -241,7 +253,7 @@ def to_jwk(key_obj):
'e': force_unicode(to_base64url_uint(numbers.e))
}
else:
raise InvalidKeyError('Not a public or private key')
raise InvalidKeyError('Invalid key: Expecting a RSAPublicKey or RSAPrivateKey instance.')

return json.dumps(obj)

Expand All @@ -250,22 +262,22 @@ def from_jwk(jwk):
try:
obj = json.loads(jwk)
except ValueError:
raise InvalidKeyError('Key is not valid JSON')
raise InvalidJwkError('Key is not valid JSON')

if obj.get('kty') != 'RSA':
raise InvalidKeyError('Not an RSA key')
raise InvalidJwkError('Not an RSA key')

if 'd' in obj and 'e' in obj and 'n' in obj:
# Private key
if 'oth' in obj:
raise InvalidKeyError('Unsupported RSA private key: > 2 primes not supported')
raise InvalidJwkError('Unsupported RSA private key: > 2 primes not supported')

other_props = ['p', 'q', 'dp', 'dq', 'qi']
props_found = [prop in obj for prop in other_props]
any_props_found = any(props_found)

if any_props_found and not all(props_found):
raise InvalidKeyError('RSA key must include all parameters if any are present besides d')
raise InvalidJwkError('RSA key must include all parameters if any are present besides d')

public_numbers = RSAPublicNumbers(
from_base64url_uint(obj['e']), from_base64url_uint(obj['n'])
Expand Down Expand Up @@ -306,7 +318,7 @@ def from_jwk(jwk):

return numbers.public_key(default_backend())
else:
raise InvalidKeyError('Not a public or private key')
raise InvalidKeyError('Not a valid JWK public or private key')

def sign(self, msg, key):
signer = key.signer(
Expand Down Expand Up @@ -349,24 +361,28 @@ def prepare_key(self, key):
isinstance(key, EllipticCurvePublicKey):
return key

if isinstance(key, string_types):
key = force_bytes(key)
if not isinstance(key, string_types):
raise InvalidAsymmetricKeyError

key = force_bytes(key)

# Attempt to load key. We don't know if it's
# a Signing Key or a Verifying Key, so we try
# the Verifying Key first.
if key.startswith(b'ecdsa-sha2-'):
try:
if key.startswith(b'ecdsa-sha2-'):
key = load_ssh_public_key(key, backend=default_backend())
else:
key = load_pem_public_key(key, backend=default_backend())
return load_ssh_public_key(key, backend=default_backend())
except ValueError:
key = load_pem_private_key(key, password=None, backend=default_backend())
raise InvalidAsymmetricKeyError

else:
raise TypeError('Expecting a PEM-formatted key.')
try:
return load_pem_public_key(key, backend=default_backend())
except ValueError:
pass

try:
return load_pem_private_key(key, password=None, backend=default_backend())
except ValueError:
pass

return key
raise InvalidAsymmetricKeyError

def sign(self, msg, key):
signer = key.signer(ec.ECDSA(self.hash_alg()))
Expand Down
3 changes: 2 additions & 1 deletion jwt/contrib/algorithms/py_ecdsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from jwt.algorithms import Algorithm
from jwt.compat import string_types, text_type
from jwt.exceptions import InvalidAsymmetricKeyError


class ECAlgorithm(Algorithm):
Expand Down Expand Up @@ -44,7 +45,7 @@ def prepare_key(self, key):
key = ecdsa.SigningKey.from_pem(key)

else:
raise TypeError('Expecting a PEM-formatted key.')
raise InvalidAsymmetricKeyError('Expecting a PEM-formatted key.')

return key

Expand Down
3 changes: 2 additions & 1 deletion jwt/contrib/algorithms/pycrypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from jwt.algorithms import Algorithm
from jwt.compat import string_types, text_type
from jwt.exceptions import InvalidAsymmetricKeyError


class RSAAlgorithm(Algorithm):
Expand Down Expand Up @@ -36,7 +37,7 @@ def prepare_key(self, key):

key = RSA.importKey(key)
else:
raise TypeError('Expecting a PEM- or RSA-formatted key.')
raise InvalidAsymmetricKeyError

return key

Expand Down
10 changes: 9 additions & 1 deletion jwt/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,18 @@ class ImmatureSignatureError(InvalidTokenError):
pass


class InvalidKeyError(Exception):
class InvalidKeyError(ValueError):
pass


class InvalidAsymmetricKeyError(InvalidKeyError):
message = 'Invalid key: Keys must be in PEM or RFC 4253 format.'


class InvalidJwkError(InvalidKeyError):
message = 'Invalid key: Keys must be in JWK format.'


class InvalidAlgorithmError(InvalidTokenError):
pass

Expand Down
5 changes: 3 additions & 2 deletions tests/contrib/test_algorithms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import base64

from jwt.exceptions import InvalidAsymmetricKeyError
from jwt.utils import force_bytes, force_unicode

import pytest
Expand Down Expand Up @@ -36,7 +37,7 @@ def test_rsa_should_accept_unicode_key(self):
def test_rsa_should_reject_non_string_key(self):
algo = RSAAlgorithm(RSAAlgorithm.SHA256)

with pytest.raises(TypeError):
with pytest.raises(InvalidAsymmetricKeyError):
algo.prepare_key(None)

def test_rsa_sign_should_generate_correct_signature_value(self):
Expand Down Expand Up @@ -117,7 +118,7 @@ class TestEcdsaAlgorithms:
def test_ec_should_reject_non_string_key(self):
algo = ECAlgorithm(ECAlgorithm.SHA256)

with pytest.raises(TypeError):
with pytest.raises(InvalidAsymmetricKeyError):
algo.prepare_key(None)

def test_ec_should_accept_unicode_key(self):
Expand Down
8 changes: 4 additions & 4 deletions tests/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ def test_none_algorithm_should_throw_exception_if_key_is_not_none(self):
def test_hmac_should_reject_nonstring_key(self):
algo = HMACAlgorithm(HMACAlgorithm.SHA256)

with pytest.raises(TypeError) as context:
with pytest.raises(InvalidKeyError) as context:
algo.prepare_key(object())

exception = context.value
assert str(exception) == 'Expected a string value'
assert str(exception) == 'HMAC secret key must be a string type.'

def test_hmac_should_accept_unicode_key(self):
algo = HMACAlgorithm(HMACAlgorithm.SHA256)
Expand Down Expand Up @@ -144,7 +144,7 @@ def test_rsa_should_accept_unicode_key(self):
def test_rsa_should_reject_non_string_key(self):
algo = RSAAlgorithm(RSAAlgorithm.SHA256)

with pytest.raises(TypeError):
with pytest.raises(InvalidKeyError):
algo.prepare_key(None)

@pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library')
Expand Down Expand Up @@ -358,7 +358,7 @@ def test_rsa_from_jwk_raises_exception_on_invalid_key(self):
def test_ec_should_reject_non_string_key(self):
algo = ECAlgorithm(ECAlgorithm.SHA256)

with pytest.raises(TypeError):
with pytest.raises(InvalidKeyError):
algo.prepare_key(None)

@pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library')
Expand Down

0 comments on commit d04339d

Please sign in to comment.