Skip to content

Commit

Permalink
Add async_track_state_added_domain for tracking when states are added…
Browse files Browse the repository at this point in the history
… to a domain (#38776)

* Fire event_state_added when a state is added after start

* async_track_state_added_domain

* test

* naming

* coverage
  • Loading branch information
bdraco committed Aug 12, 2020
1 parent 716fa63 commit 45526f4
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 12 deletions.
91 changes: 79 additions & 12 deletions homeassistant/helpers/event.py
Expand Up @@ -17,7 +17,14 @@
SUN_EVENT_SUNRISE,
SUN_EVENT_SUNSET,
)
from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, State, callback
from homeassistant.core import (
CALLBACK_TYPE,
Event,
HomeAssistant,
State,
callback,
split_entity_id,
)
from homeassistant.helpers.entity_registry import EVENT_ENTITY_REGISTRY_UPDATED
from homeassistant.helpers.sun import get_astral_event_next
from homeassistant.helpers.template import Template
Expand All @@ -28,6 +35,9 @@
TRACK_STATE_CHANGE_CALLBACKS = "track_state_change_callbacks"
TRACK_STATE_CHANGE_LISTENER = "track_state_change_listener"

TRACK_STATE_ADDED_DOMAIN_CALLBACKS = "track_state_added_domain_callbacks"
TRACK_STATE_ADDED_DOMAIN_LISTENER = "track_state_added_domain_listener"

TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS = "track_entity_registry_updated_callbacks"
TRACK_ENTITY_REGISTRY_UPDATED_LISTENER = "track_entity_registry_updated_listener"

Expand Down Expand Up @@ -191,7 +201,7 @@ def _async_state_change_dispatcher(event: Event) -> None:
@callback
def remove_listener() -> None:
"""Remove state change listener."""
_async_remove_entity_listeners(
_async_remove_indexed_listeners(
hass,
TRACK_STATE_CHANGE_CALLBACKS,
TRACK_STATE_CHANGE_LISTENER,
Expand All @@ -203,23 +213,23 @@ def remove_listener() -> None:


@callback
def _async_remove_entity_listeners(
def _async_remove_indexed_listeners(
hass: HomeAssistant,
storage_key: str,
data_key: str,
listener_key: str,
entity_ids: Iterable[str],
storage_keys: Iterable[str],
action: Callable[[Event], Any],
) -> None:
"""Remove a listener."""

entity_callbacks = hass.data[storage_key]
callbacks = hass.data[data_key]

for entity_id in entity_ids:
entity_callbacks[entity_id].remove(action)
if len(entity_callbacks[entity_id]) == 0:
del entity_callbacks[entity_id]
for storage_key in storage_keys:
callbacks[storage_key].remove(action)
if len(callbacks[storage_key]) == 0:
del callbacks[storage_key]

if not entity_callbacks:
if not callbacks:
hass.data[listener_key]()
del hass.data[listener_key]

Expand Down Expand Up @@ -271,7 +281,7 @@ def _async_entity_registry_updated_dispatcher(event: Event) -> None:
@callback
def remove_listener() -> None:
"""Remove state change listener."""
_async_remove_entity_listeners(
_async_remove_indexed_listeners(
hass,
TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS,
TRACK_ENTITY_REGISTRY_UPDATED_LISTENER,
Expand All @@ -282,6 +292,63 @@ def remove_listener() -> None:
return remove_listener


@bind_hass
def async_track_state_added_domain(
hass: HomeAssistant,
domains: Union[str, Iterable[str]],
action: Callable[[Event], Any],
) -> Callable[[], None]:
"""Track state change events when an entity is added to domains."""

domain_callbacks = hass.data.setdefault(TRACK_STATE_ADDED_DOMAIN_CALLBACKS, {})

if TRACK_STATE_ADDED_DOMAIN_LISTENER not in hass.data:

@callback
def _async_state_change_dispatcher(event: Event) -> None:
"""Dispatch state changes by entity_id."""
if event.data.get("old_state") is not None:
return

domain = split_entity_id(event.data["entity_id"])[0]

if domain not in domain_callbacks:
return

for action in domain_callbacks[domain][:]:
try:
hass.async_run_job(action, event)
except Exception: # pylint: disable=broad-except
_LOGGER.exception(
"Error while processing state added for %s", domain
)

hass.data[TRACK_STATE_ADDED_DOMAIN_LISTENER] = hass.bus.async_listen(
EVENT_STATE_CHANGED, _async_state_change_dispatcher
)

if isinstance(domains, str):
domains = [domains]

domains = [domains.lower() for domains in domains]

for domain in domains:
domain_callbacks.setdefault(domain, []).append(action)

@callback
def remove_listener() -> None:
"""Remove state change listener."""
_async_remove_indexed_listeners(
hass,
TRACK_STATE_ADDED_DOMAIN_CALLBACKS,
TRACK_STATE_ADDED_DOMAIN_LISTENER,
domains,
action,
)

return remove_listener


@callback
@bind_hass
def async_track_template(
Expand Down
83 changes: 83 additions & 0 deletions tests/helpers/test_event.py
Expand Up @@ -16,6 +16,7 @@
async_track_point_in_time,
async_track_point_in_utc_time,
async_track_same_state,
async_track_state_added_domain,
async_track_state_change,
async_track_state_change_event,
async_track_sunrise,
Expand Down Expand Up @@ -341,6 +342,88 @@ def callback_that_throws(event):
unsub_throws()


async def test_async_track_state_added_domain(hass):
"""Test async_track_state_added_domain."""
single_entity_id_tracker = []
multiple_entity_id_tracker = []

@ha.callback
def single_run_callback(event):
old_state = event.data.get("old_state")
new_state = event.data.get("new_state")

single_entity_id_tracker.append((old_state, new_state))

@ha.callback
def multiple_run_callback(event):
old_state = event.data.get("old_state")
new_state = event.data.get("new_state")

multiple_entity_id_tracker.append((old_state, new_state))

@ha.callback
def callback_that_throws(event):
raise ValueError

unsub_single = async_track_state_added_domain(hass, "light", single_run_callback)
unsub_multi = async_track_state_added_domain(
hass, ["light", "switch"], multiple_run_callback
)
unsub_throws = async_track_state_added_domain(
hass, ["light", "switch"], callback_that_throws
)

# Adding state to state machine
hass.states.async_set("light.Bowl", "on")
await hass.async_block_till_done()
assert len(single_entity_id_tracker) == 1
assert single_entity_id_tracker[-1][0] is None
assert single_entity_id_tracker[-1][1] is not None
assert len(multiple_entity_id_tracker) == 1
assert multiple_entity_id_tracker[-1][0] is None
assert multiple_entity_id_tracker[-1][1] is not None

# Set same state should not trigger a state change/listener
hass.states.async_set("light.Bowl", "on")
await hass.async_block_till_done()
assert len(single_entity_id_tracker) == 1
assert len(multiple_entity_id_tracker) == 1

# State change off -> on - nothing added so no trigger
hass.states.async_set("light.Bowl", "off")
await hass.async_block_till_done()
assert len(single_entity_id_tracker) == 1
assert len(multiple_entity_id_tracker) == 1

# State change off -> off - nothing added so no trigger
hass.states.async_set("light.Bowl", "off", {"some_attr": 1})
await hass.async_block_till_done()
assert len(single_entity_id_tracker) == 1
assert len(multiple_entity_id_tracker) == 1

# Removing state does not trigger
hass.states.async_remove("light.bowl")
await hass.async_block_till_done()
assert len(single_entity_id_tracker) == 1
assert len(multiple_entity_id_tracker) == 1

# Set state for different entity id
hass.states.async_set("switch.kitchen", "on")
await hass.async_block_till_done()
assert len(single_entity_id_tracker) == 1
assert len(multiple_entity_id_tracker) == 2

unsub_single()
# Ensure unsubing the listener works
hass.states.async_set("light.new", "off")
await hass.async_block_till_done()
assert len(single_entity_id_tracker) == 1
assert len(multiple_entity_id_tracker) == 3

unsub_multi()
unsub_throws()


async def test_track_template(hass):
"""Test tracking template."""
specific_runs = []
Expand Down

0 comments on commit 45526f4

Please sign in to comment.