Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add async_track_state_added_domain for tracking when states are added to a domain #38776

Merged
merged 5 commits into from Aug 12, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
89 changes: 78 additions & 11 deletions homeassistant/helpers/event.py
Expand Up @@ -17,7 +17,14 @@
SUN_EVENT_SUNRISE,
SUN_EVENT_SUNSET,
)
from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, State, callback
from homeassistant.core import (
CALLBACK_TYPE,
Event,
HomeAssistant,
State,
callback,
split_entity_id,
)
from homeassistant.helpers.entity_registry import EVENT_ENTITY_REGISTRY_UPDATED
from homeassistant.helpers.sun import get_astral_event_next
from homeassistant.helpers.template import Template
Expand All @@ -28,6 +35,9 @@
TRACK_STATE_CHANGE_CALLBACKS = "track_state_change_callbacks"
TRACK_STATE_CHANGE_LISTENER = "track_state_change_listener"

TRACK_STATE_ADDED_DOMAIN_CALLBACKS = "track_state_added_domain_callbacks"
TRACK_STATE_ADDED_DOMAIN_LISTENER = "track_state_added_domain_listener"

TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS = "track_entity_registry_updated_callbacks"
TRACK_ENTITY_REGISTRY_UPDATED_LISTENER = "track_entity_registry_updated_listener"

Expand Down Expand Up @@ -191,7 +201,7 @@ def _async_state_change_dispatcher(event: Event) -> None:
@callback
def remove_listener() -> None:
"""Remove state change listener."""
_async_remove_entity_listeners(
_async_remove_indexed_listeners(
hass,
TRACK_STATE_CHANGE_CALLBACKS,
TRACK_STATE_CHANGE_LISTENER,
Expand All @@ -203,23 +213,23 @@ def remove_listener() -> None:


@callback
def _async_remove_entity_listeners(
def _async_remove_indexed_listeners(
hass: HomeAssistant,
storage_key: str,
listener_key: str,
entity_ids: Iterable[str],
storage_keys: Iterable[str],
action: Callable[[Event], Any],
) -> None:
"""Remove a listener."""

entity_callbacks = hass.data[storage_key]
callbacks = hass.data[storage_key]

for entity_id in entity_ids:
entity_callbacks[entity_id].remove(action)
if len(entity_callbacks[entity_id]) == 0:
del entity_callbacks[entity_id]
for storage_key in storage_keys:
callbacks[storage_key].remove(action)
if len(callbacks[storage_key]) == 0:
del callbacks[storage_key]

if not entity_callbacks:
if not callbacks:
hass.data[listener_key]()
del hass.data[listener_key]

Expand Down Expand Up @@ -271,7 +281,7 @@ def _async_entity_registry_updated_dispatcher(event: Event) -> None:
@callback
def remove_listener() -> None:
"""Remove state change listener."""
_async_remove_entity_listeners(
_async_remove_indexed_listeners(
hass,
TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS,
TRACK_ENTITY_REGISTRY_UPDATED_LISTENER,
Expand All @@ -282,6 +292,63 @@ def remove_listener() -> None:
return remove_listener


@bind_hass
def async_track_state_added_domain(
hass: HomeAssistant,
domains: Union[str, Iterable[str]],
action: Callable[[Event], Any],
) -> Callable[[], None]:
"""Track state change events when an entity is added to domains."""

domain_callbacks = hass.data.setdefault(TRACK_STATE_ADDED_DOMAIN_CALLBACKS, {})

if TRACK_STATE_ADDED_DOMAIN_LISTENER not in hass.data:

@callback
def _async_state_change_dispatcher(event: Event) -> None:
"""Dispatch state changes by entity_id."""
if event.data.get("old_state") is not None:
return

domain = split_entity_id(event.data["entity_id"])[0]

if domain not in domain_callbacks:
return

for action in domain_callbacks[domain][:]:
try:
hass.async_run_job(action, event)
except Exception: # pylint: disable=broad-except
_LOGGER.exception(
"Error while processing state added for %s", domain
)

hass.data[TRACK_STATE_ADDED_DOMAIN_LISTENER] = hass.bus.async_listen(
EVENT_STATE_CHANGED, _async_state_change_dispatcher
)

if isinstance(domains, str):
domains = [domains]

domains = [domains.lower() for domains in domains]

for domain in domains:
domain_callbacks.setdefault(domain, []).append(action)

@callback
def remove_listener() -> None:
"""Remove state change listener."""
_async_remove_indexed_listeners(
hass,
TRACK_STATE_ADDED_DOMAIN_CALLBACKS,
TRACK_STATE_ADDED_DOMAIN_LISTENER,
domains,
action,
)

return remove_listener


@callback
@bind_hass
def async_track_template(
Expand Down
83 changes: 83 additions & 0 deletions tests/helpers/test_event.py
Expand Up @@ -16,6 +16,7 @@
async_track_point_in_time,
async_track_point_in_utc_time,
async_track_same_state,
async_track_state_added_domain,
async_track_state_change,
async_track_state_change_event,
async_track_sunrise,
Expand Down Expand Up @@ -341,6 +342,88 @@ def callback_that_throws(event):
unsub_throws()


async def test_async_track_state_added_domain(hass):
"""Test async_track_state_added_domain."""
single_entity_id_tracker = []
multiple_entity_id_tracker = []

@ha.callback
def single_run_callback(event):
old_state = event.data.get("old_state")
new_state = event.data.get("new_state")

single_entity_id_tracker.append((old_state, new_state))

@ha.callback
def multiple_run_callback(event):
old_state = event.data.get("old_state")
new_state = event.data.get("new_state")

multiple_entity_id_tracker.append((old_state, new_state))

@ha.callback
def callback_that_throws(event):
raise ValueError

unsub_single = async_track_state_added_domain(hass, ["light"], single_run_callback)
unsub_multi = async_track_state_added_domain(
hass, ["light", "switch"], multiple_run_callback
)
unsub_throws = async_track_state_added_domain(
hass, ["light", "switch"], callback_that_throws
)

# Adding state to state machine
hass.states.async_set("light.Bowl", "on")
await hass.async_block_till_done()
assert len(single_entity_id_tracker) == 1
assert single_entity_id_tracker[-1][0] is None
assert single_entity_id_tracker[-1][1] is not None
assert len(multiple_entity_id_tracker) == 1
assert multiple_entity_id_tracker[-1][0] is None
assert multiple_entity_id_tracker[-1][1] is not None

# Set same state should not trigger a state change/listener
hass.states.async_set("light.Bowl", "on")
await hass.async_block_till_done()
assert len(single_entity_id_tracker) == 1
assert len(multiple_entity_id_tracker) == 1

# State change off -> on - nothing added so no trigger
hass.states.async_set("light.Bowl", "off")
await hass.async_block_till_done()
assert len(single_entity_id_tracker) == 1
assert len(multiple_entity_id_tracker) == 1

# State change off -> off - nothing added so no trigger
hass.states.async_set("light.Bowl", "off", {"some_attr": 1})
await hass.async_block_till_done()
assert len(single_entity_id_tracker) == 1
assert len(multiple_entity_id_tracker) == 1

# Removing state does not trigger
hass.states.async_remove("light.bowl")
await hass.async_block_till_done()
assert len(single_entity_id_tracker) == 1
assert len(multiple_entity_id_tracker) == 1

# Set state for different entity id
hass.states.async_set("switch.kitchen", "on")
await hass.async_block_till_done()
assert len(single_entity_id_tracker) == 1
assert len(multiple_entity_id_tracker) == 2

unsub_single()
# Ensure unsubing the listener works
hass.states.async_set("light.new", "off")
await hass.async_block_till_done()
assert len(single_entity_id_tracker) == 1
assert len(multiple_entity_id_tracker) == 3

unsub_multi()
unsub_throws()


async def test_track_template(hass):
"""Test tracking template."""
specific_runs = []
Expand Down