diff --git a/HISTORY.rst b/HISTORY.rst index fbd55733d..a300dcd1d 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -71,6 +71,11 @@ Release History the ``auto_session_renewal`` functionality of ``BoxSession``, this means that authentication for ``JWTAuth`` objects can be done completely automatically, at the time of first API call. + - The constructor now supports passing the RSA private key in two different + ways: by file system path (existing functionality), or by passing the key + data directly (new functionality). The ``rsa_private_key_file_sys_path`` + parameter is now optional, but it is required to pass exactly one of + ``rsa_private_key_file_sys_path`` or ``rsa_private_key_data``. - Document that the ``enterprise_id`` argument to ``JWTAuth`` is allowed to be ``None``. - ``authenticate_instance()`` now accepts an ``enterprise`` argument, which diff --git a/boxsdk/auth/jwt_auth.py b/boxsdk/auth/jwt_auth.py index 6633c4ae9..c6e1ff5ac 100644 --- a/boxsdk/auth/jwt_auth.py +++ b/boxsdk/auth/jwt_auth.py @@ -8,12 +8,13 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey import jwt -from six import string_types, text_type +from six import binary_type, string_types, raise_from, text_type from .oauth2 import OAuth2 from ..object.user import User -from ..util.compat import total_seconds +from ..util.compat import NoneType, total_seconds class JWTAuth(OAuth2): @@ -28,7 +29,7 @@ def __init__( client_secret, enterprise_id, jwt_key_id, - rsa_private_key_file_sys_path, + rsa_private_key_file_sys_path=None, rsa_private_key_passphrase=None, user=None, store_tokens=None, @@ -37,9 +38,13 @@ def __init__( access_token=None, network_layer=None, jwt_algorithm='RS256', + rsa_private_key_data=None, ): """Extends baseclass method. + Must pass exactly one of either `rsa_private_key_file_sys_path` or + `rsa_private_key_data`. + If both `enterprise_id` and `user` are non-`None`, the `user` takes precedence when `refresh()` is called. This can be overruled with a call to `authenticate_instance()`. @@ -68,13 +73,13 @@ def __init__( :type jwt_key_id: `unicode` :param rsa_private_key_file_sys_path: - Path to an RSA private key file, used for signing the JWT assertion. + (optional) Path to an RSA private key file, used for signing the JWT assertion. :type rsa_private_key_file_sys_path: `unicode` :param rsa_private_key_passphrase: Passphrase used to unlock the private key. Do not pass a unicode string - this must be bytes. :type rsa_private_key_passphrase: - `str` or None + `bytes` or None :param user: (optional) The user to authenticate, expressed as a Box User ID or as a :class:`User` instance. @@ -120,8 +125,20 @@ def __init__( Which algorithm to use for signing the JWT assertion. Must be one of 'RS256', 'RS384', 'RS512'. :type jwt_algorithm: `unicode` + :param rsa_private_key_data: + (optional) Contents of RSA private key, used for signing the JWT assertion. Do not pass a + unicode string. Can pass a byte string, or a file-like object that returns bytes, or an + already-loaded `RSAPrivateKey` object. + :type rsa_private_key_data: `bytes` or :class:`io.IOBase` or :class:`RSAPrivateKey` """ user_id = self._normalize_user_id(user) + rsa_private_key = self._normalize_rsa_private_key( + file_sys_path=rsa_private_key_file_sys_path, + data=rsa_private_key_data, + passphrase=rsa_private_key_passphrase, + ) + del rsa_private_key_data + del rsa_private_key_file_sys_path super(JWTAuth, self).__init__( client_id, client_secret, @@ -132,12 +149,7 @@ def __init__( refresh_token=None, network_layer=network_layer, ) - with open(rsa_private_key_file_sys_path, 'rb') as key_file: - self._rsa_private_key = serialization.load_pem_private_key( - key_file.read(), - password=rsa_private_key_passphrase, - backend=default_backend(), - ) + self._rsa_private_key = rsa_private_key self._enterprise_id = enterprise_id self._jwt_algorithm = jwt_algorithm self._jwt_key_id = jwt_key_id @@ -295,3 +307,54 @@ def _refresh(self, access_token): else: new_access_token = self.authenticate_user() return new_access_token, None + + @classmethod + def _normalize_rsa_private_key(cls, file_sys_path, data, passphrase=None): + if len(list(filter(None, [file_sys_path, data]))) != 1: + raise TypeError("must pass exactly one of either rsa_private_key_file_sys_path or rsa_private_key_data") + if file_sys_path: + with open(file_sys_path, 'rb') as key_file: + data = key_file.read() + if hasattr(data, 'read') and callable(data.read): + data = data.read() + if isinstance(data, text_type): + try: + data = data.encode('ascii') + except UnicodeError: + raise_from( + TypeError("rsa_private_key_data must contain binary data (bytes/str), not a text/unicode string"), + None, + ) + if isinstance(data, binary_type): + passphrase = cls._normalize_rsa_private_key_passphrase(passphrase) + return serialization.load_pem_private_key( + data, + password=passphrase, + backend=default_backend(), + ) + if isinstance(data, RSAPrivateKey): + return data + raise TypeError( + 'rsa_private_key_data must be binary data (bytes/str), ' + 'a file-like object with a read() method, ' + 'or an instance of RSAPrivateKey, ' + 'but got {0!r}' + .format(data.__class__.__name__) + ) + + @staticmethod + def _normalize_rsa_private_key_passphrase(passphrase): + if isinstance(passphrase, text_type): + try: + return passphrase.encode('ascii') + except UnicodeError: + raise_from( + TypeError("rsa_private_key_passphrase must contain binary data (bytes/str), not a text/unicode string"), + None, + ) + if not isinstance(passphrase, (binary_type, NoneType)): + raise TypeError( + "rsa_private_key_passphrase must contain binary data (bytes/str), got {0!r}" + .format(passphrase.__class__.__name__) + ) + return passphrase diff --git a/boxsdk/util/compat.py b/boxsdk/util/compat.py index 9f5897a47..2243af924 100644 --- a/boxsdk/util/compat.py +++ b/boxsdk/util/compat.py @@ -7,6 +7,9 @@ import six +NoneType = type(None) + + if not hasattr(timedelta, 'total_seconds'): def total_seconds(delta): """ diff --git a/test/unit/auth/test_jwt_auth.py b/test/unit/auth/test_jwt_auth.py index e32a6868b..830dfe063 100644 --- a/test/unit/auth/test_jwt_auth.py +++ b/test/unit/auth/test_jwt_auth.py @@ -4,15 +4,18 @@ from contextlib import contextmanager from datetime import datetime, timedelta +import io from itertools import product import json import random import string from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, generate_private_key as generate_rsa_private_key +from cryptography.hazmat.primitives import serialization from mock import Mock, mock_open, patch, sentinel import pytest -from six import string_types, text_type +from six import binary_type, string_types, text_type from boxsdk.auth.jwt_auth import JWTAuth from boxsdk.config import API @@ -35,11 +38,26 @@ def jwt_key_id(): return 'jwt_key_id_1' +@pytest.fixture(scope='module') +def rsa_private_key_object(): + return generate_rsa_private_key(public_exponent=65537, key_size=4096, backend=default_backend()) + + @pytest.fixture(params=(None, b'strong_password')) def rsa_passphrase(request): return request.param +@pytest.fixture +def rsa_private_key_bytes(rsa_private_key_object, rsa_passphrase): + encryption = serialization.BestAvailableEncryption(rsa_passphrase) if rsa_passphrase else serialization.NoEncryption() + return rsa_private_key_object.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=encryption, + ) + + @pytest.fixture(scope='function') def successful_token_response(successful_token_mock, successful_token_json_response): # pylint:disable=redefined-outer-name @@ -52,8 +70,76 @@ def successful_token_response(successful_token_mock, successful_token_json_respo return successful_token_mock +@pytest.mark.parametrize(('key_file', 'key_data'), [(None, None), ('fake sys path', 'fake key data')]) +@pytest.mark.parametrize('rsa_passphrase', [None]) +def test_jwt_auth_init_raises_type_error_unless_exactly_one_of_rsa_private_key_file_or_data_is_given(key_file, key_data, rsa_private_key_bytes): + kwargs = dict( + rsa_private_key_data=rsa_private_key_bytes, + client_id=None, + client_secret=None, + jwt_key_id=None, + enterprise_id=None, + ) + JWTAuth(**kwargs) + kwargs.update(rsa_private_key_file_sys_path=key_file, rsa_private_key_data=key_data) + with pytest.raises(TypeError): + JWTAuth(**kwargs) + + +@pytest.mark.parametrize('key_data', [object(), u'ƒøø']) +@pytest.mark.parametrize('rsa_passphrase', [None]) +def test_jwt_auth_init_raises_type_error_if_rsa_private_key_data_has_unexpected_type(key_data, rsa_private_key_bytes): + kwargs = dict( + rsa_private_key_data=rsa_private_key_bytes, + client_id=None, + client_secret=None, + jwt_key_id=None, + enterprise_id=None, + ) + JWTAuth(**kwargs) + kwargs.update(rsa_private_key_data=key_data) + with pytest.raises(TypeError): + JWTAuth(**kwargs) + + +@pytest.mark.parametrize('rsa_private_key_data_type', [io.BytesIO, text_type, binary_type, RSAPrivateKey]) +def test_jwt_auth_init_accepts_rsa_private_key_data(rsa_private_key_bytes, rsa_passphrase, rsa_private_key_data_type): + if rsa_private_key_data_type is text_type: + rsa_private_key_data = text_type(rsa_private_key_bytes.decode('ascii')) + elif rsa_private_key_data_type is RSAPrivateKey: + rsa_private_key_data = serialization.load_pem_private_key( + rsa_private_key_bytes, + password=rsa_passphrase, + backend=default_backend(), + ) + else: + rsa_private_key_data = rsa_private_key_data_type(rsa_private_key_bytes) + JWTAuth( + rsa_private_key_data=rsa_private_key_data, + rsa_private_key_passphrase=rsa_passphrase, + client_id=None, + client_secret=None, + jwt_key_id=None, + enterprise_id=None, + ) + + +@pytest.fixture(params=[False, True]) +def pass_private_key_by_path(request): + """For jwt_auth_init_mocks, whether to pass the private key via sys_path (True) or pass the data directly (False).""" + return request.param + + @pytest.fixture -def jwt_auth_init_mocks(mock_network_layer, successful_token_response, jwt_algorithm, jwt_key_id, rsa_passphrase): +def jwt_auth_init_mocks( + mock_network_layer, + successful_token_response, + jwt_algorithm, + jwt_key_id, + rsa_passphrase, + rsa_private_key_bytes, + pass_private_key_by_path, +): # pylint:disable=redefined-outer-name @contextmanager @@ -70,15 +156,14 @@ def _jwt_auth_init_mocks(**kwargs): 'box_device_id': '0', 'box_device_name': 'my_awesome_device', } - mock_network_layer.request.return_value = successful_token_response - key_file_read_data = b'key_file_read_data' - with patch('boxsdk.auth.jwt_auth.open', mock_open(read_data=key_file_read_data), create=True) as jwt_auth_open: + with patch('boxsdk.auth.jwt_auth.open', mock_open(read_data=rsa_private_key_bytes), create=True) as jwt_auth_open: with patch('cryptography.hazmat.primitives.serialization.load_pem_private_key') as load_pem_private_key: oauth = JWTAuth( client_id=fake_client_id, client_secret=fake_client_secret, - rsa_private_key_file_sys_path=sentinel.rsa_path, + rsa_private_key_file_sys_path=(sentinel.rsa_path if pass_private_key_by_path else None), + rsa_private_key_data=(None if pass_private_key_by_path else rsa_private_key_bytes), rsa_private_key_passphrase=rsa_passphrase, network_layer=mock_network_layer, box_device_name='my_awesome_device', @@ -87,11 +172,13 @@ def _jwt_auth_init_mocks(**kwargs): enterprise_id=kwargs.pop('enterprise_id', None), **kwargs ) - - jwt_auth_open.assert_called_once_with(sentinel.rsa_path, 'rb') - jwt_auth_open.return_value.read.assert_called_once_with() # pylint:disable=no-member + if pass_private_key_by_path: + jwt_auth_open.assert_called_once_with(sentinel.rsa_path, 'rb') + jwt_auth_open.return_value.read.assert_called_once_with() # pylint:disable=no-member + else: + jwt_auth_open.assert_not_called() load_pem_private_key.assert_called_once_with( - key_file_read_data, + rsa_private_key_bytes, password=rsa_passphrase, backend=default_backend(), ) diff --git a/test/unit/util/test_api_call_decorator.py b/test/unit/util/test_api_call_decorator.py index 9c0209274..ab395f08c 100644 --- a/test/unit/util/test_api_call_decorator.py +++ b/test/unit/util/test_api_call_decorator.py @@ -66,12 +66,12 @@ def func(): def test_api_call_decorated_method_must_be_a_cloneable_method(): - class Cls(object): + class NonCloneable(object): @api_call def func(self): pass - obj = Cls() + obj = NonCloneable() with pytest.raises(TypeError): obj.func()