Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support shorthand templates in condition actions #61177

Merged
merged 3 commits into from Dec 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
50 changes: 47 additions & 3 deletions homeassistant/helpers/config_validation.py
Expand Up @@ -860,7 +860,10 @@ def removed(


def key_value_schemas(
key: str, value_schemas: dict[Hashable, vol.Schema]
key: str,
value_schemas: dict[Hashable, vol.Schema],
default_schema: vol.Schema | None = None,
default_description: str | None = None,
) -> Callable[[Any], dict[Hashable, Any]]:
"""Create a validator that validates based on a value for specific key.

Expand All @@ -876,8 +879,15 @@ def key_value_validator(value: Any) -> dict[Hashable, Any]:
if isinstance(key_value, Hashable) and key_value in value_schemas:
return cast(Dict[Hashable, Any], value_schemas[key_value](value))

if default_schema:
with contextlib.suppress(vol.Invalid):
return cast(Dict[Hashable, Any], default_schema(value))
Comment on lines +882 to +884
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will hide error messages from the default validator, not sure if that's acceptable?


alternatives = ", ".join(str(key) for key in value_schemas)
if default_description:
alternatives += ", " + default_description
raise vol.Invalid(
f"Unexpected value for {key}: '{key_value}'. Expected {', '.join(str(key) for key in value_schemas)}"
f"Unexpected value for {key}: '{key_value}'. Expected {alternatives}"
)

return key_value_validator
Expand Down Expand Up @@ -1207,6 +1217,40 @@ def STATE_CONDITION_SCHEMA(value: Any) -> dict: # pylint: disable=invalid-name
)
)


dynamic_template_condition_action = vol.All(
vol.Schema(
{**CONDITION_BASE_SCHEMA, vol.Required(CONF_CONDITION): dynamic_template}
),
lambda config: {
**config,
CONF_VALUE_TEMPLATE: config[CONF_CONDITION],
CONF_CONDITION: "template",
},
)


CONDITION_ACTION_SCHEMA: vol.Schema = vol.Schema(
key_value_schemas(
CONF_CONDITION,
{
"and": AND_CONDITION_SCHEMA,
"device": DEVICE_CONDITION_SCHEMA,
"not": NOT_CONDITION_SCHEMA,
"numeric_state": NUMERIC_STATE_CONDITION_SCHEMA,
"or": OR_CONDITION_SCHEMA,
"state": STATE_CONDITION_SCHEMA,
"sun": SUN_CONDITION_SCHEMA,
"template": TEMPLATE_CONDITION_SCHEMA,
"time": TIME_CONDITION_SCHEMA,
"trigger": TRIGGER_CONDITION_SCHEMA,
"zone": ZONE_CONDITION_SCHEMA,
},
dynamic_template_condition_action,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the broken test this breaks the validation message.

"a valid template",
)
)

TRIGGER_BASE_SCHEMA = vol.Schema(
{vol.Required(CONF_PLATFORM): str, vol.Optional(CONF_ID): str}
)
Expand Down Expand Up @@ -1352,7 +1396,7 @@ def determine_script_action(action: dict[str, Any]) -> str:
SCRIPT_ACTION_DELAY: _SCRIPT_DELAY_SCHEMA,
SCRIPT_ACTION_WAIT_TEMPLATE: _SCRIPT_WAIT_TEMPLATE_SCHEMA,
SCRIPT_ACTION_FIRE_EVENT: EVENT_SCHEMA,
SCRIPT_ACTION_CHECK_CONDITION: CONDITION_SCHEMA,
SCRIPT_ACTION_CHECK_CONDITION: CONDITION_ACTION_SCHEMA,
SCRIPT_ACTION_DEVICE_AUTOMATION: DEVICE_ACTION_SCHEMA,
SCRIPT_ACTION_ACTIVATE_SCENE: _SCRIPT_SCENE_SCHEMA,
SCRIPT_ACTION_REPEAT: _SCRIPT_REPEAT_SCHEMA,
Expand Down
39 changes: 39 additions & 0 deletions tests/helpers/test_config_validation.py
Expand Up @@ -1184,6 +1184,45 @@ def test_key_value_schemas():
schema({"mode": mode, "data": data})


def test_key_value_schemas_with_default():
"""Test key value schemas."""
schema = vol.Schema(
cv.key_value_schemas(
"mode",
{
"number": vol.Schema({"mode": "number", "data": int}),
"string": vol.Schema({"mode": "string", "data": str}),
},
vol.Schema({"mode": cv.dynamic_template}),
"a cool template",
)
)

with pytest.raises(vol.Invalid) as excinfo:
schema(True)
assert str(excinfo.value) == "Expected a dictionary"

for mode in None, {"a": "dict"}, "invalid":
with pytest.raises(vol.Invalid) as excinfo:
schema({"mode": mode})
assert (
str(excinfo.value)
== f"Unexpected value for mode: '{mode}'. Expected number, string, a cool template"
)

with pytest.raises(vol.Invalid) as excinfo:
schema({"mode": "number", "data": "string-value"})
assert str(excinfo.value) == "expected int for dictionary value @ data['data']"

with pytest.raises(vol.Invalid) as excinfo:
schema({"mode": "string", "data": 1})
assert str(excinfo.value) == "expected str for dictionary value @ data['data']"

for mode, data in (("number", 1), ("string", "hello")):
schema({"mode": mode, "data": data})
schema({"mode": "{{ 1 + 1}}"})


def test_script(caplog):
"""Test script validation is user friendly."""
for data, msg in (
Expand Down
55 changes: 55 additions & 0 deletions tests/helpers/test_script.py
Expand Up @@ -1501,6 +1501,61 @@ async def test_condition_basic(hass, caplog):
)


async def test_shorthand_template_condition(hass, caplog):
"""Test if we can use shorthand template conditions in a script."""
event = "test_event"
events = async_capture_events(hass, event)
alias = "condition step"
sequence = cv.SCRIPT_SCHEMA(
[
{"event": event},
{
"alias": alias,
"condition": "{{ states.test.entity.state == 'hello' }}",
},
{"event": event},
]
)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")

hass.states.async_set("test.entity", "hello")
await script_obj.async_run(context=Context())
await hass.async_block_till_done()

assert f"Test condition {alias}: True" in caplog.text
caplog.clear()
assert len(events) == 2

assert_action_trace(
{
"0": [{"result": {"event": "test_event", "event_data": {}}}],
"1": [{"result": {"entities": ["test.entity"], "result": True}}],
"2": [{"result": {"event": "test_event", "event_data": {}}}],
}
)

hass.states.async_set("test.entity", "goodbye")

await script_obj.async_run(context=Context())
await hass.async_block_till_done()

assert f"Test condition {alias}: False" in caplog.text
assert len(events) == 3

assert_action_trace(
{
"0": [{"result": {"event": "test_event", "event_data": {}}}],
"1": [
{
"error_type": script._StopScript,
"result": {"entities": ["test.entity"], "result": False},
}
],
},
expected_script_execution="aborted",
)


async def test_condition_validation(hass, caplog):
"""Test if we can use conditions which validate late in a script."""
registry = er.async_get(hass)
Expand Down