Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate restore_state helper to use registry loading pattern #93773

Merged
merged 13 commits into from
May 31, 2023
2 changes: 2 additions & 0 deletions homeassistant/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
entity_registry,
issue_registry,
recorder,
restore_state,
template,
)
from .helpers.dispatcher import async_dispatcher_send
Expand Down Expand Up @@ -248,6 +249,7 @@ def _cache_uname_processor() -> None:
issue_registry.async_load(hass),
hass.async_add_executor_job(_cache_uname_processor),
template.async_load_custom_templates(hass),
restore_state.async_load(hass),
)


Expand Down
84 changes: 44 additions & 40 deletions homeassistant/helpers/restore_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from __future__ import annotations

from abc import ABC, abstractmethod
import asyncio
from datetime import datetime, timedelta
import logging
from typing import Any, cast
Expand All @@ -18,10 +17,9 @@
from .entity import Entity
from .event import async_track_time_interval
from .json import JSONEncoder
from .singleton import singleton
from .storage import Store

DATA_RESTORE_STATE_TASK = "restore_state_task"
DATA_RESTORE_STATE = "restore_state"

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -96,31 +94,25 @@ def from_dict(cls, json_dict: dict) -> Self:
)


async def async_load(hass: HomeAssistant) -> None:
"""Load the restore state task."""
hass.data[DATA_RESTORE_STATE] = await RestoreStateData.async_get_instance(hass)


@callback
def async_get(hass: HomeAssistant) -> RestoreStateData:
"""Get the restore state data helper."""
return cast(RestoreStateData, hass.data[DATA_RESTORE_STATE])


class RestoreStateData:
"""Helper class for managing the helper saved data."""

@staticmethod
@singleton(DATA_RESTORE_STATE_TASK)
async def async_get_instance(hass: HomeAssistant) -> RestoreStateData:
"""Get the singleton instance of this data helper."""
"""Get the instance of this data helper."""
data = RestoreStateData(hass)

try:
stored_states = await data.store.async_load()
except HomeAssistantError as exc:
_LOGGER.error("Error loading last states", exc_info=exc)
stored_states = None

if stored_states is None:
_LOGGER.debug("Not creating cache - no saved states found")
data.last_states = {}
else:
data.last_states = {
item["state"]["entity_id"]: StoredState.from_dict(item)
for item in stored_states
if valid_entity_id(item["state"]["entity_id"])
}
_LOGGER.debug("Created cache with %s", list(data.last_states))
await data.async_load()

async def hass_start(hass: HomeAssistant) -> None:
"""Start the restore state task."""
Expand All @@ -133,8 +125,7 @@ async def hass_start(hass: HomeAssistant) -> None:
@classmethod
async def async_save_persistent_states(cls, hass: HomeAssistant) -> None:
"""Dump states now."""
data = await cls.async_get_instance(hass)
await data.async_dump_states()
await async_get(hass).async_dump_states()

def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the restore state data class."""
Expand All @@ -145,6 +136,25 @@ def __init__(self, hass: HomeAssistant) -> None:
self.last_states: dict[str, StoredState] = {}
self.entities: dict[str, RestoreEntity] = {}

async def async_load(self) -> None:
"""Load the instance of this data helper."""
try:
stored_states = await self.store.async_load()
except HomeAssistantError as exc:
_LOGGER.error("Error loading last states", exc_info=exc)
stored_states = None

if stored_states is None:
_LOGGER.debug("Not creating cache - no saved states found")
self.last_states = {}
else:
self.last_states = {
item["state"]["entity_id"]: StoredState.from_dict(item)
for item in stored_states
if valid_entity_id(item["state"]["entity_id"])
}
_LOGGER.debug("Created cache with %s", list(self.last_states))

@callback
def async_get_stored_states(self) -> list[StoredState]:
"""Get the set of states which should be stored.
Expand Down Expand Up @@ -288,42 +298,36 @@ class RestoreEntity(Entity):

async def async_internal_added_to_hass(self) -> None:
"""Register this entity as a restorable entity."""
_, data = await asyncio.gather(
super().async_internal_added_to_hass(),
RestoreStateData.async_get_instance(self.hass),
)
data.async_restore_entity_added(self)
await super().async_internal_added_to_hass()
async_get(self.hass).async_restore_entity_added(self)

async def async_internal_will_remove_from_hass(self) -> None:
"""Run when entity will be removed from hass."""
_, data = await asyncio.gather(
super().async_internal_will_remove_from_hass(),
RestoreStateData.async_get_instance(self.hass),
async_get(self.hass).async_restore_entity_removed(
self.entity_id, self.extra_restore_state_data
)
data.async_restore_entity_removed(self.entity_id, self.extra_restore_state_data)
await super().async_internal_will_remove_from_hass()

async def _async_get_restored_data(self) -> StoredState | None:
@callback
def _async_get_restored_data(self) -> StoredState | None:
"""Get data stored for an entity, if any."""
if self.hass is None or self.entity_id is None:
# Return None if this entity isn't added to hass yet
_LOGGER.warning( # type: ignore[unreachable]
"Cannot get last state. Entity not added to hass"
)
return None
data = await RestoreStateData.async_get_instance(self.hass)
if self.entity_id not in data.last_states:
return None
return data.last_states[self.entity_id]
return async_get(self.hass).last_states.get(self.entity_id)

async def async_get_last_state(self) -> State | None:
"""Get the entity state from the previous run."""
if (stored_state := await self._async_get_restored_data()) is None:
if (stored_state := self._async_get_restored_data()) is None:
return None
return stored_state.state

async def async_get_last_extra_data(self) -> ExtraStoredData | None:
"""Get the entity specific state data from the previous run."""
if (stored_state := await self._async_get_restored_data()) is None:
if (stored_state := self._async_get_restored_data()) is None:
return None
return stored_state.extra_data

Expand Down
35 changes: 32 additions & 3 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
issue_registry as ir,
recorder as recorder_helper,
restore_state,
restore_state as rs,
storage,
)
from homeassistant.helpers.dispatcher import async_dispatcher_connect
Expand Down Expand Up @@ -251,12 +252,20 @@ def async_create_task(coroutine, name=None):
# Load the registries
entity.async_setup(hass)
if load_registries:
with patch("homeassistant.helpers.storage.Store.async_load", return_value=None):
with patch(
"homeassistant.helpers.storage.Store.async_load", return_value=None
), patch(
"homeassistant.helpers.restore_state.RestoreStateData.async_setup_dump",
return_value=None,
), patch(
"homeassistant.helpers.restore_state.start.async_at_start"
):
await asyncio.gather(
ar.async_load(hass),
dr.async_load(hass),
er.async_load(hass),
ir.async_load(hass),
rs.async_load(hass),
)
hass.data[bootstrap.DATA_REGISTRIES_LOADED] = None

Expand Down Expand Up @@ -1010,7 +1019,7 @@ def init_recorder_component(hass, add_config=None, db_url="sqlite://"):

def mock_restore_cache(hass: HomeAssistant, states: Sequence[State]) -> None:
"""Mock the DATA_RESTORE_CACHE."""
key = restore_state.DATA_RESTORE_STATE_TASK
key = restore_state.DATA_RESTORE_STATE
data = restore_state.RestoreStateData(hass)
now = date_util.utcnow()

Expand All @@ -1037,7 +1046,7 @@ def mock_restore_cache_with_extra_data(
hass: HomeAssistant, states: Sequence[tuple[State, Mapping[str, Any]]]
) -> None:
"""Mock the DATA_RESTORE_CACHE."""
key = restore_state.DATA_RESTORE_STATE_TASK
key = restore_state.DATA_RESTORE_STATE
data = restore_state.RestoreStateData(hass)
now = date_util.utcnow()

Expand All @@ -1060,6 +1069,26 @@ def mock_restore_cache_with_extra_data(
hass.data[key] = data


async def async_mock_restore_state_shutdown_restart(
hass: HomeAssistant,
) -> restore_state.RestoreStateData:
"""Mock shutting down and saving restore state and restoring."""
data = restore_state.async_get(hass)
await data.async_dump_states()
await async_mock_load_restore_state_from_storage(hass)
return data


async def async_mock_load_restore_state_from_storage(
hass: HomeAssistant,
) -> None:
"""Mock loading restore state from storage.

hass_storage must already be mocked.
"""
await restore_state.async_get(hass).async_load()


class MockEntity(entity.Entity):
"""Mock Entity class."""

Expand Down
7 changes: 5 additions & 2 deletions tests/components/number/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@
from homeassistant.setup import async_setup_component
from homeassistant.util.unit_system import METRIC_SYSTEM, US_CUSTOMARY_SYSTEM

from tests.common import mock_restore_cache_with_extra_data
from tests.common import (
async_mock_restore_state_shutdown_restart,
mock_restore_cache_with_extra_data,
)


class MockDefaultNumberEntity(NumberEntity):
Expand Down Expand Up @@ -635,7 +638,7 @@ async def test_restore_number_save_state(
await hass.async_block_till_done()

# Trigger saving state
await hass.async_stop()
await async_mock_restore_state_shutdown_restart(hass)

assert len(hass_storage[RESTORE_STATE_KEY]["data"]) == 1
state = hass_storage[RESTORE_STATE_KEY]["data"][0]["state"]
Expand Down
7 changes: 5 additions & 2 deletions tests/components/sensor/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@
from homeassistant.util import dt as dt_util
from homeassistant.util.unit_system import METRIC_SYSTEM, US_CUSTOMARY_SYSTEM

from tests.common import mock_restore_cache_with_extra_data
from tests.common import (
async_mock_restore_state_shutdown_restart,
mock_restore_cache_with_extra_data,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -397,7 +400,7 @@ async def test_restore_sensor_save_state(
await hass.async_block_till_done()

# Trigger saving state
await hass.async_stop()
await async_mock_restore_state_shutdown_restart(hass)

assert len(hass_storage[RESTORE_STATE_KEY]["data"]) == 1
state = hass_storage[RESTORE_STATE_KEY]["data"][0]["state"]
Expand Down
7 changes: 5 additions & 2 deletions tests/components/text/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
from homeassistant.helpers.restore_state import STORAGE_KEY as RESTORE_STATE_KEY
from homeassistant.setup import async_setup_component

from tests.common import mock_restore_cache_with_extra_data
from tests.common import (
async_mock_restore_state_shutdown_restart,
mock_restore_cache_with_extra_data,
)


class MockTextEntity(TextEntity):
Expand Down Expand Up @@ -141,7 +144,7 @@ async def test_restore_number_save_state(
await hass.async_block_till_done()

# Trigger saving state
await hass.async_stop()
await async_mock_restore_state_shutdown_restart(hass)

assert len(hass_storage[RESTORE_STATE_KEY]["data"]) == 1
state = hass_storage[RESTORE_STATE_KEY]["data"][0]["state"]
Expand Down
34 changes: 9 additions & 25 deletions tests/components/timer/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,7 @@
from homeassistant.core import Context, CoreState, HomeAssistant, State
from homeassistant.exceptions import HomeAssistantError, Unauthorized
from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.helpers.restore_state import (
DATA_RESTORE_STATE_TASK,
RestoreStateData,
StoredState,
)
from homeassistant.helpers.restore_state import StoredState, async_get
from homeassistant.setup import async_setup_component
from homeassistant.util.dt import utcnow

Expand Down Expand Up @@ -838,12 +834,9 @@ async def test_restore_idle(hass: HomeAssistant) -> None:
utc_now,
)

data = await RestoreStateData.async_get_instance(hass)
await hass.async_block_till_done()
data = async_get(hass)
await data.store.async_save([stored_state.as_dict()])

# Emulate a fresh load
hass.data.pop(DATA_RESTORE_STATE_TASK)
await data.async_load()

entity = Timer.from_storage(
{
Expand Down Expand Up @@ -878,12 +871,9 @@ async def test_restore_paused(hass: HomeAssistant) -> None:
utc_now,
)

data = await RestoreStateData.async_get_instance(hass)
await hass.async_block_till_done()
data = async_get(hass)
await data.store.async_save([stored_state.as_dict()])

# Emulate a fresh load
hass.data.pop(DATA_RESTORE_STATE_TASK)
await data.async_load()

entity = Timer.from_storage(
{
Expand Down Expand Up @@ -922,12 +912,9 @@ async def test_restore_active_resume(hass: HomeAssistant) -> None:
utc_now,
)

data = await RestoreStateData.async_get_instance(hass)
await hass.async_block_till_done()
data = async_get(hass)
await data.store.async_save([stored_state.as_dict()])

# Emulate a fresh load
hass.data.pop(DATA_RESTORE_STATE_TASK)
await data.async_load()

entity = Timer.from_storage(
{
Expand Down Expand Up @@ -973,12 +960,9 @@ async def test_restore_active_finished_outside_grace(hass: HomeAssistant) -> Non
utc_now,
)

data = await RestoreStateData.async_get_instance(hass)
await hass.async_block_till_done()
data = async_get(hass)
await data.store.async_save([stored_state.as_dict()])

# Emulate a fresh load
hass.data.pop(DATA_RESTORE_STATE_TASK)
await data.async_load()

entity = Timer.from_storage(
{
Expand Down