Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions boxsdk/auth/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
client_id,
client_secret,
store_tokens=None,
retrieve_tokens=None,
box_device_id='0',
box_device_name='',
access_token=None,
Expand All @@ -42,6 +43,10 @@ def __init__(
Optional callback for getting access to tokens for storing them.
:type store_tokens:
`callable`
:param retrieve_tokens:
Optional callback for retrieving tokens prior to refresh.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's expand on this slightly.

We should specify that the callback takes no arguments, and that it should either return None (if no credentials are currently stored), or a (access token, refresh token) tuple.

And maybe we should add a brief description of what this feature would be used for. Something like "When using multiple processes or multiple separate OAuth2 objects, provide a retrieve_tokens function so that the same credentials can be reused across all your OAuth2 objects. The first object that needs to authenticate will do so, and store the tokens with the store_tokens callback (or the authentication and storage might happen entirely outside of the SDK). All subsequent objects, rather than re-authenticating, will use the tokens provided by retrieve_tokens. And the same behavior will occur when they all need to do token refresh." Feel free to reword or make that more concise.

:type retrieve_tokens:
`callable`
:param box_device_id:
Optional unique ID of this device. Used for applications that want to support device-pinning.
:type box_device_id:
Expand All @@ -66,6 +71,7 @@ def __init__(
self._client_id = client_id
self._client_secret = client_secret
self._store_tokens = store_tokens
self._retrieve_tokens = retrieve_tokens
self._access_token = access_token
self._refresh_token = refresh_token
self._network_layer = network_layer if network_layer else DefaultNetwork()
Expand Down Expand Up @@ -171,6 +177,12 @@ def refresh(self, access_token_to_refresh):
with self._refresh_lock:
# The lock here is for handling that case that multiple requests fail, due to access token expired, at the
# same time to avoid multiple session renewals.
if self._retrieve_tokens:
# If the token retrieving callback returns True-ish, reset the tokens to the result and
# proceed with the usual processing flow.
tokens = self._retrieve_tokens()
if tokens:
self._access_token, self._refresh_token = tokens
if access_token_to_refresh == self._access_token:
# If the active access token is the same as the token needs to be refreshed, we make the request to
# refresh the token.
Expand Down
113 changes: 113 additions & 0 deletions test/unit/auth/test_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,116 @@ def test_token_request_allows_missing_refresh_token(mock_network_layer):
network_layer=mock_network_layer,
)
oauth.send_token_request({}, access_token=None, expect_refresh_token=False)


def test_retrieve_tokens_false_refreshes_tokens(
mock_network_layer,
successful_token_response):
# pylint:disable=redefined-outer-name
fake_client_id = 'fake_client_id'
fake_client_secret = 'fake_client_secret'
fake_refresh_token = 'fake_refresh_token'
fake_access_token = 'fake_access_token'
data = {
'grant_type': 'refresh_token',
'refresh_token': fake_refresh_token,
'client_id': fake_client_id,
'client_secret': fake_client_secret,
'box_device_id': '0',
}

def retrieve_tokens():
return False

mock_network_layer.request.return_value = successful_token_response
oauth = OAuth2(
client_id=fake_client_id,
client_secret=fake_client_secret,
access_token=fake_access_token,
refresh_token=fake_refresh_token,
network_layer=mock_network_layer,
retrieve_tokens=retrieve_tokens,
)

access_token, refresh_token = oauth.refresh(fake_access_token)
assert access_token == successful_token_response.json()['access_token']
assert refresh_token == successful_token_response.json()['refresh_token']
mock_network_layer.request.assert_called_once_with(
'POST',
'{0}/token'.format(API.OAUTH2_API_URL),
data=data,
headers={'content-type': 'application/x-www-form-urlencoded'},
access_token=fake_access_token,
)


def test_retrieve_tokens_same_refreshes_tokens(
mock_network_layer,
successful_token_response):
# pylint:disable=redefined-outer-name
fake_client_id = 'fake_client_id'
fake_client_secret = 'fake_client_secret'
fake_refresh_token = 'fake_refresh_token'
fake_access_token = 'fake_access_token'
different_refresh_token = 'different_refresh_token'
data = {
'grant_type': 'refresh_token',
'refresh_token': different_refresh_token,
'client_id': fake_client_id,
'client_secret': fake_client_secret,
'box_device_id': '0',
}

def retrieve_tokens():
return fake_access_token, different_refresh_token

mock_network_layer.request.return_value = successful_token_response
oauth = OAuth2(
client_id=fake_client_id,
client_secret=fake_client_secret,
access_token=fake_access_token,
refresh_token=fake_refresh_token,
network_layer=mock_network_layer,
retrieve_tokens=retrieve_tokens,
)

access_token, refresh_token = oauth.refresh(fake_access_token)
assert access_token == successful_token_response.json()['access_token']
assert refresh_token == successful_token_response.json()['refresh_token']
mock_network_layer.request.assert_called_once_with(
'POST',
'{0}/token'.format(API.OAUTH2_API_URL),
data=data,
headers={'content-type': 'application/x-www-form-urlencoded'},
access_token=fake_access_token,
)


def test_retrieve_tokens_different_does_not_refresh_tokens(
mock_network_layer,
successful_token_response):
# pylint:disable=redefined-outer-name
fake_client_id = 'fake_client_id'
fake_client_secret = 'fake_client_secret'
fake_refresh_token = 'fake_refresh_token'
fake_access_token = 'fake_access_token'
different_access_token = 'different_access_token'
different_refresh_token = 'different_refresh_token'

def retrieve_tokens():
return different_access_token, different_refresh_token

mock_network_layer.request.return_value = successful_token_response
oauth = OAuth2(
client_id=fake_client_id,
client_secret=fake_client_secret,
access_token=fake_access_token,
refresh_token=fake_refresh_token,
network_layer=mock_network_layer,
retrieve_tokens=retrieve_tokens,
)

access_token, refresh_token = oauth.refresh(fake_access_token)
assert access_token == different_access_token
assert refresh_token == different_refresh_token
assert not mock_network_layer.request.called