From 3a6950d1866cec8396eca985d8235ecd887a242f Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Wed, 21 Aug 2019 15:11:59 -0700 Subject: [PATCH 1/9] Reload config entry when disabled_by updated in entity registry --- homeassistant/config_entries.py | 103 +++++++++++++++++++++-- homeassistant/helpers/entity_registry.py | 2 +- tests/helpers/test_entity_registry.py | 1 + tests/test_config_entries.py | 69 +++++++++++++++ 4 files changed, 165 insertions(+), 10 deletions(-) diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index 2e1fbea14d1392..e301ee388acaee 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -3,13 +3,7 @@ import logging import functools import uuid -from typing import ( - Any, - Callable, - List, - Optional, - Set, # noqa pylint: disable=unused-import -) +from typing import Any, Callable, List, Optional, Set import weakref import attr @@ -19,6 +13,7 @@ from homeassistant.exceptions import HomeAssistantError, ConfigEntryNotReady from homeassistant.setup import async_setup_component, async_process_deps_reqs from homeassistant.util.decorator import Registry +from homeassistant.helpers import entity_registry # mypy: allow-untyped-defs @@ -161,8 +156,6 @@ async def async_setup( try: component = integration.get_component() - if self.domain == integration.domain: - integration.get_platform("config_flow") except ImportError as err: _LOGGER.error( "Error importing integration %s to set up %s config entry: %s", @@ -174,6 +167,20 @@ async def async_setup( self.state = ENTRY_STATE_SETUP_ERROR return + if self.domain == integration.domain: + try: + integration.get_platform("config_flow") + except ImportError as err: + _LOGGER.error( + "Error importing platform config_flow from integration %s to set up %s config entry: %s", + integration.domain, + self.domain, + err, + ) + if self.domain == integration.domain: + self.state = ENTRY_STATE_SETUP_ERROR + return + # Perform migration if integration.domain == self.domain: if not await self.async_migrate(hass): @@ -383,6 +390,7 @@ def __init__(self, hass: HomeAssistant, hass_config: dict) -> None: self._hass_config = hass_config self._entries = [] # type: List[ConfigEntry] self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) + EntityRegistryDisabledHandler(hass) @callback def async_domains(self) -> List[str]: @@ -757,3 +765,80 @@ def update(self, *, disable_new_entities): def as_dict(self): """Return dictionary version of this config entrys system options.""" return {"disable_new_entities": self.disable_new_entities} + + +class EntityRegistryDisabledHandler: + """Handler to handle when entities related to config entries updating disabled_by.""" + + RELOAD_AFTER_UPDATE_DELAY = 30 + + def __init__(self, hass: HomeAssistant) -> None: + """Initialize the handler.""" + self.hass = hass + self.registry: Optional[entity_registry.EntityRegistry] = None + self.changed: Set[str] = set() + self._remove_call_later: Optional[Callable[[], None]] = None + + hass.bus.async_listen( + entity_registry.EVENT_ENTITY_REGISTRY_UPDATED, self._handle_entry_updated + ) + + async def _handle_entry_updated(self, event): + """Handle entity registry entry update.""" + if ( + event.data["action"] != "update" + or "disabled_by" not in event.data["changes"] + ): + return + + if self.registry is None: + self.registry = await entity_registry.async_get_registry(self.hass) + + entity_entry = self.registry.async_get(event.data["entity_id"]) + + if entity_entry is None or entity_entry.config_entry_id is None: + return + + config_entry = self.hass.config_entries.async_get_entry( + entity_entry.config_entry_id + ) + + if config_entry.entry_id not in self.changed and await support_entry_unload( + self.hass, config_entry.domain + ): + self.changed.add(config_entry.entry_id) + + if not self.changed: + return + + # We are going to delay reloading on *every* entity registry change so that + # if a user is happily clicking along, it will only reload at the end. + + if self._remove_call_later: + self._remove_call_later() + + self._remove_call_later = self.hass.helpers.event.async_call_later( + self.RELOAD_AFTER_UPDATE_DELAY, self._handle_reload + ) + + async def _handle_reload(self, _now): + """Handle a reload.""" + self._remove_call_later = None + to_reload = self.changed + self.changed = set() + + _LOGGER.info( + "Reloading config entries because disabled_by changed in entity registry: %s", + ", ".join(self.changed), + ) + + await asyncio.gather( + *[self.hass.config_entries.async_reload(entry_id) for entry_id in to_reload] + ) + + +async def support_entry_unload(hass, domain) -> bool: + """Test if a domain supports entry unloading.""" + integration = await loader.async_get_integration(hass, domain) + component = integration.get_component() + return hasattr(component, "async_unload_entry") diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index 3d84313a5c650d..7d81f62fa1c051 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -302,7 +302,7 @@ def _async_update_entity( self.async_schedule_save() - data = {"action": "update", "entity_id": entity_id} + data = {"action": "update", "entity_id": entity_id, "changes": list(changes)} if old.entity_id != entity_id: data["old_entity_id"] = old.entity_id diff --git a/tests/helpers/test_entity_registry.py b/tests/helpers/test_entity_registry.py index aee6b6f19a3965..9debbdbcba7cda 100644 --- a/tests/helpers/test_entity_registry.py +++ b/tests/helpers/test_entity_registry.py @@ -219,6 +219,7 @@ async def test_updating_config_entry_id(hass, registry, update_events): assert update_events[0]["entity_id"] == entry.entity_id assert update_events[1]["action"] == "update" assert update_events[1]["entity_id"] == entry.entity_id + assert update_events[1]["changes"] == ["config_entry_id"] async def test_removing_config_entry_id(hass, registry, update_events): diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index ca6872a7a2cc1e..f8e8bdd048ba51 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -20,6 +20,7 @@ MockEntity, mock_integration, mock_entity_platform, + mock_registry, ) @@ -925,3 +926,71 @@ async def test_init_custom_integration(hass): return_value=mock_coro(integration), ): await hass.config_entries.flow.async_init("bla") + + +async def test_support_entry_unload(hass): + """Test unloading entry.""" + assert await config_entries.support_entry_unload(hass, "light") + assert not await config_entries.support_entry_unload(hass, "auth") + + +async def test_reload_entry_entity_registry_ignores_no_entry(hass): + """Test reloading entry in entity registry skips if no config entry linked.""" + handler = config_entries.EntityRegistryDisabledHandler(hass) + registry = mock_registry(hass) + + # Test we ignore entities without config entry + entry = registry.async_get_or_create("light", "hue", "123") + registry.async_update_entity(entry.entity_id, disabled_by="user") + await hass.async_block_till_done() + assert not handler.changed + assert handler._remove_call_later is None + + +async def test_reload_entry_entity_registry_works(hass): + """Test we schedule an entry to be reloaded if disabled_by is updated.""" + handler = config_entries.EntityRegistryDisabledHandler(hass) + registry = mock_registry(hass) + + config_entry = MockConfigEntry( + domain="comp", state=config_entries.ENTRY_STATE_LOADED + ) + config_entry.add_to_hass(hass) + mock_setup_entry = MagicMock(return_value=mock_coro(True)) + mock_unload_entry = MagicMock(return_value=mock_coro(True)) + mock_integration( + hass, + MockModule( + "comp", + async_setup_entry=mock_setup_entry, + async_unload_entry=mock_unload_entry, + ), + ) + mock_entity_platform(hass, "config_flow.comp", None) + + # Only changing disabled_by should update trigger + entity_entry = registry.async_get_or_create( + "light", "hue", "123", config_entry=config_entry + ) + registry.async_update_entity(entity_entry.entity_id, name="yo") + await hass.async_block_till_done() + assert not handler.changed + assert handler._remove_call_later is None + + # Changed disabled_by, check unloading. + registry.async_update_entity(entity_entry.entity_id, disabled_by="user") + await hass.async_block_till_done() + assert handler.changed == {config_entry.entry_id} + assert handler._remove_call_later is not None + + async_fire_time_changed( + hass, + dt.utcnow() + + timedelta( + seconds=config_entries.EntityRegistryDisabledHandler.RELOAD_AFTER_UPDATE_DELAY + + 1 + ), + ) + await hass.async_block_till_done() + + assert len(mock_unload_entry.mock_calls) == 1 From d239aea4bb0e8d664310318b999059aeaf196c79 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Wed, 21 Aug 2019 16:17:30 -0700 Subject: [PATCH 2/9] Add types --- homeassistant/config_entries.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index e301ee388acaee..d384e5a6a3444f 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -837,7 +837,7 @@ async def _handle_reload(self, _now): ) -async def support_entry_unload(hass, domain) -> bool: +async def support_entry_unload(hass: HomeAssistant, domain: str) -> bool: """Test if a domain supports entry unloading.""" integration = await loader.async_get_integration(hass, domain) component = integration.get_component() From 0d3b2a5f9961eeaa1835cd037db144b18d771379 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Thu, 22 Aug 2019 09:20:56 -0700 Subject: [PATCH 3/9] Remove entities that get disabled --- homeassistant/config_entries.py | 10 +++++++++- homeassistant/helpers/entity.py | 4 ++++ tests/helpers/test_entity.py | 25 +++++++++++++++++++++++++ tests/test_config_entries.py | 8 +++++++- 4 files changed, 45 insertions(+), 2 deletions(-) diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index d384e5a6a3444f..7098e9c934326f 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -796,7 +796,15 @@ async def _handle_entry_updated(self, event): entity_entry = self.registry.async_get(event.data["entity_id"]) - if entity_entry is None or entity_entry.config_entry_id is None: + if ( + # Stop if no entry found + entity_entry is None + # Stop if entry not connected to config entry + or entity_entry.config_entry_id is None + # Stop if the entry got disabled. In that case the entity handles it + # themselves. + or entity_entry.disabled_by + ): return config_entry = self.hass.config_entries.async_get_entry( diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index 7de41415f080ec..bd96e1bafdb5f4 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -503,6 +503,10 @@ async def _async_registry_updated(self, event): old = self.registry_entry self.registry_entry = ent_reg.async_get(data["entity_id"]) + if self.registry_entry.disabled_by is not None: + await self.async_remove() + return + if self.registry_entry.entity_id == old.entity_id: self.async_write_ha_state() return diff --git a/tests/helpers/test_entity.py b/tests/helpers/test_entity.py index 94650592d8e1bb..011c30acacfc0d 100644 --- a/tests/helpers/test_entity.py +++ b/tests/helpers/test_entity.py @@ -526,3 +526,28 @@ async def test_warn_disabled(hass, caplog): ent.async_write_ha_state() assert hass.states.get("hello.world") is None assert caplog.text == "" + + +async def test_disabled_in_entity_registry(hass): + """Test entity is removed if we disable entity registry entry.""" + + registry = mock_registry(hass, {"hello.world": entry}) + + ent = entity.Entity() + ent.hass = hass + ent.entity_id = "hello.world" + ent.registry_entry = entry + + await ent.async_internal_added_to_hass() + ent.async_write_ha_state() + assert hass.states.get("hello.world") is not None + + entry2 = registry.async_update_entity("hello.world", disabled_by=None) + await hass.async_block_till_done() + assert hass.states.get("hello.world") is not None + assert entry2 != entry + assert ent.registry_entry == entry2 + + registry.async_update_entity("hello.world", disabled_by="user") + await hass.async_block_till_done() + assert hass.states.get("hello.world") is None diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index f8e8bdd048ba51..28b369d4bb0313 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -977,9 +977,15 @@ async def test_reload_entry_entity_registry_works(hass): assert not handler.changed assert handler._remove_call_later is None - # Changed disabled_by, check unloading. + # Disable entity, we should not do anything, only act when enabled. registry.async_update_entity(entity_entry.entity_id, disabled_by="user") await hass.async_block_till_done() + assert not handler.changed + assert handler._remove_call_later is None + + # Enable entity, check we are reloading config entry. + registry.async_update_entity(entity_entry.entity_id, disabled_by=None) + await hass.async_block_till_done() assert handler.changed == {config_entry.entry_id} assert handler._remove_call_later is not None From 2192c60a852258519911e9614e46da08e4908238 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Thu, 22 Aug 2019 09:25:50 -0700 Subject: [PATCH 4/9] Remove unnecessary domain checks. --- homeassistant/config_entries.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index 7098e9c934326f..8262915c7b3783 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -177,12 +177,10 @@ async def async_setup( self.domain, err, ) - if self.domain == integration.domain: - self.state = ENTRY_STATE_SETUP_ERROR + self.state = ENTRY_STATE_SETUP_ERROR return - # Perform migration - if integration.domain == self.domain: + # Perform migration if not await self.async_migrate(hass): self.state = ENTRY_STATE_MIGRATION_ERROR return From d21b37d88a13199622ee84cce19a5fbd74640395 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Thu, 22 Aug 2019 09:29:45 -0700 Subject: [PATCH 5/9] Attach handler in async_setup --- homeassistant/config_entries.py | 7 +++++-- tests/test_config_entries.py | 1 + 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index 8262915c7b3783..c9efa17f4a17cd 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -388,7 +388,7 @@ def __init__(self, hass: HomeAssistant, hass_config: dict) -> None: self._hass_config = hass_config self._entries = [] # type: List[ConfigEntry] self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) - EntityRegistryDisabledHandler(hass) + EntityRegistryDisabledHandler(hass).async_setup() @callback def async_domains(self) -> List[str]: @@ -777,7 +777,10 @@ def __init__(self, hass: HomeAssistant) -> None: self.changed: Set[str] = set() self._remove_call_later: Optional[Callable[[], None]] = None - hass.bus.async_listen( + @callback + def async_setup(self): + """Set up the disable handler.""" + self.hass.bus.async_listen( entity_registry.EVENT_ENTITY_REGISTRY_UPDATED, self._handle_entry_updated ) diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index 28b369d4bb0313..d9dd614c9a5e4a 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -950,6 +950,7 @@ async def test_reload_entry_entity_registry_ignores_no_entry(hass): async def test_reload_entry_entity_registry_works(hass): """Test we schedule an entry to be reloaded if disabled_by is updated.""" handler = config_entries.EntityRegistryDisabledHandler(hass) + handler.async_setup() registry = mock_registry(hass) config_entry = MockConfigEntry( From a66bbc43fd3d66a53f8b2177a45ce3382a61e9bb Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Thu, 22 Aug 2019 09:31:48 -0700 Subject: [PATCH 6/9] Remove unused var --- homeassistant/helpers/entity.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index bd96e1bafdb5f4..d88eca0892c2e9 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -99,9 +99,6 @@ class Entity: # If we reported if this entity was slow _slow_reported = False - # If we reported this entity is updated while disabled - _disabled_reported = False - # Protect for multiple updates _update_staged = False From 3dd1f723bd9e863a37b41991229dd73ee6d214c7 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Thu, 22 Aug 2019 09:55:56 -0700 Subject: [PATCH 7/9] Type --- homeassistant/config_entries.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index c9efa17f4a17cd..c2da37943c1abb 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -778,7 +778,7 @@ def __init__(self, hass: HomeAssistant) -> None: self._remove_call_later: Optional[Callable[[], None]] = None @callback - def async_setup(self): + def async_setup(self) -> None: """Set up the disable handler.""" self.hass.bus.async_listen( entity_registry.EVENT_ENTITY_REGISTRY_UPDATED, self._handle_entry_updated From af7cb1b994595b61808bd371e9c9bd6b0efbd341 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Thu, 22 Aug 2019 11:10:35 -0700 Subject: [PATCH 8/9] Fix test --- tests/components/config/test_entity_registry.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/components/config/test_entity_registry.py b/tests/components/config/test_entity_registry.py index 64328a0c8c5647..9472d8882540c9 100644 --- a/tests/components/config/test_entity_registry.py +++ b/tests/components/config/test_entity_registry.py @@ -163,6 +163,7 @@ async def test_update_entity(hass, client): msg = await client.receive_json() + assert hass.states.get("test_domain.world") is None assert registry.entities["test_domain.world"].disabled_by == "user" # UPDATE DISABLED_BY TO NONE From 8af8b55a2fbdd3bec9e984d7c08fccee2a2f85ec Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Thu, 22 Aug 2019 14:20:36 -0700 Subject: [PATCH 9/9] Fix tests --- homeassistant/helpers/entity.py | 3 +++ tests/helpers/test_entity.py | 16 +++++++++++----- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index d88eca0892c2e9..bd96e1bafdb5f4 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -99,6 +99,9 @@ class Entity: # If we reported if this entity was slow _slow_reported = False + # If we reported this entity is updated while disabled + _disabled_reported = False + # Protect for multiple updates _update_staged = False diff --git a/tests/helpers/test_entity.py b/tests/helpers/test_entity.py index 011c30acacfc0d..3c89a5c65379d6 100644 --- a/tests/helpers/test_entity.py +++ b/tests/helpers/test_entity.py @@ -530,24 +530,30 @@ async def test_warn_disabled(hass, caplog): async def test_disabled_in_entity_registry(hass): """Test entity is removed if we disable entity registry entry.""" - + entry = entity_registry.RegistryEntry( + entity_id="hello.world", + unique_id="test-unique-id", + platform="test-platform", + disabled_by="user", + ) registry = mock_registry(hass, {"hello.world": entry}) ent = entity.Entity() ent.hass = hass ent.entity_id = "hello.world" ent.registry_entry = entry + ent.platform = MagicMock(platform_name="test-platform") await ent.async_internal_added_to_hass() ent.async_write_ha_state() - assert hass.states.get("hello.world") is not None + assert hass.states.get("hello.world") is None entry2 = registry.async_update_entity("hello.world", disabled_by=None) await hass.async_block_till_done() - assert hass.states.get("hello.world") is not None assert entry2 != entry assert ent.registry_entry == entry2 - registry.async_update_entity("hello.world", disabled_by="user") + entry3 = registry.async_update_entity("hello.world", disabled_by="user") await hass.async_block_till_done() - assert hass.states.get("hello.world") is None + assert entry3 != entry2 + assert ent.registry_entry == entry3