Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve validation of device automations #102766

Merged
merged 3 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions homeassistant/components/device_automation/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

import voluptuous as vol

from homeassistant.const import CONF_DEVICE_ID, CONF_DOMAIN, Platform
from homeassistant.const import CONF_DEVICE_ID, CONF_DOMAIN, CONF_ENTITY_ID, Platform
from homeassistant.core import HomeAssistant
from homeassistant.helpers import device_registry as dr
from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers.typing import ConfigType

from . import DeviceAutomationType, async_get_device_automation_platform
Expand Down Expand Up @@ -55,31 +55,42 @@ async def async_validate_device_automation_config(
platform = await async_get_device_automation_platform(
hass, validated_config[CONF_DOMAIN], automation_type
)

# Make sure the referenced device and optional entity exist
device_registry = dr.async_get(hass)
if not (device := device_registry.async_get(validated_config[CONF_DEVICE_ID])):
# The device referenced by the device automation does not exist
raise InvalidDeviceAutomationConfig(
f"Unknown device '{validated_config[CONF_DEVICE_ID]}'"
)
if entity_id := validated_config.get(CONF_ENTITY_ID):
try:
er.async_validate_entity_id(er.async_get(hass), entity_id)
except vol.Invalid as err:
raise InvalidDeviceAutomationConfig(
f"Unknown entity '{entity_id}'"
) from err

if not hasattr(platform, DYNAMIC_VALIDATOR[automation_type]):
# Pass the unvalidated config to avoid mutating the raw config twice
return cast(
ConfigType, getattr(platform, STATIC_VALIDATOR[automation_type])(config)
)

# Bypass checks for entity platforms
# Devices are not linked to config entries from entity platform domains, skip
# the checks below which look for a config entry matching the device automation
# domain
if (
automation_type == DeviceAutomationType.ACTION
and validated_config[CONF_DOMAIN] in ENTITY_PLATFORMS
):
# Pass the unvalidated config to avoid mutating the raw config twice
return cast(
ConfigType,
await getattr(platform, DYNAMIC_VALIDATOR[automation_type])(hass, config),
)

# Only call the dynamic validator if the referenced device exists and the relevant
# config entry is loaded
registry = dr.async_get(hass)
if not (device := registry.async_get(validated_config[CONF_DEVICE_ID])):
# The device referenced by the device automation does not exist
raise InvalidDeviceAutomationConfig(
f"Unknown device '{validated_config[CONF_DEVICE_ID]}'"
)

# Find a config entry with the same domain as the device automation
device_config_entry = None
for entry_id in device.config_entries:
if (
Expand All @@ -91,7 +102,7 @@ async def async_validate_device_automation_config(
break

if not device_config_entry:
# The config entry referenced by the device automation does not exist
# There's no config entry with the same domain as the device automation
raise InvalidDeviceAutomationConfig(
f"Device '{validated_config[CONF_DEVICE_ID]}' has no config entry from "
f"domain '{validated_config[CONF_DOMAIN]}'"
Expand Down
146 changes: 108 additions & 38 deletions tests/components/device_automation/test_init.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""The test for light device automation."""
from unittest.mock import AsyncMock, Mock, patch

import attr
import pytest
from pytest_unordered import unordered
import voluptuous as vol
Expand Down Expand Up @@ -31,6 +32,13 @@
from tests.typing import WebSocketGenerator


@attr.s(frozen=True)
class MockDeviceEntry(dr.DeviceEntry):
"""Device Registry Entry with fixed UUID."""

id: str = attr.ib(default="very_unique")


@pytest.fixture(autouse=True, name="stub_blueprint_populate")
def stub_blueprint_populate_autouse(stub_blueprint_populate: None) -> None:
"""Stub copying the blueprints to the config folder."""
Expand Down Expand Up @@ -1240,17 +1248,56 @@ async def test_automation_with_integration_without_device_trigger(
)


BAD_AUTOMATIONS = [
(
{"device_id": "very_unique", "domain": "light"},
"required key not provided @ data['entity_id']",
),
(
{"device_id": "wrong", "domain": "light"},
"Unknown device 'wrong'",
),
(
{"device_id": "wrong"},
"required key not provided @ data{path}['domain']",
),
(
{"device_id": "wrong", "domain": "light"},
"Unknown device 'wrong'",
),
(
{"device_id": "very_unique", "domain": "light"},
"required key not provided @ data['entity_id']",
),
(
{"device_id": "very_unique", "domain": "light", "entity_id": "wrong"},
"Unknown entity 'wrong'",
),
]

BAD_TRIGGERS = BAD_CONDITIONS = BAD_AUTOMATIONS + [
(
{"domain": "light"},
"required key not provided @ data{path}['device_id']",
)
]


@patch("homeassistant.helpers.device_registry.DeviceEntry", MockDeviceEntry)
@pytest.mark.parametrize(("action", "expected_error"), BAD_AUTOMATIONS)
async def test_automation_with_bad_action(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
action: dict[str, str],
expected_error: str,
) -> None:
"""Test automation with bad device action."""
config_entry = MockConfigEntry(domain="fake_integration", data={})
config_entry.state = config_entries.ConfigEntryState.LOADED
config_entry.add_to_hass(hass)
device_entry = device_registry.async_get_or_create(
device_registry.async_get_or_create(
config_entry_id=config_entry.entry_id,
connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
)
Expand All @@ -1262,25 +1309,29 @@ async def test_automation_with_bad_action(
automation.DOMAIN: {
"alias": "hello",
"trigger": {"platform": "event", "event_type": "test_event1"},
"action": {"device_id": device_entry.id, "domain": "light"},
"action": action,
}
},
)

assert "required key not provided" in caplog.text
assert expected_error.format(path="['action'][0]") in caplog.text


@patch("homeassistant.helpers.device_registry.DeviceEntry", MockDeviceEntry)
@pytest.mark.parametrize(("condition", "expected_error"), BAD_CONDITIONS)
async def test_automation_with_bad_condition_action(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
condition: dict[str, str],
expected_error: str,
) -> None:
"""Test automation with bad device action."""
config_entry = MockConfigEntry(domain="fake_integration", data={})
config_entry.state = config_entries.ConfigEntryState.LOADED
config_entry.add_to_hass(hass)
device_entry = device_registry.async_get_or_create(
device_registry.async_get_or_create(
config_entry_id=config_entry.entry_id,
connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
)
Expand All @@ -1292,56 +1343,46 @@ async def test_automation_with_bad_condition_action(
automation.DOMAIN: {
"alias": "hello",
"trigger": {"platform": "event", "event_type": "test_event1"},
"action": {
"condition": "device",
"device_id": device_entry.id,
"domain": "light",
},
"action": {"condition": "device"} | condition,
}
},
)

assert "required key not provided" in caplog.text
assert expected_error.format(path="['action'][0]") in caplog.text


async def test_automation_with_bad_condition_missing_domain(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
@patch("homeassistant.helpers.device_registry.DeviceEntry", MockDeviceEntry)
@pytest.mark.parametrize(("condition", "expected_error"), BAD_CONDITIONS)
async def test_automation_with_bad_condition(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
device_registry: dr.DeviceRegistry,
condition: dict[str, str],
expected_error: str,
) -> None:
"""Test automation with bad device condition."""
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: {
"alias": "hello",
"trigger": {"platform": "event", "event_type": "test_event1"},
"condition": {"condition": "device", "device_id": "hello.device"},
"action": {"service": "test.automation", "entity_id": "hello.world"},
}
},
config_entry = MockConfigEntry(domain="fake_integration", data={})
config_entry.state = config_entries.ConfigEntryState.LOADED
config_entry.add_to_hass(hass)
device_registry.async_get_or_create(
config_entry_id=config_entry.entry_id,
connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
)

assert "required key not provided @ data['condition'][0]['domain']" in caplog.text


async def test_automation_with_bad_condition(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test automation with bad device condition."""
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: {
"alias": "hello",
"trigger": {"platform": "event", "event_type": "test_event1"},
"condition": {"condition": "device", "domain": "light"},
"condition": {"condition": "device"} | condition,
"action": {"service": "test.automation", "entity_id": "hello.world"},
}
},
)

assert "required key not provided" in caplog.text
assert expected_error.format(path="['condition'][0]") in caplog.text


@pytest.fixture
Expand Down Expand Up @@ -1475,10 +1516,24 @@ async def test_automation_with_sub_condition(
)


@patch("homeassistant.helpers.device_registry.DeviceEntry", MockDeviceEntry)
@pytest.mark.parametrize(("condition", "expected_error"), BAD_CONDITIONS)
async def test_automation_with_bad_sub_condition(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
device_registry: dr.DeviceRegistry,
condition: dict[str, str],
expected_error: str,
) -> None:
"""Test automation with bad device condition under and/or conditions."""
config_entry = MockConfigEntry(domain="fake_integration", data={})
config_entry.state = config_entries.ConfigEntryState.LOADED
config_entry.add_to_hass(hass)
device_registry.async_get_or_create(
config_entry_id=config_entry.entry_id,
connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
)

assert await async_setup_component(
hass,
automation.DOMAIN,
Expand All @@ -1488,33 +1543,48 @@ async def test_automation_with_bad_sub_condition(
"trigger": {"platform": "event", "event_type": "test_event1"},
"condition": {
"condition": "and",
"conditions": [{"condition": "device", "domain": "light"}],
"conditions": [{"condition": "device"} | condition],
},
"action": {"service": "test.automation", "entity_id": "hello.world"},
}
},
)

assert "required key not provided" in caplog.text
path = "['condition'][0]['conditions'][0]"
assert expected_error.format(path=path) in caplog.text


@patch("homeassistant.helpers.device_registry.DeviceEntry", MockDeviceEntry)
@pytest.mark.parametrize(("trigger", "expected_error"), BAD_TRIGGERS)
async def test_automation_with_bad_trigger(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
device_registry: dr.DeviceRegistry,
trigger: dict[str, str],
expected_error: str,
) -> None:
"""Test automation with bad device trigger."""
config_entry = MockConfigEntry(domain="fake_integration", data={})
config_entry.state = config_entries.ConfigEntryState.LOADED
config_entry.add_to_hass(hass)
device_registry.async_get_or_create(
config_entry_id=config_entry.entry_id,
connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
)

assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: {
"alias": "hello",
"trigger": {"platform": "device", "domain": "light"},
"trigger": {"platform": "device"} | trigger,
"action": {"service": "test.automation", "entity_id": "hello.world"},
}
},
)

assert "required key not provided" in caplog.text
assert expected_error.format(path="") in caplog.text


async def test_websocket_device_not_found(
Expand Down
Loading