Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-add event listeners after Z-Wave server disconnection #94383

Merged
merged 3 commits into from
Jun 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions homeassistant/components/zwave_js/__init__.py
Expand Up @@ -215,6 +215,9 @@ async def handle_ha_shutdown(event: Event) -> None:
LOGGER.info("Connection to Zwave JS Server initialized")

assert client.driver
async_dispatcher_send(
hass, f"{DOMAIN}_{client.driver.controller.home_id}_connected_to_server"
)

await driver_events.setup(client.driver)

Expand Down
67 changes: 45 additions & 22 deletions homeassistant/components/zwave_js/triggers/event.py
@@ -1,18 +1,20 @@
"""Offer Z-Wave JS event listening automation trigger."""
from __future__ import annotations

from collections.abc import Callable
import functools

from pydantic import ValidationError
import voluptuous as vol
from zwave_js_server.client import Client
from zwave_js_server.model.controller import CONTROLLER_EVENT_MODEL_MAP
from zwave_js_server.model.driver import DRIVER_EVENT_MODEL_MAP
from zwave_js_server.model.driver import DRIVER_EVENT_MODEL_MAP, Driver
from zwave_js_server.model.node import NODE_EVENT_MODEL_MAP

from homeassistant.const import ATTR_DEVICE_ID, ATTR_ENTITY_ID, CONF_PLATFORM
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback
from homeassistant.helpers import config_validation as cv, device_registry as dr
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
from homeassistant.helpers.typing import ConfigType

Expand Down Expand Up @@ -150,7 +152,7 @@ async def async_attach_trigger(
event_name = config[ATTR_EVENT]
event_data_filter = config.get(ATTR_EVENT_DATA, {})

unsubs = []
unsubs: list[Callable] = []
job = HassJob(action)

trigger_data = trigger_info["trigger_data"]
Expand Down Expand Up @@ -199,31 +201,52 @@ def async_on_event(event_data: dict, device: dr.DeviceEntry | None = None) -> No

hass.async_run_hass_job(job, {"trigger": payload})

if not nodes:
entry_id = config[ATTR_CONFIG_ENTRY_ID]
client: Client = hass.data[DOMAIN][entry_id][DATA_CLIENT]
assert client.driver
if event_source == "controller":
unsubs.append(client.driver.controller.on(event_name, async_on_event))
else:
unsubs.append(client.driver.on(event_name, async_on_event))

for node in nodes:
driver = node.client.driver
assert driver is not None # The node comes from the driver.
device_identifier = get_device_id(driver, node)
device = dev_reg.async_get_device({device_identifier})
assert device
# We need to store the device for the callback
unsubs.append(
node.on(event_name, functools.partial(async_on_event, device=device))
)

@callback
def async_remove() -> None:
"""Remove state listeners async."""
for unsub in unsubs:
unsub()
unsubs.clear()

@callback
def _create_zwave_listeners() -> None:
"""Create Z-Wave JS listeners."""
async_remove()
# Nodes list can come from different drivers and we will need to listen to
# server connections for all of them.
drivers: set[Driver] = set()
if not nodes:
entry_id = config[ATTR_CONFIG_ENTRY_ID]
client: Client = hass.data[DOMAIN][entry_id][DATA_CLIENT]
driver = client.driver
assert driver
drivers.add(driver)
if event_source == "controller":
unsubs.append(driver.controller.on(event_name, async_on_event))
else:
unsubs.append(driver.on(event_name, async_on_event))

for node in nodes:
driver = node.client.driver
raman325 marked this conversation as resolved.
Show resolved Hide resolved
assert driver is not None # The node comes from the driver.
drivers.add(driver)
device_identifier = get_device_id(driver, node)
device = dev_reg.async_get_device({device_identifier})
assert device
# We need to store the device for the callback
unsubs.append(
node.on(event_name, functools.partial(async_on_event, device=device))
)

for driver in drivers:
unsubs.append(
async_dispatcher_connect(
hass,
f"{DOMAIN}_{driver.controller.home_id}_connected_to_server",
_create_zwave_listeners,
)
)

_create_zwave_listeners()

return async_remove
61 changes: 41 additions & 20 deletions homeassistant/components/zwave_js/triggers/value_updated.py
@@ -1,15 +1,18 @@
"""Offer Z-Wave JS value updated listening automation trigger."""
from __future__ import annotations

from collections.abc import Callable
import functools

import voluptuous as vol
from zwave_js_server.const import CommandClass
from zwave_js_server.model.driver import Driver
from zwave_js_server.model.value import Value, get_value_id_str

from homeassistant.const import ATTR_DEVICE_ID, ATTR_ENTITY_ID, CONF_PLATFORM, MATCH_ALL
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback
from homeassistant.helpers import config_validation as cv, device_registry as dr
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
from homeassistant.helpers.typing import ConfigType

Expand Down Expand Up @@ -99,7 +102,7 @@ async def async_attach_trigger(
property_ = config[ATTR_PROPERTY]
endpoint = config.get(ATTR_ENDPOINT)
property_key = config.get(ATTR_PROPERTY_KEY)
unsubs = []
unsubs: list[Callable] = []
job = HassJob(action)

trigger_data = trigger_info["trigger_data"]
Expand Down Expand Up @@ -153,34 +156,52 @@ def async_on_value_updated(
ATTR_PREVIOUS_VALUE_RAW: prev_value_raw,
ATTR_CURRENT_VALUE: curr_value,
ATTR_CURRENT_VALUE_RAW: curr_value_raw,
"description": f"Z-Wave value {value_id} updated on {device_name}",
"description": f"Z-Wave value {value.value_id} updated on {device_name}",
}

hass.async_run_hass_job(job, {"trigger": payload})

for node in nodes:
driver = node.client.driver
assert driver is not None # The node comes from the driver.
device_identifier = get_device_id(driver, node)
device = dev_reg.async_get_device({device_identifier})
assert device
value_id = get_value_id_str(
node, command_class, property_, endpoint, property_key
)
value = node.values[value_id]
# We need to store the current value and device for the callback
unsubs.append(
node.on(
"value updated",
functools.partial(async_on_value_updated, value, device),
)
)

@callback
def async_remove() -> None:
"""Remove state listeners async."""
for unsub in unsubs:
unsub()
unsubs.clear()

def _create_zwave_listeners() -> None:
"""Create Z-Wave JS listeners."""
async_remove()
# Nodes list can come from different drivers and we will need to listen to
# server connections for all of them.
drivers: set[Driver] = set()
for node in nodes:
driver = node.client.driver
assert driver is not None # The node comes from the driver.
drivers.add(driver)
device_identifier = get_device_id(driver, node)
device = dev_reg.async_get_device({device_identifier})
assert device
value_id = get_value_id_str(
node, command_class, property_, endpoint, property_key
)
value = node.values[value_id]
# We need to store the current value and device for the callback
unsubs.append(
node.on(
"value updated",
functools.partial(async_on_value_updated, value, device),
)
)

for driver in drivers:
unsubs.append(
async_dispatcher_connect(
hass,
f"{DOMAIN}_{driver.controller.home_id}_connected_to_server",
_create_zwave_listeners,
)
)

_create_zwave_listeners()

return async_remove
98 changes: 98 additions & 0 deletions tests/components/zwave_js/test_trigger.py
Expand Up @@ -1109,3 +1109,101 @@ def test_get_trigger_platform_failure() -> None:
"""Test _get_trigger_platform."""
with pytest.raises(ValueError):
_get_trigger_platform({CONF_PLATFORM: "zwave_js.invalid"})


async def test_server_reconnect_event(
hass: HomeAssistant, client, lock_schlage_be469, integration
) -> None:
"""Test that when we reconnect to server, event triggers reattach."""
trigger_type = f"{DOMAIN}.event"
node: Node = lock_schlage_be469
dev_reg = async_get_dev_reg(hass)
device = dev_reg.async_get_device(
{get_device_id(client.driver, lock_schlage_be469)}
)
assert device

event_name = "interview stage completed"

original_len = len(node._listeners.get(event_name, []))

assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: [
{
"trigger": {
"platform": trigger_type,
"entity_id": SCHLAGE_BE469_LOCK_ENTITY,
"event_source": "node",
"event": event_name,
},
"action": {
"event": "blah",
},
},
]
},
)

assert len(node._listeners.get(event_name, [])) == original_len + 1
old_listener = node._listeners.get(event_name, [])[original_len]

await hass.config_entries.async_reload(integration.entry_id)
await hass.async_block_till_done()

# Make sure there is still a listener added for the trigger
assert len(node._listeners.get(event_name, [])) == original_len + 1

# Make sure the old listener was removed
assert old_listener not in node._listeners.get(event_name, [])


async def test_server_reconnect_value_updated(
hass: HomeAssistant, client, lock_schlage_be469, integration
) -> None:
"""Test that when we reconnect to server, value_updated triggers reattach."""
trigger_type = f"{DOMAIN}.value_updated"
node: Node = lock_schlage_be469
dev_reg = async_get_dev_reg(hass)
device = dev_reg.async_get_device(
{get_device_id(client.driver, lock_schlage_be469)}
)
assert device

event_name = "value updated"

original_len = len(node._listeners.get(event_name, []))

assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: [
{
"trigger": {
"platform": trigger_type,
"entity_id": SCHLAGE_BE469_LOCK_ENTITY,
"command_class": CommandClass.DOOR_LOCK.value,
"property": "latchStatus",
},
"action": {
"event": "no_value_filter",
},
},
]
},
)

assert len(node._listeners.get(event_name, [])) == original_len + 1
old_listener = node._listeners.get(event_name, [])[original_len]

await hass.config_entries.async_reload(integration.entry_id)
await hass.async_block_till_done()

# Make sure there is still a listener added for the trigger
assert len(node._listeners.get(event_name, [])) == original_len + 1

# Make sure the old listener was removed
assert old_listener not in node._listeners.get(event_name, [])