Skip to content

Commit

Permalink
Re-add event listeners after Z-Wave server disconnection (#94383)
Browse files Browse the repository at this point in the history
* Re-add event listeners after Z-Wave server disconnection

* switch order

* Add tests
  • Loading branch information
raman325 committed Jun 11, 2023
1 parent eab0249 commit 41d8ba3
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 42 deletions.
3 changes: 3 additions & 0 deletions homeassistant/components/zwave_js/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ async def handle_ha_shutdown(event: Event) -> None:
LOGGER.info("Connection to Zwave JS Server initialized")

assert client.driver
async_dispatcher_send(
hass, f"{DOMAIN}_{client.driver.controller.home_id}_connected_to_server"
)

await driver_events.setup(client.driver)

Expand Down
67 changes: 45 additions & 22 deletions homeassistant/components/zwave_js/triggers/event.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
"""Offer Z-Wave JS event listening automation trigger."""
from __future__ import annotations

from collections.abc import Callable
import functools

from pydantic import ValidationError
import voluptuous as vol
from zwave_js_server.client import Client
from zwave_js_server.model.controller import CONTROLLER_EVENT_MODEL_MAP
from zwave_js_server.model.driver import DRIVER_EVENT_MODEL_MAP
from zwave_js_server.model.driver import DRIVER_EVENT_MODEL_MAP, Driver
from zwave_js_server.model.node import NODE_EVENT_MODEL_MAP

from homeassistant.const import ATTR_DEVICE_ID, ATTR_ENTITY_ID, CONF_PLATFORM
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback
from homeassistant.helpers import config_validation as cv, device_registry as dr
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
from homeassistant.helpers.typing import ConfigType

Expand Down Expand Up @@ -150,7 +152,7 @@ async def async_attach_trigger(
event_name = config[ATTR_EVENT]
event_data_filter = config.get(ATTR_EVENT_DATA, {})

unsubs = []
unsubs: list[Callable] = []
job = HassJob(action)

trigger_data = trigger_info["trigger_data"]
Expand Down Expand Up @@ -199,31 +201,52 @@ def async_on_event(event_data: dict, device: dr.DeviceEntry | None = None) -> No

hass.async_run_hass_job(job, {"trigger": payload})

if not nodes:
entry_id = config[ATTR_CONFIG_ENTRY_ID]
client: Client = hass.data[DOMAIN][entry_id][DATA_CLIENT]
assert client.driver
if event_source == "controller":
unsubs.append(client.driver.controller.on(event_name, async_on_event))
else:
unsubs.append(client.driver.on(event_name, async_on_event))

for node in nodes:
driver = node.client.driver
assert driver is not None # The node comes from the driver.
device_identifier = get_device_id(driver, node)
device = dev_reg.async_get_device({device_identifier})
assert device
# We need to store the device for the callback
unsubs.append(
node.on(event_name, functools.partial(async_on_event, device=device))
)

@callback
def async_remove() -> None:
"""Remove state listeners async."""
for unsub in unsubs:
unsub()
unsubs.clear()

@callback
def _create_zwave_listeners() -> None:
"""Create Z-Wave JS listeners."""
async_remove()
# Nodes list can come from different drivers and we will need to listen to
# server connections for all of them.
drivers: set[Driver] = set()
if not nodes:
entry_id = config[ATTR_CONFIG_ENTRY_ID]
client: Client = hass.data[DOMAIN][entry_id][DATA_CLIENT]
driver = client.driver
assert driver
drivers.add(driver)
if event_source == "controller":
unsubs.append(driver.controller.on(event_name, async_on_event))
else:
unsubs.append(driver.on(event_name, async_on_event))

for node in nodes:
driver = node.client.driver
assert driver is not None # The node comes from the driver.
drivers.add(driver)
device_identifier = get_device_id(driver, node)
device = dev_reg.async_get_device({device_identifier})
assert device
# We need to store the device for the callback
unsubs.append(
node.on(event_name, functools.partial(async_on_event, device=device))
)

for driver in drivers:
unsubs.append(
async_dispatcher_connect(
hass,
f"{DOMAIN}_{driver.controller.home_id}_connected_to_server",
_create_zwave_listeners,
)
)

_create_zwave_listeners()

return async_remove
61 changes: 41 additions & 20 deletions homeassistant/components/zwave_js/triggers/value_updated.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
"""Offer Z-Wave JS value updated listening automation trigger."""
from __future__ import annotations

from collections.abc import Callable
import functools

import voluptuous as vol
from zwave_js_server.const import CommandClass
from zwave_js_server.model.driver import Driver
from zwave_js_server.model.value import Value, get_value_id_str

from homeassistant.const import ATTR_DEVICE_ID, ATTR_ENTITY_ID, CONF_PLATFORM, MATCH_ALL
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback
from homeassistant.helpers import config_validation as cv, device_registry as dr
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
from homeassistant.helpers.typing import ConfigType

Expand Down Expand Up @@ -99,7 +102,7 @@ async def async_attach_trigger(
property_ = config[ATTR_PROPERTY]
endpoint = config.get(ATTR_ENDPOINT)
property_key = config.get(ATTR_PROPERTY_KEY)
unsubs = []
unsubs: list[Callable] = []
job = HassJob(action)

trigger_data = trigger_info["trigger_data"]
Expand Down Expand Up @@ -153,34 +156,52 @@ def async_on_value_updated(
ATTR_PREVIOUS_VALUE_RAW: prev_value_raw,
ATTR_CURRENT_VALUE: curr_value,
ATTR_CURRENT_VALUE_RAW: curr_value_raw,
"description": f"Z-Wave value {value_id} updated on {device_name}",
"description": f"Z-Wave value {value.value_id} updated on {device_name}",
}

hass.async_run_hass_job(job, {"trigger": payload})

for node in nodes:
driver = node.client.driver
assert driver is not None # The node comes from the driver.
device_identifier = get_device_id(driver, node)
device = dev_reg.async_get_device({device_identifier})
assert device
value_id = get_value_id_str(
node, command_class, property_, endpoint, property_key
)
value = node.values[value_id]
# We need to store the current value and device for the callback
unsubs.append(
node.on(
"value updated",
functools.partial(async_on_value_updated, value, device),
)
)

@callback
def async_remove() -> None:
"""Remove state listeners async."""
for unsub in unsubs:
unsub()
unsubs.clear()

def _create_zwave_listeners() -> None:
"""Create Z-Wave JS listeners."""
async_remove()
# Nodes list can come from different drivers and we will need to listen to
# server connections for all of them.
drivers: set[Driver] = set()
for node in nodes:
driver = node.client.driver
assert driver is not None # The node comes from the driver.
drivers.add(driver)
device_identifier = get_device_id(driver, node)
device = dev_reg.async_get_device({device_identifier})
assert device
value_id = get_value_id_str(
node, command_class, property_, endpoint, property_key
)
value = node.values[value_id]
# We need to store the current value and device for the callback
unsubs.append(
node.on(
"value updated",
functools.partial(async_on_value_updated, value, device),
)
)

for driver in drivers:
unsubs.append(
async_dispatcher_connect(
hass,
f"{DOMAIN}_{driver.controller.home_id}_connected_to_server",
_create_zwave_listeners,
)
)

_create_zwave_listeners()

return async_remove
98 changes: 98 additions & 0 deletions tests/components/zwave_js/test_trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,3 +1109,101 @@ def test_get_trigger_platform_failure() -> None:
"""Test _get_trigger_platform."""
with pytest.raises(ValueError):
_get_trigger_platform({CONF_PLATFORM: "zwave_js.invalid"})


async def test_server_reconnect_event(
hass: HomeAssistant, client, lock_schlage_be469, integration
) -> None:
"""Test that when we reconnect to server, event triggers reattach."""
trigger_type = f"{DOMAIN}.event"
node: Node = lock_schlage_be469
dev_reg = async_get_dev_reg(hass)
device = dev_reg.async_get_device(
{get_device_id(client.driver, lock_schlage_be469)}
)
assert device

event_name = "interview stage completed"

original_len = len(node._listeners.get(event_name, []))

assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: [
{
"trigger": {
"platform": trigger_type,
"entity_id": SCHLAGE_BE469_LOCK_ENTITY,
"event_source": "node",
"event": event_name,
},
"action": {
"event": "blah",
},
},
]
},
)

assert len(node._listeners.get(event_name, [])) == original_len + 1
old_listener = node._listeners.get(event_name, [])[original_len]

await hass.config_entries.async_reload(integration.entry_id)
await hass.async_block_till_done()

# Make sure there is still a listener added for the trigger
assert len(node._listeners.get(event_name, [])) == original_len + 1

# Make sure the old listener was removed
assert old_listener not in node._listeners.get(event_name, [])


async def test_server_reconnect_value_updated(
hass: HomeAssistant, client, lock_schlage_be469, integration
) -> None:
"""Test that when we reconnect to server, value_updated triggers reattach."""
trigger_type = f"{DOMAIN}.value_updated"
node: Node = lock_schlage_be469
dev_reg = async_get_dev_reg(hass)
device = dev_reg.async_get_device(
{get_device_id(client.driver, lock_schlage_be469)}
)
assert device

event_name = "value updated"

original_len = len(node._listeners.get(event_name, []))

assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: [
{
"trigger": {
"platform": trigger_type,
"entity_id": SCHLAGE_BE469_LOCK_ENTITY,
"command_class": CommandClass.DOOR_LOCK.value,
"property": "latchStatus",
},
"action": {
"event": "no_value_filter",
},
},
]
},
)

assert len(node._listeners.get(event_name, [])) == original_len + 1
old_listener = node._listeners.get(event_name, [])[original_len]

await hass.config_entries.async_reload(integration.entry_id)
await hass.async_block_till_done()

# Make sure there is still a listener added for the trigger
assert len(node._listeners.get(event_name, [])) == original_len + 1

# Make sure the old listener was removed
assert old_listener not in node._listeners.get(event_name, [])

0 comments on commit 41d8ba3

Please sign in to comment.