Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reload config entry when entity enabled in entity registry, remove entity if disabled. #26120

Merged
merged 9 commits into from Aug 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
116 changes: 105 additions & 11 deletions homeassistant/config_entries.py
Expand Up @@ -3,13 +3,7 @@
import logging
import functools
import uuid
from typing import (
Any,
Callable,
List,
Optional,
Set, # noqa pylint: disable=unused-import
)
from typing import Any, Callable, List, Optional, Set
import weakref

import attr
Expand All @@ -19,6 +13,7 @@
from homeassistant.exceptions import HomeAssistantError, ConfigEntryNotReady
from homeassistant.setup import async_setup_component, async_process_deps_reqs
from homeassistant.util.decorator import Registry
from homeassistant.helpers import entity_registry

# mypy: allow-untyped-defs

Expand Down Expand Up @@ -161,8 +156,6 @@ async def async_setup(

try:
component = integration.get_component()
if self.domain == integration.domain:
integration.get_platform("config_flow")
except ImportError as err:
_LOGGER.error(
"Error importing integration %s to set up %s config entry: %s",
Expand All @@ -174,8 +167,20 @@ async def async_setup(
self.state = ENTRY_STATE_SETUP_ERROR
return

# Perform migration
if integration.domain == self.domain:
if self.domain == integration.domain:
try:
integration.get_platform("config_flow")
except ImportError as err:
_LOGGER.error(
"Error importing platform config_flow from integration %s to set up %s config entry: %s",
integration.domain,
self.domain,
err,
)
self.state = ENTRY_STATE_SETUP_ERROR
return

# Perform migration
if not await self.async_migrate(hass):
self.state = ENTRY_STATE_MIGRATION_ERROR
return
Expand Down Expand Up @@ -383,6 +388,7 @@ def __init__(self, hass: HomeAssistant, hass_config: dict) -> None:
self._hass_config = hass_config
self._entries = [] # type: List[ConfigEntry]
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
EntityRegistryDisabledHandler(hass).async_setup()

@callback
def async_domains(self) -> List[str]:
Expand Down Expand Up @@ -757,3 +763,91 @@ def update(self, *, disable_new_entities):
def as_dict(self):
"""Return dictionary version of this config entrys system options."""
return {"disable_new_entities": self.disable_new_entities}


class EntityRegistryDisabledHandler:
"""Handler to handle when entities related to config entries updating disabled_by."""

RELOAD_AFTER_UPDATE_DELAY = 30

def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the handler."""
self.hass = hass
self.registry: Optional[entity_registry.EntityRegistry] = None
self.changed: Set[str] = set()
self._remove_call_later: Optional[Callable[[], None]] = None

@callback
def async_setup(self) -> None:
"""Set up the disable handler."""
self.hass.bus.async_listen(
entity_registry.EVENT_ENTITY_REGISTRY_UPDATED, self._handle_entry_updated
)

async def _handle_entry_updated(self, event):
"""Handle entity registry entry update."""
if (
event.data["action"] != "update"
or "disabled_by" not in event.data["changes"]
):
return

if self.registry is None:
self.registry = await entity_registry.async_get_registry(self.hass)

entity_entry = self.registry.async_get(event.data["entity_id"])

if (
# Stop if no entry found
entity_entry is None
# Stop if entry not connected to config entry
or entity_entry.config_entry_id is None
# Stop if the entry got disabled. In that case the entity handles it
# themselves.
or entity_entry.disabled_by
):
return

config_entry = self.hass.config_entries.async_get_entry(
entity_entry.config_entry_id
)

if config_entry.entry_id not in self.changed and await support_entry_unload(
self.hass, config_entry.domain
):
self.changed.add(config_entry.entry_id)

if not self.changed:
return

# We are going to delay reloading on *every* entity registry change so that
# if a user is happily clicking along, it will only reload at the end.

if self._remove_call_later:
self._remove_call_later()

self._remove_call_later = self.hass.helpers.event.async_call_later(
self.RELOAD_AFTER_UPDATE_DELAY, self._handle_reload
)

async def _handle_reload(self, _now):
"""Handle a reload."""
self._remove_call_later = None
to_reload = self.changed
self.changed = set()

_LOGGER.info(
"Reloading config entries because disabled_by changed in entity registry: %s",
", ".join(self.changed),
)

await asyncio.gather(
*[self.hass.config_entries.async_reload(entry_id) for entry_id in to_reload]
)


async def support_entry_unload(hass: HomeAssistant, domain: str) -> bool:
"""Test if a domain supports entry unloading."""
integration = await loader.async_get_integration(hass, domain)
component = integration.get_component()
return hasattr(component, "async_unload_entry")
4 changes: 4 additions & 0 deletions homeassistant/helpers/entity.py
Expand Up @@ -503,6 +503,10 @@ async def _async_registry_updated(self, event):
old = self.registry_entry
self.registry_entry = ent_reg.async_get(data["entity_id"])

if self.registry_entry.disabled_by is not None:
await self.async_remove()
return

if self.registry_entry.entity_id == old.entity_id:
self.async_write_ha_state()
return
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/helpers/entity_registry.py
Expand Up @@ -302,7 +302,7 @@ def _async_update_entity(

self.async_schedule_save()

data = {"action": "update", "entity_id": entity_id}
data = {"action": "update", "entity_id": entity_id, "changes": list(changes)}

if old.entity_id != entity_id:
data["old_entity_id"] = old.entity_id
Expand Down
1 change: 1 addition & 0 deletions tests/components/config/test_entity_registry.py
Expand Up @@ -163,6 +163,7 @@ async def test_update_entity(hass, client):

msg = await client.receive_json()

assert hass.states.get("test_domain.world") is None
assert registry.entities["test_domain.world"].disabled_by == "user"

# UPDATE DISABLED_BY TO NONE
Expand Down
31 changes: 31 additions & 0 deletions tests/helpers/test_entity.py
Expand Up @@ -526,3 +526,34 @@ async def test_warn_disabled(hass, caplog):
ent.async_write_ha_state()
assert hass.states.get("hello.world") is None
assert caplog.text == ""


async def test_disabled_in_entity_registry(hass):
"""Test entity is removed if we disable entity registry entry."""
entry = entity_registry.RegistryEntry(
entity_id="hello.world",
unique_id="test-unique-id",
platform="test-platform",
disabled_by="user",
)
registry = mock_registry(hass, {"hello.world": entry})

ent = entity.Entity()
ent.hass = hass
ent.entity_id = "hello.world"
ent.registry_entry = entry
ent.platform = MagicMock(platform_name="test-platform")

await ent.async_internal_added_to_hass()
ent.async_write_ha_state()
assert hass.states.get("hello.world") is None

entry2 = registry.async_update_entity("hello.world", disabled_by=None)
await hass.async_block_till_done()
assert entry2 != entry
assert ent.registry_entry == entry2

entry3 = registry.async_update_entity("hello.world", disabled_by="user")
await hass.async_block_till_done()
assert entry3 != entry2
assert ent.registry_entry == entry3
1 change: 1 addition & 0 deletions tests/helpers/test_entity_registry.py
Expand Up @@ -219,6 +219,7 @@ async def test_updating_config_entry_id(hass, registry, update_events):
assert update_events[0]["entity_id"] == entry.entity_id
assert update_events[1]["action"] == "update"
assert update_events[1]["entity_id"] == entry.entity_id
assert update_events[1]["changes"] == ["config_entry_id"]


async def test_removing_config_entry_id(hass, registry, update_events):
Expand Down
76 changes: 76 additions & 0 deletions tests/test_config_entries.py
Expand Up @@ -20,6 +20,7 @@
MockEntity,
mock_integration,
mock_entity_platform,
mock_registry,
)


Expand Down Expand Up @@ -925,3 +926,78 @@ async def test_init_custom_integration(hass):
return_value=mock_coro(integration),
):
await hass.config_entries.flow.async_init("bla")


async def test_support_entry_unload(hass):
"""Test unloading entry."""
assert await config_entries.support_entry_unload(hass, "light")
assert not await config_entries.support_entry_unload(hass, "auth")


async def test_reload_entry_entity_registry_ignores_no_entry(hass):
"""Test reloading entry in entity registry skips if no config entry linked."""
handler = config_entries.EntityRegistryDisabledHandler(hass)
registry = mock_registry(hass)

# Test we ignore entities without config entry
entry = registry.async_get_or_create("light", "hue", "123")
registry.async_update_entity(entry.entity_id, disabled_by="user")
await hass.async_block_till_done()
assert not handler.changed
assert handler._remove_call_later is None


async def test_reload_entry_entity_registry_works(hass):
"""Test we schedule an entry to be reloaded if disabled_by is updated."""
handler = config_entries.EntityRegistryDisabledHandler(hass)
handler.async_setup()
registry = mock_registry(hass)

config_entry = MockConfigEntry(
domain="comp", state=config_entries.ENTRY_STATE_LOADED
)
config_entry.add_to_hass(hass)
mock_setup_entry = MagicMock(return_value=mock_coro(True))
mock_unload_entry = MagicMock(return_value=mock_coro(True))
mock_integration(
hass,
MockModule(
"comp",
async_setup_entry=mock_setup_entry,
async_unload_entry=mock_unload_entry,
),
)
mock_entity_platform(hass, "config_flow.comp", None)

# Only changing disabled_by should update trigger
entity_entry = registry.async_get_or_create(
"light", "hue", "123", config_entry=config_entry
)
registry.async_update_entity(entity_entry.entity_id, name="yo")
await hass.async_block_till_done()
assert not handler.changed
assert handler._remove_call_later is None

# Disable entity, we should not do anything, only act when enabled.
registry.async_update_entity(entity_entry.entity_id, disabled_by="user")
await hass.async_block_till_done()
assert not handler.changed
assert handler._remove_call_later is None

# Enable entity, check we are reloading config entry.
registry.async_update_entity(entity_entry.entity_id, disabled_by=None)
await hass.async_block_till_done()
assert handler.changed == {config_entry.entry_id}
assert handler._remove_call_later is not None

async_fire_time_changed(
hass,
dt.utcnow()
+ timedelta(
seconds=config_entries.EntityRegistryDisabledHandler.RELOAD_AFTER_UPDATE_DELAY
+ 1
),
)
await hass.async_block_till_done()

assert len(mock_unload_entry.mock_calls) == 1