From 0ca577d7655d27b5aeb285ad29084edf73e33c55 Mon Sep 17 00:00:00 2001 From: Evan Bruhn Date: Sun, 8 May 2022 22:27:51 +1000 Subject: [PATCH] Synchronize requests to Logi authorization API (#14) --- logi_circle/auth.py | 62 +++++++++++++++++++++++++++------------------ 1 file changed, 38 insertions(+), 24 deletions(-) diff --git a/logi_circle/auth.py b/logi_circle/auth.py index 7497f6f..6892c4a 100644 --- a/logi_circle/auth.py +++ b/logi_circle/auth.py @@ -6,6 +6,7 @@ import pickle from urllib.parse import urlencode import aiohttp +import asyncio from .const import AUTH_BASE, AUTH_ENDPOINT, TOKEN_ENDPOINT from .exception import AuthorizationFailed, NotAuthorized, SessionInvalidated @@ -26,6 +27,7 @@ def __init__(self, client_id, client_secret, redirect_uri, scopes, cache_file, l self.tokens = self._read_token() self.invalid = False self.session = None + self._lock = asyncio.Lock() @property def authorized(self): @@ -85,6 +87,8 @@ async def refresh(self): "client_id": self.client_id, "client_secret": self.client_secret} + _LOGGER.debug("Refreshing access token for client %s", self.client_id) + await self._authenticate(refresh_payload) async def close(self): @@ -105,34 +109,44 @@ async def _authenticate(self, payload): if self.invalid: raise SessionInvalidated('Logi API session invalidated due to 4xx exception refreshing token') - session = await self.get_session() - async with session.post(AUTH_BASE + TOKEN_ENDPOINT, data=payload) as req: - try: - response = await req.json() - - if req.status >= 400: + if self._lock.locked(): + async with self._lock: + _LOGGER.debug("Concurrent request to authenticate client ID %s ignored", self.client_id) + return + + async with self._lock: + _LOGGER.debug("Authenticating client ID %s", self.client_id) + + session = await self.get_session() + async with session.post(AUTH_BASE + TOKEN_ENDPOINT, data=payload) as req: + try: + response = await req.json() + + if req.status >= 400: + self.logi.is_connected = False + if req.status >= 400 and req.status < 500: + self.invalid = True + + error_message = response.get( + "error_description", "Non-OK code %s returned" % (req.status)) + raise AuthorizationFailed(error_message) + + # Authorization succeeded. Persist the refresh and access tokens. + _LOGGER.debug("Successfully authenticated client ID %s", self.client_id) + self.logi.is_connected = True + self.invalid = False + self.tokens[self.client_id] = response + self._save_token() + except aiohttp.ContentTypeError: + response = await req.text() self.logi.is_connected = False if req.status >= 400 and req.status < 500: self.invalid = True - error_message = response.get( - "error_description", "Non-OK code %s returned" % (req.status)) - raise AuthorizationFailed(error_message) - - # Authorization succeeded. Persist the refresh and access tokens. - self.logi.is_connected = True - self.tokens[self.client_id] = response - self._save_token() - except aiohttp.ContentTypeError: - response = await req.text() - self.logi.is_connected = False - if req.status >= 400 and req.status < 500: - self.invalid = True - - if req.status >= 400: - raise AuthorizationFailed("Non-OK code %s returned: %s" % (req.status, response)) - else: - raise AuthorizationFailed("Unexpected content type from Logi API: %s" % (response)) + if req.status >= 400: + raise AuthorizationFailed("Non-OK code %s returned: %s" % (req.status, response)) + else: + raise AuthorizationFailed("Unexpected content type from Logi API: %s" % (response)) async def get_session(self): """Returns a aiohttp session, creating one if it doesn't already exist."""