Skip to content

Commit

Permalink
Block requests to Logi API if token refresh fails with 4xx exception (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
evanjd committed May 8, 2022
1 parent 842a899 commit ce48e3f
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 16 deletions.
13 changes: 10 additions & 3 deletions logi_circle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .auth import AuthProvider
from .camera import Camera
from .subscription import Subscription
from .exception import NotAuthorized, AuthorizationFailed
from .exception import NotAuthorized, AuthorizationFailed, SessionInvalidated
from .utils import _get_ids_for_cameras

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -122,6 +122,14 @@ def subscriptions(self):
"""Returns all WS subscriptions."""
return self._subscriptions

def _check_readiness(self):
"""Checks that this library is ready to submit requests to the Logi Circle API"""
if not self.auth_provider.authorized:
raise NotAuthorized('No access token available for this client ID')

if self.auth_provider.invalid:
raise SessionInvalidated('Logi API session invalidated due to 4xx exception refreshing token')

async def _fetch(self,
url,
method='GET',
Expand All @@ -134,8 +142,7 @@ async def _fetch(self,
"""Query data from the Logi Circle API."""
# pylint: disable=too-many-locals

if not self.auth_provider.authorized:
raise NotAuthorized('No access token available for this client ID')
self._check_readiness()

base_headers = {
'X-API-Key': self.api_key,
Expand Down
39 changes: 28 additions & 11 deletions logi_circle/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import aiohttp

from .const import AUTH_BASE, AUTH_ENDPOINT, TOKEN_ENDPOINT
from .exception import AuthorizationFailed, NotAuthorized
from .exception import AuthorizationFailed, NotAuthorized, SessionInvalidated

_LOGGER = logging.getLogger(__name__)

Expand All @@ -24,6 +24,7 @@ def __init__(self, client_id, client_secret, redirect_uri, scopes, cache_file, l
self.cache_file = cache_file
self.logi = logi_base
self.tokens = self._read_token()
self.invalid = False
self.session = None

@property
Expand Down Expand Up @@ -101,21 +102,37 @@ async def close(self):

async def _authenticate(self, payload):
"""Request or refresh the access token with Logi Circle"""
if self.invalid:
raise SessionInvalidated('Logi API session invalidated due to 4xx exception refreshing token')

session = await self.get_session()
async with session.post(AUTH_BASE + TOKEN_ENDPOINT, data=payload) as req:
response = await req.json()

if req.status >= 400:
try:
response = await req.json()

if req.status >= 400:
self.logi.is_connected = False
if req.status >= 400 and req.status < 500:
self.invalid = True

error_message = response.get(
"error_description", "Non-OK code %s returned" % (req.status))
raise AuthorizationFailed(error_message)

# Authorization succeeded. Persist the refresh and access tokens.
self.logi.is_connected = True
self.tokens[self.client_id] = response
self._save_token()
except aiohttp.ContentTypeError:
response = await req.text()
self.logi.is_connected = False
error_message = response.get(
"error_description", "Non-OK code %s returned" % (req.status))
raise AuthorizationFailed(error_message)
if req.status >= 400 and req.status < 500:
self.invalid = True

# Authorization succeeded. Persist the refresh and access tokens.
self.logi.is_connected = True
self.tokens[self.client_id] = response
self._save_token()
if req.status >= 400:
raise AuthorizationFailed("Non-OK code %s returned: %s" % (req.status, response))
else:
raise AuthorizationFailed("Unexpected content type from Logi API: %s" % (response))

async def get_session(self):
"""Returns a aiohttp session, creating one if it doesn't already exist."""
Expand Down
4 changes: 4 additions & 0 deletions logi_circle/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ class AuthorizationFailed(Exception):
"""When authorization fails for any reason."""


class SessionInvalidated(Exception):
"""When authorization is attempted on an invalidated session."""


class NotAuthorized(Exception):
"""When supplied client ID has not been authorized."""

Expand Down
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[pytest]
asyncio_mode=strict
23 changes: 22 additions & 1 deletion tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import aresponses
from tests.test_base import LogiUnitTestBase
from logi_circle.const import AUTH_HOST, TOKEN_ENDPOINT, DEFAULT_SCOPES
from logi_circle.exception import NotAuthorized, AuthorizationFailed
from logi_circle.exception import NotAuthorized, AuthorizationFailed, SessionInvalidated


class TestAuth(LogiUnitTestBase):
Expand Down Expand Up @@ -99,6 +99,27 @@ async def run_test():

self.loop.run_until_complete(run_test())

def test_session_invalidation(self):
"""Test session invalidation."""
logi = self.logi

async def run_test():
async with aresponses.ResponsesMockServer(loop=self.loop) as arsps:
arsps.add(AUTH_HOST, TOKEN_ENDPOINT, 'post',
aresponses.Response(status=401,
text=self.fixtures['failed_authorization'],
headers={'content-type': 'application/json'}))

# Mock authorization, and verify AuthProvider state
with self.assertRaises(AuthorizationFailed):
await logi.authorize('letmein')

# Attempt authorisation again
with self.assertRaises(SessionInvalidated):
await logi.authorize('letmein')

self.loop.run_until_complete(run_test())

def test_token_persistence(self):
"""Test that token is loaded from the cache file implicitly."""

Expand Down
22 changes: 21 additions & 1 deletion tests/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tests.test_base import LogiUnitTestBase
from logi_circle import LogiCircle
from logi_circle.const import AUTH_HOST, TOKEN_ENDPOINT, API_HOST, ACCESSORIES_ENDPOINT, DEFAULT_FFMPEG_BIN
from logi_circle.exception import NotAuthorized, AuthorizationFailed
from logi_circle.exception import NotAuthorized, AuthorizationFailed, SessionInvalidated


class TestAuth(LogiUnitTestBase):
Expand Down Expand Up @@ -101,6 +101,26 @@ async def run_test():
await logi._fetch(url='/api')

self.loop.run_until_complete(run_test())

def test_fetch_guard_invalid_session(self):
"""Fetch should bail out if session was invalidated"""
logi = self.logi
logi.auth_provider = self.get_authorized_auth_provider()

async def run_test():
async with aresponses.ResponsesMockServer(loop=self.loop) as arsps:
arsps.add(API_HOST, '/api', 'get',
aresponses.Response(status=401))
arsps.add(AUTH_HOST, TOKEN_ENDPOINT, 'post',
aresponses.Response(status=429,
text='too many requests'))

with self.assertRaises(AuthorizationFailed):
await logi._fetch(url='/api')
with self.assertRaises(SessionInvalidated):
await logi._fetch(url='/api')

self.loop.run_until_complete(run_test())

def test_fetch_raw(self):
"""Fetch should return ClientResponse object if raw parameter set"""
Expand Down

0 comments on commit ce48e3f

Please sign in to comment.