diff --git a/boxsdk/auth/oauth2.py b/boxsdk/auth/oauth2.py index 0d5203582..dc08ee6c8 100644 --- a/boxsdk/auth/oauth2.py +++ b/boxsdk/auth/oauth2.py @@ -23,6 +23,7 @@ def __init__( client_id, client_secret, store_tokens=None, + retrieve_tokens=None, box_device_id='0', box_device_name='', access_token=None, @@ -42,6 +43,10 @@ def __init__( Optional callback for getting access to tokens for storing them. :type store_tokens: `callable` + :param retrieve_tokens: + Optional callback for retrieving tokens prior to refresh. + :type retrieve_tokens: + `callable` :param box_device_id: Optional unique ID of this device. Used for applications that want to support device-pinning. :type box_device_id: @@ -66,6 +71,7 @@ def __init__( self._client_id = client_id self._client_secret = client_secret self._store_tokens = store_tokens + self._retrieve_tokens = retrieve_tokens self._access_token = access_token self._refresh_token = refresh_token self._network_layer = network_layer if network_layer else DefaultNetwork() @@ -171,6 +177,12 @@ def refresh(self, access_token_to_refresh): with self._refresh_lock: # The lock here is for handling that case that multiple requests fail, due to access token expired, at the # same time to avoid multiple session renewals. + if self._retrieve_tokens: + # If the token retrieving callback returns True-ish, reset the tokens to the result and + # proceed with the usual processing flow. + tokens = self._retrieve_tokens() + if tokens: + self._access_token, self._refresh_token = tokens if access_token_to_refresh == self._access_token: # If the active access token is the same as the token needs to be refreshed, we make the request to # refresh the token. diff --git a/test/unit/auth/test_oauth2.py b/test/unit/auth/test_oauth2.py index bc0c7b6ef..0024afbfd 100644 --- a/test/unit/auth/test_oauth2.py +++ b/test/unit/auth/test_oauth2.py @@ -274,3 +274,116 @@ def test_token_request_allows_missing_refresh_token(mock_network_layer): network_layer=mock_network_layer, ) oauth.send_token_request({}, access_token=None, expect_refresh_token=False) + + +def test_retrieve_tokens_false_refreshes_tokens( + mock_network_layer, + successful_token_response): + # pylint:disable=redefined-outer-name + fake_client_id = 'fake_client_id' + fake_client_secret = 'fake_client_secret' + fake_refresh_token = 'fake_refresh_token' + fake_access_token = 'fake_access_token' + data = { + 'grant_type': 'refresh_token', + 'refresh_token': fake_refresh_token, + 'client_id': fake_client_id, + 'client_secret': fake_client_secret, + 'box_device_id': '0', + } + + def retrieve_tokens(): + return False + + mock_network_layer.request.return_value = successful_token_response + oauth = OAuth2( + client_id=fake_client_id, + client_secret=fake_client_secret, + access_token=fake_access_token, + refresh_token=fake_refresh_token, + network_layer=mock_network_layer, + retrieve_tokens=retrieve_tokens, + ) + + access_token, refresh_token = oauth.refresh(fake_access_token) + assert access_token == successful_token_response.json()['access_token'] + assert refresh_token == successful_token_response.json()['refresh_token'] + mock_network_layer.request.assert_called_once_with( + 'POST', + '{0}/token'.format(API.OAUTH2_API_URL), + data=data, + headers={'content-type': 'application/x-www-form-urlencoded'}, + access_token=fake_access_token, + ) + + +def test_retrieve_tokens_same_refreshes_tokens( + mock_network_layer, + successful_token_response): + # pylint:disable=redefined-outer-name + fake_client_id = 'fake_client_id' + fake_client_secret = 'fake_client_secret' + fake_refresh_token = 'fake_refresh_token' + fake_access_token = 'fake_access_token' + different_refresh_token = 'different_refresh_token' + data = { + 'grant_type': 'refresh_token', + 'refresh_token': different_refresh_token, + 'client_id': fake_client_id, + 'client_secret': fake_client_secret, + 'box_device_id': '0', + } + + def retrieve_tokens(): + return fake_access_token, different_refresh_token + + mock_network_layer.request.return_value = successful_token_response + oauth = OAuth2( + client_id=fake_client_id, + client_secret=fake_client_secret, + access_token=fake_access_token, + refresh_token=fake_refresh_token, + network_layer=mock_network_layer, + retrieve_tokens=retrieve_tokens, + ) + + access_token, refresh_token = oauth.refresh(fake_access_token) + assert access_token == successful_token_response.json()['access_token'] + assert refresh_token == successful_token_response.json()['refresh_token'] + mock_network_layer.request.assert_called_once_with( + 'POST', + '{0}/token'.format(API.OAUTH2_API_URL), + data=data, + headers={'content-type': 'application/x-www-form-urlencoded'}, + access_token=fake_access_token, + ) + + +def test_retrieve_tokens_different_does_not_refresh_tokens( + mock_network_layer, + successful_token_response): + # pylint:disable=redefined-outer-name + fake_client_id = 'fake_client_id' + fake_client_secret = 'fake_client_secret' + fake_refresh_token = 'fake_refresh_token' + fake_access_token = 'fake_access_token' + different_access_token = 'different_access_token' + different_refresh_token = 'different_refresh_token' + + def retrieve_tokens(): + return different_access_token, different_refresh_token + + mock_network_layer.request.return_value = successful_token_response + oauth = OAuth2( + client_id=fake_client_id, + client_secret=fake_client_secret, + access_token=fake_access_token, + refresh_token=fake_refresh_token, + network_layer=mock_network_layer, + retrieve_tokens=retrieve_tokens, + ) + + access_token, refresh_token = oauth.refresh(fake_access_token) + assert access_token == different_access_token + assert refresh_token == different_refresh_token + assert not mock_network_layer.request.called