From ce48e3f7d75955a36e2a853ce858be12083fc61c Mon Sep 17 00:00:00 2001 From: Evan Bruhn Date: Sun, 8 May 2022 21:22:58 +1000 Subject: [PATCH] Block requests to Logi API if token refresh fails with 4xx exception (#12) --- logi_circle/__init__.py | 13 ++++++++++--- logi_circle/auth.py | 39 ++++++++++++++++++++++++++++----------- logi_circle/exception.py | 4 ++++ pytest.ini | 2 ++ tests/test_auth.py | 23 ++++++++++++++++++++++- tests/test_init.py | 22 +++++++++++++++++++++- 6 files changed, 87 insertions(+), 16 deletions(-) create mode 100644 pytest.ini diff --git a/logi_circle/__init__.py b/logi_circle/__init__.py index 455a9bf..c924327 100644 --- a/logi_circle/__init__.py +++ b/logi_circle/__init__.py @@ -14,7 +14,7 @@ from .auth import AuthProvider from .camera import Camera from .subscription import Subscription -from .exception import NotAuthorized, AuthorizationFailed +from .exception import NotAuthorized, AuthorizationFailed, SessionInvalidated from .utils import _get_ids_for_cameras _LOGGER = logging.getLogger(__name__) @@ -122,6 +122,14 @@ def subscriptions(self): """Returns all WS subscriptions.""" return self._subscriptions + def _check_readiness(self): + """Checks that this library is ready to submit requests to the Logi Circle API""" + if not self.auth_provider.authorized: + raise NotAuthorized('No access token available for this client ID') + + if self.auth_provider.invalid: + raise SessionInvalidated('Logi API session invalidated due to 4xx exception refreshing token') + async def _fetch(self, url, method='GET', @@ -134,8 +142,7 @@ async def _fetch(self, """Query data from the Logi Circle API.""" # pylint: disable=too-many-locals - if not self.auth_provider.authorized: - raise NotAuthorized('No access token available for this client ID') + self._check_readiness() base_headers = { 'X-API-Key': self.api_key, diff --git a/logi_circle/auth.py b/logi_circle/auth.py index 8a3b7e6..7497f6f 100644 --- a/logi_circle/auth.py +++ b/logi_circle/auth.py @@ -8,7 +8,7 @@ import aiohttp from .const import AUTH_BASE, AUTH_ENDPOINT, TOKEN_ENDPOINT -from .exception import AuthorizationFailed, NotAuthorized +from .exception import AuthorizationFailed, NotAuthorized, SessionInvalidated _LOGGER = logging.getLogger(__name__) @@ -24,6 +24,7 @@ def __init__(self, client_id, client_secret, redirect_uri, scopes, cache_file, l self.cache_file = cache_file self.logi = logi_base self.tokens = self._read_token() + self.invalid = False self.session = None @property @@ -101,21 +102,37 @@ async def close(self): async def _authenticate(self, payload): """Request or refresh the access token with Logi Circle""" + if self.invalid: + raise SessionInvalidated('Logi API session invalidated due to 4xx exception refreshing token') session = await self.get_session() async with session.post(AUTH_BASE + TOKEN_ENDPOINT, data=payload) as req: - response = await req.json() - - if req.status >= 400: + try: + response = await req.json() + + if req.status >= 400: + self.logi.is_connected = False + if req.status >= 400 and req.status < 500: + self.invalid = True + + error_message = response.get( + "error_description", "Non-OK code %s returned" % (req.status)) + raise AuthorizationFailed(error_message) + + # Authorization succeeded. Persist the refresh and access tokens. + self.logi.is_connected = True + self.tokens[self.client_id] = response + self._save_token() + except aiohttp.ContentTypeError: + response = await req.text() self.logi.is_connected = False - error_message = response.get( - "error_description", "Non-OK code %s returned" % (req.status)) - raise AuthorizationFailed(error_message) + if req.status >= 400 and req.status < 500: + self.invalid = True - # Authorization succeeded. Persist the refresh and access tokens. - self.logi.is_connected = True - self.tokens[self.client_id] = response - self._save_token() + if req.status >= 400: + raise AuthorizationFailed("Non-OK code %s returned: %s" % (req.status, response)) + else: + raise AuthorizationFailed("Unexpected content type from Logi API: %s" % (response)) async def get_session(self): """Returns a aiohttp session, creating one if it doesn't already exist.""" diff --git a/logi_circle/exception.py b/logi_circle/exception.py index 943c5ba..ca17ee0 100644 --- a/logi_circle/exception.py +++ b/logi_circle/exception.py @@ -5,6 +5,10 @@ class AuthorizationFailed(Exception): """When authorization fails for any reason.""" +class SessionInvalidated(Exception): + """When authorization is attempted on an invalidated session.""" + + class NotAuthorized(Exception): """When supplied client ID has not been authorized.""" diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..74c5ad3 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +asyncio_mode=strict \ No newline at end of file diff --git a/tests/test_auth.py b/tests/test_auth.py index b471c60..7d59c1f 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -5,7 +5,7 @@ import aresponses from tests.test_base import LogiUnitTestBase from logi_circle.const import AUTH_HOST, TOKEN_ENDPOINT, DEFAULT_SCOPES -from logi_circle.exception import NotAuthorized, AuthorizationFailed +from logi_circle.exception import NotAuthorized, AuthorizationFailed, SessionInvalidated class TestAuth(LogiUnitTestBase): @@ -99,6 +99,27 @@ async def run_test(): self.loop.run_until_complete(run_test()) + def test_session_invalidation(self): + """Test session invalidation.""" + logi = self.logi + + async def run_test(): + async with aresponses.ResponsesMockServer(loop=self.loop) as arsps: + arsps.add(AUTH_HOST, TOKEN_ENDPOINT, 'post', + aresponses.Response(status=401, + text=self.fixtures['failed_authorization'], + headers={'content-type': 'application/json'})) + + # Mock authorization, and verify AuthProvider state + with self.assertRaises(AuthorizationFailed): + await logi.authorize('letmein') + + # Attempt authorisation again + with self.assertRaises(SessionInvalidated): + await logi.authorize('letmein') + + self.loop.run_until_complete(run_test()) + def test_token_persistence(self): """Test that token is loaded from the cache file implicitly.""" diff --git a/tests/test_init.py b/tests/test_init.py index d39ebda..051d01a 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -6,7 +6,7 @@ from tests.test_base import LogiUnitTestBase from logi_circle import LogiCircle from logi_circle.const import AUTH_HOST, TOKEN_ENDPOINT, API_HOST, ACCESSORIES_ENDPOINT, DEFAULT_FFMPEG_BIN -from logi_circle.exception import NotAuthorized, AuthorizationFailed +from logi_circle.exception import NotAuthorized, AuthorizationFailed, SessionInvalidated class TestAuth(LogiUnitTestBase): @@ -101,6 +101,26 @@ async def run_test(): await logi._fetch(url='/api') self.loop.run_until_complete(run_test()) + + def test_fetch_guard_invalid_session(self): + """Fetch should bail out if session was invalidated""" + logi = self.logi + logi.auth_provider = self.get_authorized_auth_provider() + + async def run_test(): + async with aresponses.ResponsesMockServer(loop=self.loop) as arsps: + arsps.add(API_HOST, '/api', 'get', + aresponses.Response(status=401)) + arsps.add(AUTH_HOST, TOKEN_ENDPOINT, 'post', + aresponses.Response(status=429, + text='too many requests')) + + with self.assertRaises(AuthorizationFailed): + await logi._fetch(url='/api') + with self.assertRaises(SessionInvalidated): + await logi._fetch(url='/api') + + self.loop.run_until_complete(run_test()) def test_fetch_raw(self): """Fetch should return ClientResponse object if raw parameter set"""