Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for area ID in zwave_js service calls #54940

Merged
merged 1 commit into from Aug 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
67 changes: 50 additions & 17 deletions homeassistant/components/zwave_js/helpers.py
Expand Up @@ -13,14 +13,7 @@
from homeassistant.const import __version__ as HA_VERSION
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.device_registry import (
DeviceRegistry,
async_get as async_get_dev_reg,
)
from homeassistant.helpers.entity_registry import (
EntityRegistry,
async_get as async_get_ent_reg,
)
from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers.typing import ConfigType

from .const import (
Expand Down Expand Up @@ -79,15 +72,15 @@ def get_home_and_node_id_from_device_id(device_id: tuple[str, ...]) -> list[str]

@callback
def async_get_node_from_device_id(
hass: HomeAssistant, device_id: str, dev_reg: DeviceRegistry | None = None
hass: HomeAssistant, device_id: str, dev_reg: dr.DeviceRegistry | None = None
) -> ZwaveNode:
"""
Get node from a device ID.

Raises ValueError if device is invalid or node can't be found.
"""
if not dev_reg:
dev_reg = async_get_dev_reg(hass)
dev_reg = dr.async_get(hass)
device_entry = dev_reg.async_get(device_id)

if not device_entry:
Expand Down Expand Up @@ -138,16 +131,16 @@ def async_get_node_from_device_id(
def async_get_node_from_entity_id(
hass: HomeAssistant,
entity_id: str,
ent_reg: EntityRegistry | None = None,
dev_reg: DeviceRegistry | None = None,
ent_reg: er.EntityRegistry | None = None,
dev_reg: dr.DeviceRegistry | None = None,
) -> ZwaveNode:
"""
Get node from an entity ID.

Raises ValueError if entity is invalid.
"""
if not ent_reg:
ent_reg = async_get_ent_reg(hass)
ent_reg = er.async_get(hass)
entity_entry = ent_reg.async_get(entity_id)

if entity_entry is None or entity_entry.platform != DOMAIN:
Expand All @@ -159,6 +152,46 @@ def async_get_node_from_entity_id(
return async_get_node_from_device_id(hass, entity_entry.device_id, dev_reg)


@callback
def async_get_nodes_from_area_id(
hass: HomeAssistant,
area_id: str,
ent_reg: er.EntityRegistry | None = None,
dev_reg: dr.DeviceRegistry | None = None,
) -> set[ZwaveNode]:
"""Get nodes for all Z-Wave JS devices and entities that are in an area."""
nodes: set[ZwaveNode] = set()
if ent_reg is None:
ent_reg = er.async_get(hass)
if dev_reg is None:
dev_reg = dr.async_get(hass)
# Add devices for all entities in an area that are Z-Wave JS entities
nodes.update(
{
async_get_node_from_device_id(hass, entity.device_id, dev_reg)
for entity in er.async_entries_for_area(ent_reg, area_id)
if entity.platform == DOMAIN and entity.device_id is not None
}
)
# Add devices in an area that are Z-Wave JS devices
for device in dr.async_entries_for_area(dev_reg, area_id):
if next(
(
config_entry_id
for config_entry_id in device.config_entries
if cast(
ConfigEntry,
hass.config_entries.async_get_entry(config_entry_id),
).domain
== DOMAIN
),
None,
):
nodes.add(async_get_node_from_device_id(hass, device.id, dev_reg))

return nodes


def get_zwave_value_from_config(node: ZwaveNode, config: ConfigType) -> ZwaveValue:
"""Get a Z-Wave JS Value from a config."""
endpoint = None
Expand All @@ -183,14 +216,14 @@ def get_zwave_value_from_config(node: ZwaveNode, config: ConfigType) -> ZwaveVal
def async_get_node_status_sensor_entity_id(
hass: HomeAssistant,
device_id: str,
ent_reg: EntityRegistry | None = None,
dev_reg: DeviceRegistry | None = None,
ent_reg: er.EntityRegistry | None = None,
dev_reg: dr.DeviceRegistry | None = None,
) -> str:
"""Get the node status sensor entity ID for a given Z-Wave JS device."""
if not ent_reg:
ent_reg = async_get_ent_reg(hass)
ent_reg = er.async_get(hass)
if not dev_reg:
dev_reg = async_get_dev_reg(hass)
dev_reg = dr.async_get(hass)
device = dev_reg.async_get(device_id)
if not device:
raise HomeAssistantError("Invalid Device ID provided")
Expand Down
62 changes: 52 additions & 10 deletions homeassistant/components/zwave_js/services.py
Expand Up @@ -18,15 +18,18 @@
)

from homeassistant.components.group import expand_entity_ids
from homeassistant.const import ATTR_DEVICE_ID, ATTR_ENTITY_ID
from homeassistant.const import ATTR_AREA_ID, ATTR_DEVICE_ID, ATTR_ENTITY_ID
from homeassistant.core import HomeAssistant, ServiceCall, callback
from homeassistant.helpers import device_registry as dr, entity_registry as er
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.device_registry import DeviceRegistry
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.entity_registry import EntityRegistry

from . import const
from .helpers import async_get_node_from_device_id, async_get_node_from_entity_id
from .helpers import (
async_get_node_from_device_id,
async_get_node_from_entity_id,
async_get_nodes_from_area_id,
)

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -81,7 +84,10 @@ class ZWaveServices:
"""Class that holds our services (Zwave Commands) that should be published to hass."""

def __init__(
self, hass: HomeAssistant, ent_reg: EntityRegistry, dev_reg: DeviceRegistry
self,
hass: HomeAssistant,
ent_reg: er.EntityRegistry,
dev_reg: dr.DeviceRegistry,
) -> None:
"""Initialize with hass object."""
self._hass = hass
Expand All @@ -96,6 +102,7 @@ def async_register(self) -> None:
def get_nodes_from_service_data(val: dict[str, Any]) -> dict[str, Any]:
"""Get nodes set from service data."""
nodes: set[ZwaveNode] = set()
# Convert all entity IDs to nodes
for entity_id in expand_entity_ids(self._hass, val.pop(ATTR_ENTITY_ID, [])):
try:
nodes.add(
Expand All @@ -105,6 +112,16 @@ def get_nodes_from_service_data(val: dict[str, Any]) -> dict[str, Any]:
)
except ValueError as err:
const.LOGGER.warning(err.args[0])

# Convert all area IDs to nodes
for area_id in val.pop(ATTR_AREA_ID, []):
nodes.update(
async_get_nodes_from_area_id(
self._hass, area_id, self._ent_reg, self._dev_reg
)
)

# Convert all device IDs to nodes
for device_id in val.pop(ATTR_DEVICE_ID, []):
try:
nodes.add(
Expand Down Expand Up @@ -170,6 +187,9 @@ def validate_entities(val: dict[str, Any]) -> dict[str, Any]:
schema=vol.Schema(
vol.All(
{
vol.Optional(ATTR_AREA_ID): vol.All(
cv.ensure_list, [cv.string]
),
vol.Optional(ATTR_DEVICE_ID): vol.All(
cv.ensure_list, [cv.string]
),
Expand All @@ -184,7 +204,9 @@ def validate_entities(val: dict[str, Any]) -> dict[str, Any]:
vol.Coerce(int), BITMASK_SCHEMA, cv.string
),
},
cv.has_at_least_one_key(ATTR_DEVICE_ID, ATTR_ENTITY_ID),
cv.has_at_least_one_key(
ATTR_DEVICE_ID, ATTR_ENTITY_ID, ATTR_AREA_ID
),
parameter_name_does_not_need_bitmask,
get_nodes_from_service_data,
),
Expand All @@ -198,6 +220,9 @@ def validate_entities(val: dict[str, Any]) -> dict[str, Any]:
schema=vol.Schema(
vol.All(
{
vol.Optional(ATTR_AREA_ID): vol.All(
cv.ensure_list, [cv.string]
),
vol.Optional(ATTR_DEVICE_ID): vol.All(
cv.ensure_list, [cv.string]
),
Expand All @@ -212,7 +237,9 @@ def validate_entities(val: dict[str, Any]) -> dict[str, Any]:
},
),
},
cv.has_at_least_one_key(ATTR_DEVICE_ID, ATTR_ENTITY_ID),
cv.has_at_least_one_key(
ATTR_DEVICE_ID, ATTR_ENTITY_ID, ATTR_AREA_ID
),
get_nodes_from_service_data,
),
),
Expand Down Expand Up @@ -242,6 +269,9 @@ def validate_entities(val: dict[str, Any]) -> dict[str, Any]:
schema=vol.Schema(
vol.All(
{
vol.Optional(ATTR_AREA_ID): vol.All(
cv.ensure_list, [cv.string]
),
vol.Optional(ATTR_DEVICE_ID): vol.All(
cv.ensure_list, [cv.string]
),
Expand All @@ -258,7 +288,9 @@ def validate_entities(val: dict[str, Any]) -> dict[str, Any]:
vol.Optional(const.ATTR_WAIT_FOR_RESULT): cv.boolean,
vol.Optional(const.ATTR_OPTIONS): {cv.string: VALUE_SCHEMA},
},
cv.has_at_least_one_key(ATTR_DEVICE_ID, ATTR_ENTITY_ID),
cv.has_at_least_one_key(
ATTR_DEVICE_ID, ATTR_ENTITY_ID, ATTR_AREA_ID
),
get_nodes_from_service_data,
),
),
Expand All @@ -271,6 +303,9 @@ def validate_entities(val: dict[str, Any]) -> dict[str, Any]:
schema=vol.Schema(
vol.All(
{
vol.Optional(ATTR_AREA_ID): vol.All(
cv.ensure_list, [cv.string]
),
vol.Optional(ATTR_DEVICE_ID): vol.All(
cv.ensure_list, [cv.string]
),
Expand All @@ -288,7 +323,9 @@ def validate_entities(val: dict[str, Any]) -> dict[str, Any]:
vol.Optional(const.ATTR_OPTIONS): {cv.string: VALUE_SCHEMA},
},
vol.Any(
cv.has_at_least_one_key(ATTR_DEVICE_ID, ATTR_ENTITY_ID),
cv.has_at_least_one_key(
ATTR_DEVICE_ID, ATTR_ENTITY_ID, ATTR_AREA_ID
),
broadcast_command,
),
get_nodes_from_service_data,
Expand All @@ -304,12 +341,17 @@ def validate_entities(val: dict[str, Any]) -> dict[str, Any]:
schema=vol.Schema(
vol.All(
{
vol.Optional(ATTR_AREA_ID): vol.All(
cv.ensure_list, [cv.string]
),
vol.Optional(ATTR_DEVICE_ID): vol.All(
cv.ensure_list, [cv.string]
),
vol.Optional(ATTR_ENTITY_ID): cv.entity_ids,
},
cv.has_at_least_one_key(ATTR_DEVICE_ID, ATTR_ENTITY_ID),
cv.has_at_least_one_key(
ATTR_DEVICE_ID, ATTR_ENTITY_ID, ATTR_AREA_ID
),
get_nodes_from_service_data,
),
),
Expand Down