Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ Release History
the ``auto_session_renewal`` functionality of ``BoxSession``, this means
that authentication for ``JWTAuth`` objects can be done completely
automatically, at the time of first API call.
- The constructor now supports passing the RSA private key in two different
ways: by file system path (existing functionality), or by passing the key
data directly (new functionality). The ``rsa_private_key_file_sys_path``
parameter is now optional, but it is required to pass exactly one of
``rsa_private_key_file_sys_path`` or ``rsa_private_key_data``.
- Document that the ``enterprise_id`` argument to ``JWTAuth`` is allowed to
be ``None``.
- ``authenticate_instance()`` now accepts an ``enterprise`` argument, which
Expand Down
85 changes: 74 additions & 11 deletions boxsdk/auth/jwt_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
import jwt
from six import string_types, text_type
from six import binary_type, string_types, raise_from, text_type

from .oauth2 import OAuth2
from ..object.user import User
from ..util.compat import total_seconds
from ..util.compat import NoneType, total_seconds


class JWTAuth(OAuth2):
Expand All @@ -28,7 +29,7 @@ def __init__(
client_secret,
enterprise_id,
jwt_key_id,
rsa_private_key_file_sys_path,
rsa_private_key_file_sys_path=None,
rsa_private_key_passphrase=None,
user=None,
store_tokens=None,
Expand All @@ -37,9 +38,13 @@ def __init__(
access_token=None,
network_layer=None,
jwt_algorithm='RS256',
rsa_private_key_data=None,
):
"""Extends baseclass method.

Must pass exactly one of either `rsa_private_key_file_sys_path` or
`rsa_private_key_data`.

If both `enterprise_id` and `user` are non-`None`, the `user` takes
precedence when `refresh()` is called. This can be overruled with a
call to `authenticate_instance()`.
Expand Down Expand Up @@ -68,13 +73,13 @@ def __init__(
:type jwt_key_id:
`unicode`
:param rsa_private_key_file_sys_path:
Path to an RSA private key file, used for signing the JWT assertion.
(optional) Path to an RSA private key file, used for signing the JWT assertion.
:type rsa_private_key_file_sys_path:
`unicode`
:param rsa_private_key_passphrase:
Passphrase used to unlock the private key. Do not pass a unicode string - this must be bytes.
:type rsa_private_key_passphrase:
`str` or None
`bytes` or None
:param user:
(optional) The user to authenticate, expressed as a Box User ID or
as a :class:`User` instance.
Expand Down Expand Up @@ -120,8 +125,20 @@ def __init__(
Which algorithm to use for signing the JWT assertion. Must be one of 'RS256', 'RS384', 'RS512'.
:type jwt_algorithm:
`unicode`
:param rsa_private_key_data:
(optional) Contents of RSA private key, used for signing the JWT assertion. Do not pass a
unicode string. Can pass a byte string, or a file-like object that returns bytes, or an
already-loaded `RSAPrivateKey` object.
:type rsa_private_key_data: `bytes` or :class:`io.IOBase` or :class:`RSAPrivateKey`
"""
user_id = self._normalize_user_id(user)
rsa_private_key = self._normalize_rsa_private_key(
file_sys_path=rsa_private_key_file_sys_path,
data=rsa_private_key_data,
passphrase=rsa_private_key_passphrase,
)
del rsa_private_key_data
del rsa_private_key_file_sys_path
super(JWTAuth, self).__init__(
client_id,
client_secret,
Expand All @@ -132,12 +149,7 @@ def __init__(
refresh_token=None,
network_layer=network_layer,
)
with open(rsa_private_key_file_sys_path, 'rb') as key_file:
self._rsa_private_key = serialization.load_pem_private_key(
key_file.read(),
password=rsa_private_key_passphrase,
backend=default_backend(),
)
self._rsa_private_key = rsa_private_key
self._enterprise_id = enterprise_id
self._jwt_algorithm = jwt_algorithm
self._jwt_key_id = jwt_key_id
Expand Down Expand Up @@ -295,3 +307,54 @@ def _refresh(self, access_token):
else:
new_access_token = self.authenticate_user()
return new_access_token, None

@classmethod
def _normalize_rsa_private_key(cls, file_sys_path, data, passphrase=None):
if len(list(filter(None, [file_sys_path, data]))) != 1:
raise TypeError("must pass exactly one of either rsa_private_key_file_sys_path or rsa_private_key_data")
if file_sys_path:
with open(file_sys_path, 'rb') as key_file:
data = key_file.read()
if hasattr(data, 'read') and callable(data.read):
data = data.read()
if isinstance(data, text_type):
try:
data = data.encode('ascii')
except UnicodeError:
raise_from(
TypeError("rsa_private_key_data must contain binary data (bytes/str), not a text/unicode string"),
None,
)
if isinstance(data, binary_type):
passphrase = cls._normalize_rsa_private_key_passphrase(passphrase)
return serialization.load_pem_private_key(
data,
password=passphrase,
backend=default_backend(),
)
if isinstance(data, RSAPrivateKey):
return data
raise TypeError(
'rsa_private_key_data must be binary data (bytes/str), '
'a file-like object with a read() method, '
'or an instance of RSAPrivateKey, '
'but got {0!r}'
.format(data.__class__.__name__)
)

@staticmethod
def _normalize_rsa_private_key_passphrase(passphrase):
if isinstance(passphrase, text_type):
try:
return passphrase.encode('ascii')
except UnicodeError:
raise_from(
TypeError("rsa_private_key_passphrase must contain binary data (bytes/str), not a text/unicode string"),
None,
)
if not isinstance(passphrase, (binary_type, NoneType)):
raise TypeError(
"rsa_private_key_passphrase must contain binary data (bytes/str), got {0!r}"
.format(passphrase.__class__.__name__)
)
return passphrase
3 changes: 3 additions & 0 deletions boxsdk/util/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import six


NoneType = type(None)


if not hasattr(timedelta, 'total_seconds'):
def total_seconds(delta):
"""
Expand Down
107 changes: 97 additions & 10 deletions test/unit/auth/test_jwt_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@

from contextlib import contextmanager
from datetime import datetime, timedelta
import io
from itertools import product
import json
import random
import string

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, generate_private_key as generate_rsa_private_key
from cryptography.hazmat.primitives import serialization
from mock import Mock, mock_open, patch, sentinel
import pytest
from six import string_types, text_type
from six import binary_type, string_types, text_type

from boxsdk.auth.jwt_auth import JWTAuth
from boxsdk.config import API
Expand All @@ -35,11 +38,26 @@ def jwt_key_id():
return 'jwt_key_id_1'


@pytest.fixture(scope='module')
def rsa_private_key_object():
return generate_rsa_private_key(public_exponent=65537, key_size=4096, backend=default_backend())


@pytest.fixture(params=(None, b'strong_password'))
def rsa_passphrase(request):
return request.param


@pytest.fixture
def rsa_private_key_bytes(rsa_private_key_object, rsa_passphrase):
encryption = serialization.BestAvailableEncryption(rsa_passphrase) if rsa_passphrase else serialization.NoEncryption()
return rsa_private_key_object.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=encryption,
)


@pytest.fixture(scope='function')
def successful_token_response(successful_token_mock, successful_token_json_response):
# pylint:disable=redefined-outer-name
Expand All @@ -52,8 +70,76 @@ def successful_token_response(successful_token_mock, successful_token_json_respo
return successful_token_mock


@pytest.mark.parametrize(('key_file', 'key_data'), [(None, None), ('fake sys path', 'fake key data')])
@pytest.mark.parametrize('rsa_passphrase', [None])
def test_jwt_auth_init_raises_type_error_unless_exactly_one_of_rsa_private_key_file_or_data_is_given(key_file, key_data, rsa_private_key_bytes):
kwargs = dict(
rsa_private_key_data=rsa_private_key_bytes,
client_id=None,
client_secret=None,
jwt_key_id=None,
enterprise_id=None,
)
JWTAuth(**kwargs)
kwargs.update(rsa_private_key_file_sys_path=key_file, rsa_private_key_data=key_data)
with pytest.raises(TypeError):
JWTAuth(**kwargs)


@pytest.mark.parametrize('key_data', [object(), u'ƒøø'])
@pytest.mark.parametrize('rsa_passphrase', [None])
def test_jwt_auth_init_raises_type_error_if_rsa_private_key_data_has_unexpected_type(key_data, rsa_private_key_bytes):
kwargs = dict(
rsa_private_key_data=rsa_private_key_bytes,
client_id=None,
client_secret=None,
jwt_key_id=None,
enterprise_id=None,
)
JWTAuth(**kwargs)
kwargs.update(rsa_private_key_data=key_data)
with pytest.raises(TypeError):
JWTAuth(**kwargs)


@pytest.mark.parametrize('rsa_private_key_data_type', [io.BytesIO, text_type, binary_type, RSAPrivateKey])
def test_jwt_auth_init_accepts_rsa_private_key_data(rsa_private_key_bytes, rsa_passphrase, rsa_private_key_data_type):
if rsa_private_key_data_type is text_type:
rsa_private_key_data = text_type(rsa_private_key_bytes.decode('ascii'))
elif rsa_private_key_data_type is RSAPrivateKey:
rsa_private_key_data = serialization.load_pem_private_key(
rsa_private_key_bytes,
password=rsa_passphrase,
backend=default_backend(),
)
else:
rsa_private_key_data = rsa_private_key_data_type(rsa_private_key_bytes)
JWTAuth(
rsa_private_key_data=rsa_private_key_data,
rsa_private_key_passphrase=rsa_passphrase,
client_id=None,
client_secret=None,
jwt_key_id=None,
enterprise_id=None,
)


@pytest.fixture(params=[False, True])
def pass_private_key_by_path(request):
"""For jwt_auth_init_mocks, whether to pass the private key via sys_path (True) or pass the data directly (False)."""
return request.param


@pytest.fixture
def jwt_auth_init_mocks(mock_network_layer, successful_token_response, jwt_algorithm, jwt_key_id, rsa_passphrase):
def jwt_auth_init_mocks(
mock_network_layer,
successful_token_response,
jwt_algorithm,
jwt_key_id,
rsa_passphrase,
rsa_private_key_bytes,
pass_private_key_by_path,
):
# pylint:disable=redefined-outer-name

@contextmanager
Expand All @@ -70,15 +156,14 @@ def _jwt_auth_init_mocks(**kwargs):
'box_device_id': '0',
'box_device_name': 'my_awesome_device',
}

mock_network_layer.request.return_value = successful_token_response
key_file_read_data = b'key_file_read_data'
with patch('boxsdk.auth.jwt_auth.open', mock_open(read_data=key_file_read_data), create=True) as jwt_auth_open:
with patch('boxsdk.auth.jwt_auth.open', mock_open(read_data=rsa_private_key_bytes), create=True) as jwt_auth_open:
with patch('cryptography.hazmat.primitives.serialization.load_pem_private_key') as load_pem_private_key:
oauth = JWTAuth(
client_id=fake_client_id,
client_secret=fake_client_secret,
rsa_private_key_file_sys_path=sentinel.rsa_path,
rsa_private_key_file_sys_path=(sentinel.rsa_path if pass_private_key_by_path else None),
rsa_private_key_data=(None if pass_private_key_by_path else rsa_private_key_bytes),
rsa_private_key_passphrase=rsa_passphrase,
network_layer=mock_network_layer,
box_device_name='my_awesome_device',
Expand All @@ -87,11 +172,13 @@ def _jwt_auth_init_mocks(**kwargs):
enterprise_id=kwargs.pop('enterprise_id', None),
**kwargs
)

jwt_auth_open.assert_called_once_with(sentinel.rsa_path, 'rb')
jwt_auth_open.return_value.read.assert_called_once_with() # pylint:disable=no-member
if pass_private_key_by_path:
jwt_auth_open.assert_called_once_with(sentinel.rsa_path, 'rb')
jwt_auth_open.return_value.read.assert_called_once_with() # pylint:disable=no-member
else:
jwt_auth_open.assert_not_called()
load_pem_private_key.assert_called_once_with(
key_file_read_data,
rsa_private_key_bytes,
password=rsa_passphrase,
backend=default_backend(),
)
Expand Down
4 changes: 2 additions & 2 deletions test/unit/util/test_api_call_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ def func():

def test_api_call_decorated_method_must_be_a_cloneable_method():

class Cls(object):
class NonCloneable(object):
@api_call
def func(self):
pass

obj = Cls()
obj = NonCloneable()
with pytest.raises(TypeError):
obj.func()

Expand Down