Skip to content

Commit

Permalink
Better connection error handling for remote loader (#2313)
Browse files Browse the repository at this point in the history
* fix: better connection error handling for remote loader

* fix: typing
  • Loading branch information
bramstroker committed Jun 23, 2024
1 parent dad68da commit 583f0ff
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 15 deletions.
37 changes: 24 additions & 13 deletions custom_components/powercalc/power_profile/loader/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
import logging
import os
import time
from collections.abc import Callable, Coroutine
from functools import partial
from typing import Any, cast

import aiohttp
from aiohttp import ClientError
from homeassistant.core import HomeAssistant
from homeassistant.helpers.storage import STORAGE_DIR

Expand Down Expand Up @@ -56,12 +59,19 @@ def _load_local_library_json() -> dict[str, Any]:
with open(get_library_json_path()) as f:
return cast(dict[str, Any], json.load(f))

_LOGGER.debug("Loading library.json from github")
async with aiohttp.ClientSession() as session, session.get(ENDPOINT_LIBRARY) as resp:
if resp.status != 200:
_LOGGER.error("Failed to download library.json from github, falling back to local copy")
return await self.hass.async_add_executor_job(_load_local_library_json) # type: ignore
return cast(dict[str, Any], await resp.json())
async def _download_remote_library_json() -> dict[str, Any] | None:
"""Download library.json from github"""
_LOGGER.debug("Loading library.json from github")
async with aiohttp.ClientSession() as session, session.get(ENDPOINT_LIBRARY) as resp:
if resp.status != 200:
raise ProfileDownloadError("Failed to download library.json, unexpected status code")
return cast(dict[str, Any], await resp.json())

try:
return cast(dict[str, Any], await self.download_with_retry(_download_remote_library_json))
except ProfileDownloadError:
_LOGGER.debug("Failed to download library.json, falling back to local copy")
return await self.hass.async_add_executor_job(_load_local_library_json) # type: ignore

async def get_manufacturer_listing(self, device_type: DeviceType | None) -> set[str]:
"""Get listing of available manufacturers."""
Expand Down Expand Up @@ -102,7 +112,8 @@ async def load_model(self, manufacturer: str, model: str, force_update: bool = F

if needs_update or force_update:
try:
await self.download_with_retry(manufacturer, model, storage_path)
callback = partial(self.download_profile, manufacturer, model, storage_path)
await self.download_with_retry(callback)
except ProfileDownloadError as e:
if not path_exists:
raise e
Expand Down Expand Up @@ -159,23 +170,23 @@ def _get_remote_modification_time(model_info: dict) -> float:
remote_modification_time = datetime.datetime.fromisoformat(remote_modification_time).timestamp()
return remote_modification_time # type: ignore

async def download_with_retry(self, manufacturer: str, model: str, storage_path: str) -> None:
async def download_with_retry(self, callback: Callable[[], Coroutine[Any, Any, None | dict[str, Any]]]) -> None | dict[str, Any]:
"""Download a file from a remote endpoint with retries"""
max_retries = 3
retry_count = 0

while retry_count < max_retries:
try:
await self.download_profile(manufacturer, model, storage_path)
break # Break out of the loop if download is successful
except ProfileDownloadError as e:
return await callback()
except (ClientError, ProfileDownloadError) as e:
_LOGGER.error(e, exc_info=e)
retry_count += 1
if retry_count == max_retries:
raise ProfileDownloadError(f"Failed to download profile even after {max_retries} retries, falling back to local profile") from e
raise ProfileDownloadError(f"Failed to download even after {max_retries} retries, falling back to local copy") from e

await asyncio.sleep(self.retry_timeout)
_LOGGER.warning("Failed to download profile, retrying... (Attempt %d of %d)", retry_count, max_retries)
_LOGGER.warning("Failed to download, retrying... (Attempt %d of %d)", retry_count, max_retries)
return None

async def download_profile(self, manufacturer: str, model: str, storage_path: str) -> None:
"""
Expand Down
36 changes: 34 additions & 2 deletions tests/power_profile/loader/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import shutil
import time
from functools import partial

import pytest
from aiohttp import ClientError
Expand Down Expand Up @@ -208,7 +209,8 @@ async def test_eventual_success_after_download_retry(mock_aioresponse: aiorespon
mock_aioresponse.get(remote_file["url"], status=500)
mock_aioresponse.get(remote_file["url"], status=200)

await remote_loader.download_with_retry(manufacturer, model, storage_path)
callback = partial(remote_loader.download_profile, manufacturer, model, storage_path)
await remote_loader.download_with_retry(callback)

assert os.path.exists(storage_path)

Expand Down Expand Up @@ -280,6 +282,10 @@ def _mock_library_json(profile_updated_at: str) -> None:


async def test_fallback_to_local_library(hass: HomeAssistant, mock_aioresponse: aioresponses, caplog: pytest.LogCaptureFixture) -> None:
"""
Test that the local library is used when the remote library is not available.
When unavailable, it should retry 3 times before falling back to the local library.
"""
caplog.set_level(logging.ERROR)
mock_aioresponse.get(
ENDPOINT_LIBRARY,
Expand All @@ -288,10 +294,36 @@ async def test_fallback_to_local_library(hass: HomeAssistant, mock_aioresponse:
)

loader = RemoteLoader(hass)
loader.retry_timeout = 0
await loader.initialize()

assert "signify" in loader.manufacturer_models
assert len(caplog.records) == 3


async def test_fallback_to_local_library_on_client_connection_error(
hass: HomeAssistant,
mock_aioresponse: aioresponses,
caplog: pytest.LogCaptureFixture,
) -> None:
"""
Test that the local library is used when powercalc.lauwbier.nl is not available.
See: https://github.com/bramstroker/homeassistant-powercalc/issues/2277
"""
caplog.set_level(logging.ERROR)
mock_aioresponse.get(
ENDPOINT_LIBRARY,
status=200,
repeat=True,
exception=ClientError("test"),
)

loader = RemoteLoader(hass)
loader.retry_timeout = 0
await loader.initialize()

assert "signify" in loader.manufacturer_models
assert len(caplog.records) == 1
assert len(caplog.records) == 3


async def test_fallback_to_local_profile(
Expand Down

0 comments on commit 583f0ff

Please sign in to comment.