diff --git a/README.rst b/README.rst index 21b3f0630..b157c9d55 100644 --- a/README.rst +++ b/README.rst @@ -323,12 +323,12 @@ These users can then be authenticated: ned_auth = JWTAuth( client_id='YOUR_CLIENT_ID', client_secret='YOUR_CLIENT_SECRET', - enterprise_id='YOUR_ENTERPRISE_ID', + user=ned_stark_user, jwt_key_id='YOUR_JWT_KEY_ID', rsa_private_key_file_sys_path='CERT.PEM', store_tokens=your_store_tokens_callback_method, ) - ned_auth.authenticate_app_user(ned_stark_user) + ned_auth.authenticate_user() ned_client = Client(ned_auth) Requests made with ``ned_client`` (or objects returned from ``ned_client``'s methods) @@ -396,7 +396,7 @@ Customization Custom Subclasses ~~~~~~~~~~~~~~~~~ -Custom subclasses of any SDK object with an ``_item_type`` field can be defined: +Custom object subclasses can be defined: .. code-block:: pycon @@ -407,12 +407,13 @@ Custom subclasses of any SDK object with an ``_item_type`` field can be defined: pass client = Client(oauth) + client.translator.register('folder', MyFolderSubclass) folder = client.folder('0') >>> print folder >>> -If a subclass of an SDK object with an ``_item_type`` field is defined, instances of this subclass will be +If an object subclass is registered in this way, instances of this subclass will be returned from all SDK methods that previously returned an instance of the parent. See ``BaseAPIJSONObjectMeta`` and ``Translator`` to see how the SDK performs dynamic lookups to determine return types. diff --git a/boxsdk/auth/jwt_auth.py b/boxsdk/auth/jwt_auth.py index 2d1f89598..e536e853a 100644 --- a/boxsdk/auth/jwt_auth.py +++ b/boxsdk/auth/jwt_auth.py @@ -9,8 +9,10 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization import jwt +from six import string_types, text_type from .oauth2 import OAuth2 +from ..object.user import User from ..util.compat import total_seconds @@ -28,6 +30,7 @@ def __init__( jwt_key_id, rsa_private_key_file_sys_path, rsa_private_key_passphrase=None, + user=None, store_tokens=None, box_device_id='0', box_device_name='', @@ -35,7 +38,12 @@ def __init__( network_layer=None, jwt_algorithm='RS256', ): - """ + """Extends baseclass method. + + If both `enterprise_id` and `user` are non-`None`, the `user` takes + precedence when `refresh()` is called. This can be overruled with a + call to `authenticate_instance()`. + :param client_id: Box API key used for identifying the application the user is authenticating with. :type client_id: @@ -46,8 +54,15 @@ def __init__( `unicode` :param enterprise_id: The ID of the Box Developer Edition enterprise. + + May be `None`, if the caller knows that it will not be + authenticating as an enterprise instance / service account. + + If `user` is passed, this value is not used, unless + `authenticate_instance()` is called to clear the user and + authenticate as the enterprise instance. :type enterprise_id: - `unicode` + `unicode` or `None` :param jwt_key_id: Key ID for the JWT assertion. :type jwt_key_id: @@ -60,6 +75,27 @@ def __init__( Passphrase used to unlock the private key. Do not pass a unicode string - this must be bytes. :type rsa_private_key_passphrase: `str` or None + :param user: + (optional) The user to authenticate, expressed as a Box User ID or + as a :class:`User` instance. + + This value is not required. But if it is provided, then the user + will be auto-authenticated at the time of the first API call or + when calling `authenticate_user()` without any arguments. + + Should be `None` if the intention is to authenticate as the + enterprise instance / service account. If both `enterprise_id` and + `user` are non-`None`, the `user` takes precedense when `refresh()` + is called. + + May be one of this application's created App User. Depending on the + configured User Access Level, may also be any other App User or + Managed User in the enterprise. + + + + :type user: + `unicode` or :class:`User` or `None` :param store_tokens: Optional callback for getting access to tokens for storing them. :type store_tokens: @@ -85,6 +121,7 @@ def __init__( :type jwt_algorithm: `unicode` """ + user_id = self._normalize_user_id(user) super(JWTAuth, self).__init__( client_id, client_secret, @@ -104,12 +141,12 @@ def __init__( self._enterprise_id = enterprise_id self._jwt_algorithm = jwt_algorithm self._jwt_key_id = jwt_key_id - self._user_id = None + self._user_id = user_id def _auth_with_jwt(self, sub, sub_type): """ Get an access token for use with Box Developer Edition. Pass an enterprise ID to get an enterprise token - (which can be used to provision/deprovision users), or a user ID to get an app user token. + (which can be used to provision/deprovision users), or a user ID to get a user token. :param sub: The enterprise ID or user ID to auth. @@ -157,31 +194,92 @@ def _auth_with_jwt(self, sub, sub_type): data['box_device_name'] = self._box_device_name return self.send_token_request(data, access_token=None, expect_refresh_token=False)[0] - def authenticate_app_user(self, user): + def authenticate_user(self, user=None): """ - Get an access token for an App User (part of Box Developer Edition). + Get an access token for a User. + + May be one of this application's created App User. Depending on the + configured User Access Level, may also be any other App User or Managed + User in the enterprise. + + + :param user: - The user to authenticate. + (optional) The user to authenticate, expressed as a Box User ID or + as a :class:`User` instance. + + If not given, then the most recently provided user ID, if + available, will be used. :type user: - :class:`User` + `unicode` or :class:`User` + :raises: + :exc:`ValueError` if no user ID was passed and the object is not + currently configured with one. :return: - The access token for the app user. + The access token for the user. :rtype: `unicode` """ - sub = self._user_id = user.object_id + sub = self._normalize_user_id(user) or self._user_id + if not sub: + raise ValueError("authenticate_user: Requires the user ID, but it was not provided.") + self._user_id = sub return self._auth_with_jwt(sub, 'user') - def authenticate_instance(self): + authenticate_app_user = authenticate_user + + @classmethod + def _normalize_user_id(cls, user): + """Get a Box user ID from a selection of supported param types. + + :param user: + An object representing the user or user ID. + + Currently supported types are `unicode` (which represents the user + ID) and :class:`User`. + + If `None`, returns `None`. + :raises: :exc:`TypeError` for unsupported types. + :rtype: `unicode` or `None` + """ + if user is None: + return None + if isinstance(user, User): + return user.object_id + if isinstance(user, string_types): + return text_type(user) + raise TypeError("Got unsupported type {0!r} for user.".format(user.__class__.__name__)) + + def authenticate_instance(self, enterprise=None): """ Get an access token for a Box Developer Edition enterprise. + :param enterprise: + The ID of the Box Developer Edition enterprise. + + Optional if the value was already given to `__init__`, + otherwise required. + :type enterprise: `unicode` or `None` + :raises: + :exc:`ValueError` if `None` was passed for the enterprise ID here + and in `__init__`, or if the non-`None` value passed here does not + match the non-`None` value passed to `__init__`. :return: The access token for the enterprise which can provision/deprovision app users. :rtype: `unicode` """ + enterprises = [enterprise, self._enterprise_id] + if not any(enterprises): + raise ValueError("authenticate_instance: Requires the enterprise ID, but it was not provided.") + if all(enterprises) and (enterprise != self._enterprise_id): + raise ValueError( + "authenticate_instance: Given enterprise ID {given_enterprise!r}, but {auth} already has ID {existing_enterprise!r}" + .format(auth=self, given_enterprise=enterprise, existing_enterprise=self._enterprise_id) + ) + if not self._enterprise_id: + self._enterprise_id = enterprise self._user_id = None return self._auth_with_jwt(self._enterprise_id, 'enterprise') @@ -195,4 +293,4 @@ def _refresh(self, access_token): if self._user_id is None: return self.authenticate_instance() else: - return self._auth_with_jwt(self._user_id, 'user') + return self.authenticate_user() diff --git a/boxsdk/session/box_session.py b/boxsdk/session/box_session.py index fbccabf34..a166a7888 100644 --- a/boxsdk/session/box_session.py +++ b/boxsdk/session/box_session.py @@ -193,7 +193,8 @@ def _renew_session(self, access_token_used): :type access_token_used: `unicode` """ - self._oauth.refresh(access_token_used) + new_access_token, _ = self._oauth.refresh(access_token_used) + return new_access_token @staticmethod def _is_json_response(network_response): @@ -390,6 +391,9 @@ def _make_request( # Since there can be session renewal happening in the middle of preparing the request, it's important to be # consistent with the access_token being used in the request. access_token_will_be_used = self._oauth.access_token + if auto_session_renewal and (access_token_will_be_used is None): + access_token_will_be_used = self._renew_session(None) + auto_session_renewal = False authorization_header = {'Authorization': 'Bearer {0}'.format(access_token_will_be_used)} if headers is None: headers = self._default_headers.copy() diff --git a/requirements.txt b/requirements.txt index 0e0ef6918..5e33286f8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,4 @@ pyjwt>=1.3.0 requests>=2.4.3 requests-toolbelt>=0.4.0 six >= 1.4.0 --e . +-e .[all] diff --git a/test/unit/auth/test_jwt_auth.py b/test/unit/auth/test_jwt_auth.py index 63b553c7c..e32a6868b 100644 --- a/test/unit/auth/test_jwt_auth.py +++ b/test/unit/auth/test_jwt_auth.py @@ -1,9 +1,10 @@ # coding: utf-8 -from __future__ import unicode_literals +from __future__ import absolute_import, unicode_literals from contextlib import contextmanager from datetime import datetime, timedelta +from itertools import product import json import random import string @@ -11,6 +12,7 @@ from cryptography.hazmat.backends import default_backend from mock import Mock, mock_open, patch, sentinel import pytest +from six import string_types, text_type from boxsdk.auth.jwt_auth import JWTAuth from boxsdk.config import API @@ -50,68 +52,138 @@ def successful_token_response(successful_token_mock, successful_token_json_respo return successful_token_mock -@contextmanager -def jwt_auth_init_mocks( - mock_network_layer, - successful_token_response, - jwt_algorithm, - jwt_key_id, - rsa_passphrase, - enterprise_id=None, -): +@pytest.fixture +def jwt_auth_init_mocks(mock_network_layer, successful_token_response, jwt_algorithm, jwt_key_id, rsa_passphrase): # pylint:disable=redefined-outer-name - fake_client_id = 'fake_client_id' - fake_client_secret = 'fake_client_secret' - assertion = Mock() - data = { - 'grant_type': JWTAuth._GRANT_TYPE, # pylint:disable=protected-access - 'client_id': fake_client_id, - 'client_secret': fake_client_secret, - 'assertion': assertion, - 'box_device_id': '0', - 'box_device_name': 'my_awesome_device', - } - - mock_network_layer.request.return_value = successful_token_response - key_file_read_data = b'key_file_read_data' - with patch('boxsdk.auth.jwt_auth.open', mock_open(read_data=key_file_read_data), create=True) as jwt_auth_open: - with patch('cryptography.hazmat.primitives.serialization.load_pem_private_key') as load_pem_private_key: - oauth = JWTAuth( - client_id=fake_client_id, - client_secret=fake_client_secret, - enterprise_id=enterprise_id, - rsa_private_key_file_sys_path=sentinel.rsa_path, - rsa_private_key_passphrase=rsa_passphrase, - network_layer=mock_network_layer, - box_device_name='my_awesome_device', - jwt_algorithm=jwt_algorithm, - jwt_key_id=jwt_key_id, - ) - jwt_auth_open.assert_called_once_with(sentinel.rsa_path, 'rb') - jwt_auth_open.return_value.read.assert_called_once_with() # pylint:disable=no-member - load_pem_private_key.assert_called_once_with( - key_file_read_data, - password=rsa_passphrase, - backend=default_backend(), + @contextmanager + def _jwt_auth_init_mocks(**kwargs): + assert_authed = kwargs.pop('assert_authed', True) + fake_client_id = 'fake_client_id' + fake_client_secret = 'fake_client_secret' + assertion = Mock() + data = { + 'grant_type': JWTAuth._GRANT_TYPE, # pylint:disable=protected-access + 'client_id': fake_client_id, + 'client_secret': fake_client_secret, + 'assertion': assertion, + 'box_device_id': '0', + 'box_device_name': 'my_awesome_device', + } + + mock_network_layer.request.return_value = successful_token_response + key_file_read_data = b'key_file_read_data' + with patch('boxsdk.auth.jwt_auth.open', mock_open(read_data=key_file_read_data), create=True) as jwt_auth_open: + with patch('cryptography.hazmat.primitives.serialization.load_pem_private_key') as load_pem_private_key: + oauth = JWTAuth( + client_id=fake_client_id, + client_secret=fake_client_secret, + rsa_private_key_file_sys_path=sentinel.rsa_path, + rsa_private_key_passphrase=rsa_passphrase, + network_layer=mock_network_layer, + box_device_name='my_awesome_device', + jwt_algorithm=jwt_algorithm, + jwt_key_id=jwt_key_id, + enterprise_id=kwargs.pop('enterprise_id', None), + **kwargs + ) + + jwt_auth_open.assert_called_once_with(sentinel.rsa_path, 'rb') + jwt_auth_open.return_value.read.assert_called_once_with() # pylint:disable=no-member + load_pem_private_key.assert_called_once_with( + key_file_read_data, + password=rsa_passphrase, + backend=default_backend(), + ) + + yield oauth, assertion, fake_client_id, load_pem_private_key.return_value + + if assert_authed: + mock_network_layer.request.assert_called_once_with( + 'POST', + '{0}/token'.format(API.OAUTH2_API_URL), + data=data, + headers={'content-type': 'application/x-www-form-urlencoded'}, + access_token=None, ) + assert oauth.access_token == successful_token_response.json()['access_token'] - yield oauth, assertion, fake_client_id, load_pem_private_key.return_value + return _jwt_auth_init_mocks - mock_network_layer.request.assert_called_once_with( - 'POST', - '{0}/token'.format(API.OAUTH2_API_URL), - data=data, - headers={'content-type': 'application/x-www-form-urlencoded'}, - access_token=None, - ) - assert oauth.access_token == successful_token_response.json()['access_token'] +def test_refresh_authenticates_with_user_if_enterprise_id_and_user_both_passed_to_constructor(jwt_auth_init_and_auth_mocks): + user = 'fake_user_id' + with jwt_auth_init_and_auth_mocks(sub=user, sub_type='user', enterprise_id='fake_enterprise_id', user=user) as oauth: + oauth.refresh(None) -@contextmanager -def jwt_auth_auth_mocks(jti_length, jwt_algorithm, jwt_key_id, sub, sub_type, oauth, assertion, client_id, secret): - # pylint:disable=redefined-outer-name - with patch('jwt.encode') as jwt_encode: + +@pytest.mark.parametrize('jwt_auth_method_name', ['authenticate_user', 'authenticate_instance']) +def test_authenticate_raises_value_error_if_sub_was_never_given(jwt_auth_init_mocks, jwt_auth_method_name): + with jwt_auth_init_mocks(assert_authed=False) as params: + auth = params[0] + authenticate_method = getattr(auth, jwt_auth_method_name) + with pytest.raises(ValueError): + authenticate_method() + + +def test_jwt_auth_constructor_raises_type_error_if_user_is_unsupported_type(jwt_auth_init_mocks): + with pytest.raises(TypeError): + with jwt_auth_init_mocks(user=object()): + assert False + + +def test_authenticate_user_raises_type_error_if_user_is_unsupported_type(jwt_auth_init_mocks): + with jwt_auth_init_mocks(assert_authed=False) as params: + auth = params[0] + with pytest.raises(TypeError): + auth.authenticate_user(object()) + + +@pytest.mark.parametrize('user_id_for_init', [None, 'fake_user_id_1']) +def test_authenticate_user_saves_user_id_for_future_calls(jwt_auth_init_and_auth_mocks, user_id_for_init, jwt_encode): + + def assert_jwt_encode_call_args(user_id): + assert jwt_encode.call_args[0][0]['sub'] == user_id + assert jwt_encode.call_args[0][0]['box_sub_type'] == 'user' + jwt_encode.call_args = None + + with jwt_auth_init_and_auth_mocks(sub=None, sub_type=None, assert_authed=False, user=user_id_for_init) as auth: + for new_user_id in ['fake_user_id_2', 'fake_user_id_3']: + auth.authenticate_user(new_user_id) + assert_jwt_encode_call_args(new_user_id) + auth.authenticate_user() + assert_jwt_encode_call_args(new_user_id) + + +def test_authenticate_instance_raises_value_error_if_different_enterprise_id_is_given(jwt_auth_init_mocks): + with jwt_auth_init_mocks(enterprise_id='fake_enterprise_id_1', assert_authed=False) as params: + auth = params[0] + with pytest.raises(ValueError): + auth.authenticate_instance('fake_enterprise_id_2') + + +def test_authenticate_instance_saves_enterprise_id_for_future_calls(jwt_auth_init_and_auth_mocks): + enterprise_id = 'fake_enterprise_id' + with jwt_auth_init_and_auth_mocks(sub=enterprise_id, sub_type='enterprise', assert_authed=False) as auth: + auth.authenticate_instance(enterprise_id) + auth.authenticate_instance() + auth.authenticate_instance(enterprise_id) + with pytest.raises(ValueError): + auth.authenticate_instance('fake_enterprise_id_2') + + +@pytest.yield_fixture +def jwt_encode(): + with patch('jwt.encode') as patched_jwt_encode: + yield patched_jwt_encode + + +@pytest.fixture +def jwt_auth_auth_mocks(jti_length, jwt_algorithm, jwt_key_id, jwt_encode): + + @contextmanager + def _jwt_auth_auth_mocks(sub, sub_type, oauth, assertion, client_id, secret, assert_authed=True): + # pylint:disable=redefined-outer-name with patch('boxsdk.auth.jwt_auth.datetime') as mock_datetime: with patch('boxsdk.auth.jwt_auth.random.SystemRandom') as mock_system_random: jwt_encode.return_value = assertion @@ -129,88 +201,79 @@ def jwt_auth_auth_mocks(jti_length, jwt_algorithm, jwt_key_id, sub, sub_type, oa yield oauth - system_random.randint.assert_called_once_with(16, 128) - assert len(system_random.random.mock_calls) == jti_length - jwt_encode.assert_called_once_with({ - 'iss': client_id, - 'sub': sub, - 'box_sub_type': sub_type, - 'aud': 'https://api.box.com/oauth2/token', - 'jti': jti, - 'exp': exp, - }, secret, algorithm=jwt_algorithm, headers={'kid': jwt_key_id}) - - -def test_authenticate_app_user_sends_post_request_with_correct_params( - mock_network_layer, - successful_token_response, - jti_length, - jwt_algorithm, - jwt_key_id, - rsa_passphrase, -): + if assert_authed: + system_random.randint.assert_called_once_with(16, 128) + assert len(system_random.random.mock_calls) == jti_length + jwt_encode.assert_called_once_with({ + 'iss': client_id, + 'sub': sub, + 'box_sub_type': sub_type, + 'aud': 'https://api.box.com/oauth2/token', + 'jti': jti, + 'exp': exp, + }, secret, algorithm=jwt_algorithm, headers={'kid': jwt_key_id}) + + return _jwt_auth_auth_mocks + + +@pytest.fixture +def jwt_auth_init_and_auth_mocks(jwt_auth_init_mocks, jwt_auth_auth_mocks): + + @contextmanager + def _jwt_auth_init_and_auth_mocks(sub, sub_type, *jwt_auth_init_mocks_args, **jwt_auth_init_mocks_kwargs): + assert_authed = jwt_auth_init_mocks_kwargs.pop('assert_authed', True) + with jwt_auth_init_mocks(*jwt_auth_init_mocks_args, assert_authed=assert_authed, **jwt_auth_init_mocks_kwargs) as params: + with jwt_auth_auth_mocks(sub, sub_type, *params, assert_authed=assert_authed) as oauth: + yield oauth + + return _jwt_auth_init_and_auth_mocks + + +@pytest.mark.parametrize( + ('user', 'pass_in_init'), + list(product([str('fake_user_id'), text_type('fake_user_id'), User(None, 'fake_user_id')], [False, True])), +) +def test_authenticate_user_sends_post_request_with_correct_params(jwt_auth_init_and_auth_mocks, user, pass_in_init): # pylint:disable=redefined-outer-name - fake_user_id = 'fake_user_id' - with jwt_auth_init_mocks(mock_network_layer, successful_token_response, jwt_algorithm, jwt_key_id, rsa_passphrase) as params: - with jwt_auth_auth_mocks(jti_length, jwt_algorithm, jwt_key_id, fake_user_id, 'user', *params) as oauth: - oauth.authenticate_app_user(User(None, fake_user_id)) - - -def test_authenticate_instance_sends_post_request_with_correct_params( - mock_network_layer, - successful_token_response, - jti_length, - jwt_algorithm, - jwt_key_id, - rsa_passphrase, -): + if isinstance(user, User): + user_id = user.object_id + elif isinstance(user, string_types): + user_id = user + else: + raise NotImplementedError + init_kwargs = {} + authenticate_params = [] + if pass_in_init: + init_kwargs['user'] = user + else: + authenticate_params.append(user) + with jwt_auth_init_and_auth_mocks(user_id, 'user', **init_kwargs) as oauth: + oauth.authenticate_user(*authenticate_params) + + +@pytest.mark.parametrize(('pass_in_init', 'pass_in_auth'), [(True, False), (False, True), (True, True)]) +def test_authenticate_instance_sends_post_request_with_correct_params(jwt_auth_init_and_auth_mocks, pass_in_init, pass_in_auth): # pylint:disable=redefined-outer-name enterprise_id = 'fake_enterprise_id' - with jwt_auth_init_mocks( - mock_network_layer, - successful_token_response, - jwt_algorithm, - jwt_key_id, - rsa_passphrase, - enterprise_id, - ) as params: - with jwt_auth_auth_mocks(jti_length, jwt_algorithm, jwt_key_id, enterprise_id, 'enterprise', *params) as oauth: - oauth.authenticate_instance() - - -def test_refresh_app_user_sends_post_request_with_correct_params( - mock_network_layer, - successful_token_response, - jti_length, - jwt_algorithm, - jwt_key_id, - rsa_passphrase, -): + init_kwargs = {} + auth_params = [] + if pass_in_init: + init_kwargs['enterprise_id'] = enterprise_id + if pass_in_auth: + auth_params.append(enterprise_id) + with jwt_auth_init_and_auth_mocks(enterprise_id, 'enterprise', **init_kwargs) as oauth: + oauth.authenticate_instance(*auth_params) + + +def test_refresh_app_user_sends_post_request_with_correct_params(jwt_auth_init_and_auth_mocks): # pylint:disable=redefined-outer-name fake_user_id = 'fake_user_id' - with jwt_auth_init_mocks(mock_network_layer, successful_token_response, jwt_algorithm, jwt_key_id, rsa_passphrase) as params: - with jwt_auth_auth_mocks(jti_length, jwt_algorithm, jwt_key_id, fake_user_id, 'user', *params) as oauth: - oauth._user_id = fake_user_id # pylint:disable=protected-access - oauth.refresh(None) - - -def test_refresh_instance_sends_post_request_with_correct_params( - mock_network_layer, - successful_token_response, - jti_length, - jwt_algorithm, - jwt_key_id, - rsa_passphrase, -): + with jwt_auth_init_and_auth_mocks(fake_user_id, 'user', user=fake_user_id) as oauth: + oauth.refresh(None) + + +def test_refresh_instance_sends_post_request_with_correct_params(jwt_auth_init_and_auth_mocks): # pylint:disable=redefined-outer-name enterprise_id = 'fake_enterprise_id' - with jwt_auth_init_mocks( - mock_network_layer, - successful_token_response, - jwt_algorithm, - jwt_key_id, - rsa_passphrase, - enterprise_id, - ) as params: - with jwt_auth_auth_mocks(jti_length, jwt_algorithm, jwt_key_id, enterprise_id, 'enterprise', *params) as oauth: - oauth.refresh(None) + with jwt_auth_init_and_auth_mocks(enterprise_id, 'enterprise', enterprise_id=enterprise_id) as oauth: + oauth.refresh(None) diff --git a/test/unit/session/test_box_session.py b/test/unit/session/test_box_session.py index 9b837af2b..ae6b24fb2 100644 --- a/test/unit/session/test_box_session.py +++ b/test/unit/session/test_box_session.py @@ -6,7 +6,7 @@ from io import IOBase from numbers import Number -from mock import MagicMock, Mock, call +from mock import MagicMock, Mock, PropertyMock, call import pytest from boxsdk.auth.oauth2 import OAuth2 @@ -23,13 +23,26 @@ def translator(default_translator, request): # pylint:disable=unused-argument @pytest.fixture -def box_session(translator): - mock_oauth = Mock(OAuth2) - mock_oauth.access_token = 'fake_access_token' +def initial_access_token(): + return 'fake_access_token' - mock_network_layer = Mock(DefaultNetwork) - return BoxSession(mock_oauth, mock_network_layer, translator=translator) +@pytest.fixture +def mock_oauth(initial_access_token): + mock_oauth = MagicMock(OAuth2) + mock_oauth.access_token = initial_access_token + return mock_oauth + + +@pytest.fixture +def mock_network_layer(): + return Mock(DefaultNetwork) + + +@pytest.fixture +def box_session(mock_oauth, mock_network_layer, translator): + # pylint:disable=redefined-outer-name + return BoxSession(oauth=mock_oauth, network_layer=mock_network_layer, translator=translator) @pytest.mark.parametrize('test_method', [ @@ -42,18 +55,68 @@ def box_session(translator): def test_box_session_handles_unauthorized_response( test_method, box_session, + mock_oauth, + mock_network_layer, unauthorized_response, generic_successful_response, test_url, ): - # pylint:disable=redefined-outer-name, protected-access - mock_network_layer = box_session._network_layer - mock_network_layer.request.side_effect = [unauthorized_response, generic_successful_response] + # pylint:disable=redefined-outer-name + + def get_access_token_from_auth_object(): + return mock_oauth.access_token + + mock_network_layer.request.side_effect = mock_responses = [unauthorized_response, generic_successful_response] + for mock_response in mock_responses: + type(mock_response).access_token_used = PropertyMock(side_effect=get_access_token_from_auth_object) + + def refresh(access_token_used): + assert access_token_used == mock_oauth.access_token + mock_oauth.access_token = 'fake_new_access_token' + return (mock_oauth.access_token, None) + + mock_oauth.refresh.side_effect = refresh box_response = test_method(box_session, url=test_url) assert box_response.status_code == 200 +@pytest.mark.parametrize('test_method', [ + BoxSession.get, + BoxSession.post, + BoxSession.put, + BoxSession.delete, + BoxSession.options, +]) +@pytest.mark.parametrize('initial_access_token', [None]) +def test_box_session_gets_access_token_before_request( + test_method, + box_session, + mock_oauth, + mock_network_layer, + generic_successful_response, + test_url, +): + # pylint:disable=redefined-outer-name + + def get_access_token_from_auth_object(): + return mock_oauth.access_token + + mock_network_layer.request.side_effect = mock_responses = [generic_successful_response] + for mock_response in mock_responses: + type(mock_response).access_token_used = PropertyMock(side_effect=get_access_token_from_auth_object) + + def refresh(access_token_used): + assert access_token_used == mock_oauth.access_token + mock_oauth.access_token = 'fake_new_access_token' + return (mock_oauth.access_token, None) + + mock_oauth.refresh.side_effect = refresh + + box_response = test_method(box_session, url=test_url, auto_session_renewal=True) + assert box_response.status_code == 200 + + @pytest.mark.parametrize('test_method', [ BoxSession.get, BoxSession.post, @@ -65,12 +128,12 @@ def test_box_session_handles_unauthorized_response( def test_box_session_retries_response_after_retry_after( test_method, box_session, + mock_network_layer, retry_after_response, generic_successful_response, test_url, ): - # pylint:disable=redefined-outer-name, protected-access - mock_network_layer = box_session._network_layer + # pylint:disable=redefined-outer-name mock_network_layer.request.side_effect = [retry_after_response, generic_successful_response] mock_network_layer.retry_after.side_effect = lambda delay, request, *args, **kwargs: request(*args, **kwargs) @@ -92,12 +155,12 @@ def test_box_session_retries_response_after_retry_after( def test_box_session_retries_request_after_server_error( test_method, box_session, + mock_network_layer, server_error_response, generic_successful_response, test_url, ): - # pylint:disable=redefined-outer-name, protected-access - mock_network_layer = box_session._network_layer + # pylint:disable=redefined-outer-name mock_network_layer.request.side_effect = [server_error_response, server_error_response, generic_successful_response] mock_network_layer.retry_after.side_effect = lambda delay, request, *args, **kwargs: request(*args, **kwargs) @@ -113,9 +176,8 @@ def test_box_session_retries_request_after_server_error( assert mock_network_layer.retry_after.call_args_list[1][0][0] == 2 -def test_box_session_seeks_file_after_retry(box_session, server_error_response, generic_successful_response, test_url): - # pylint:disable=redefined-outer-name, protected-access - mock_network_layer = box_session._network_layer +def test_box_session_seeks_file_after_retry(box_session, mock_network_layer, server_error_response, generic_successful_response, test_url): + # pylint:disable=redefined-outer-name mock_network_layer.request.side_effect = [server_error_response, generic_successful_response] mock_network_layer.retry_after.side_effect = lambda delay, request, *args, **kwargs: request(*args, **kwargs) mock_file_1, mock_file_2 = MagicMock(IOBase), MagicMock(IOBase) @@ -137,27 +199,24 @@ def test_box_session_seeks_file_after_retry(box_session, server_error_response, assert mock_file_2.seek.has_calls(call(3) * 2) -def test_box_session_raises_for_non_json_response(box_session, non_json_response, test_url): - # pylint:disable=redefined-outer-name, protected-access - mock_network_layer = box_session._network_layer +def test_box_session_raises_for_non_json_response(box_session, mock_network_layer, non_json_response, test_url): + # pylint:disable=redefined-outer-name mock_network_layer.request.side_effect = [non_json_response] with pytest.raises(BoxAPIException): box_session.get(url=test_url) -def test_box_session_raises_for_failed_response(box_session, bad_network_response, test_url): - # pylint:disable=redefined-outer-name, protected-access - mock_network_layer = box_session._network_layer +def test_box_session_raises_for_failed_response(box_session, mock_network_layer, bad_network_response, test_url): + # pylint:disable=redefined-outer-name mock_network_layer.request.side_effect = [bad_network_response] with pytest.raises(BoxAPIException): box_session.get(url=test_url) -def test_box_session_raises_for_failed_non_json_response(box_session, failed_non_json_response, test_url): - # pylint:disable=redefined-outer-name, protected-access - mock_network_layer = box_session._network_layer +def test_box_session_raises_for_failed_non_json_response(box_session, mock_network_layer, failed_non_json_response, test_url): + # pylint:disable=redefined-outer-name mock_network_layer.request.side_effect = [failed_non_json_response] with pytest.raises(BoxAPIException):