Skip to content

Commit

Permalink
Add target to service call API (#45898)
Browse files Browse the repository at this point in the history
* Add target to service call API

* Fix _async_call_service_step

* CONF_SERVICE_ENTITY_ID overrules target

* Move merging up before processing schema

* Restore services.yaml

* Add test
  • Loading branch information
bramkragten committed Feb 10, 2021
1 parent 7d2d98f commit 4b493c5
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 16 deletions.
2 changes: 1 addition & 1 deletion homeassistant/components/api/__init__.py
Expand Up @@ -378,7 +378,7 @@ async def post(self, request, domain, service):
with AsyncTrackStates(hass) as changed_states:
try:
await hass.services.async_call(
domain, service, data, True, self.context(request)
domain, service, data, blocking=True, context=self.context(request)
)
except (vol.Invalid, ServiceNotFound) as ex:
raise HTTPBadRequest() from ex
Expand Down
2 changes: 2 additions & 0 deletions homeassistant/components/websocket_api/commands.py
Expand Up @@ -121,6 +121,7 @@ def handle_unsubscribe_events(hass, connection, msg):
vol.Required("type"): "call_service",
vol.Required("domain"): str,
vol.Required("service"): str,
vol.Optional("target"): cv.ENTITY_SERVICE_FIELDS,
vol.Optional("service_data"): dict,
}
)
Expand All @@ -139,6 +140,7 @@ async def handle_call_service(hass, connection, msg):
msg.get("service_data"),
blocking,
context,
target=msg.get("target"),
)
connection.send_message(
messages.result_message(msg["id"], {"context": context})
Expand Down
9 changes: 8 additions & 1 deletion homeassistant/core.py
Expand Up @@ -1358,14 +1358,17 @@ def call(
blocking: bool = False,
context: Optional[Context] = None,
limit: Optional[float] = SERVICE_CALL_LIMIT,
target: Optional[Dict] = None,
) -> Optional[bool]:
"""
Call a service.
See description of async_call for details.
"""
return asyncio.run_coroutine_threadsafe(
self.async_call(domain, service, service_data, blocking, context, limit),
self.async_call(
domain, service, service_data, blocking, context, limit, target
),
self._hass.loop,
).result()

Expand All @@ -1377,6 +1380,7 @@ async def async_call(
blocking: bool = False,
context: Optional[Context] = None,
limit: Optional[float] = SERVICE_CALL_LIMIT,
target: Optional[Dict] = None,
) -> Optional[bool]:
"""
Call a service.
Expand Down Expand Up @@ -1404,6 +1408,9 @@ async def async_call(
except KeyError:
raise ServiceNotFound(domain, service) from None

if target:
service_data.update(target)

if handler.schema:
try:
processed_data = handler.schema(service_data)
Expand Down
12 changes: 5 additions & 7 deletions homeassistant/helpers/script.py
Expand Up @@ -433,14 +433,14 @@ async def _async_call_service_step(self):
self._script.last_action = self._action.get(CONF_ALIAS, "call service")
self._log("Executing step %s", self._script.last_action)

domain, service_name, service_data = service.async_prepare_call_from_config(
params = service.async_prepare_call_from_config(
self._hass, self._action, self._variables
)

running_script = (
domain == "automation"
and service_name == "trigger"
or domain in ("python_script", "script")
params["domain"] == "automation"
and params["service_name"] == "trigger"
or params["domain"] in ("python_script", "script")
)
# If this might start a script then disable the call timeout.
# Otherwise use the normal service call limit.
Expand All @@ -451,9 +451,7 @@ async def _async_call_service_step(self):

service_task = self._hass.async_create_task(
self._hass.services.async_call(
domain,
service_name,
service_data,
**params,
blocking=True,
context=self._context,
limit=limit,
Expand Down
32 changes: 25 additions & 7 deletions homeassistant/helpers/service.py
Expand Up @@ -14,6 +14,7 @@
Optional,
Set,
Tuple,
TypedDict,
Union,
cast,
)
Expand Down Expand Up @@ -70,6 +71,15 @@
SERVICE_DESCRIPTION_CACHE = "service_description_cache"


class ServiceParams(TypedDict):
"""Type for service call parameters."""

domain: str
service: str
service_data: Dict[str, Any]
target: Optional[Dict]


@dataclasses.dataclass
class SelectedEntities:
"""Class to hold the selected entities."""
Expand Down Expand Up @@ -136,7 +146,7 @@ async def async_call_from_config(
raise
_LOGGER.error(ex)
else:
await hass.services.async_call(*params, blocking, context)
await hass.services.async_call(**params, blocking=blocking, context=context)


@ha.callback
Expand All @@ -146,7 +156,7 @@ def async_prepare_call_from_config(
config: ConfigType,
variables: TemplateVarsType = None,
validate_config: bool = False,
) -> Tuple[str, str, Dict[str, Any]]:
) -> ServiceParams:
"""Prepare to call a service based on a config hash."""
if validate_config:
try:
Expand Down Expand Up @@ -177,10 +187,9 @@ def async_prepare_call_from_config(

domain, service = domain_service.split(".", 1)

service_data = {}
target = config.get(CONF_TARGET)

if CONF_TARGET in config:
service_data.update(config[CONF_TARGET])
service_data = {}

for conf in [CONF_SERVICE_DATA, CONF_SERVICE_DATA_TEMPLATE]:
if conf not in config:
Expand All @@ -192,9 +201,17 @@ def async_prepare_call_from_config(
raise HomeAssistantError(f"Error rendering data template: {ex}") from ex

if CONF_SERVICE_ENTITY_ID in config:
service_data[ATTR_ENTITY_ID] = config[CONF_SERVICE_ENTITY_ID]
if target:
target[ATTR_ENTITY_ID] = config[CONF_SERVICE_ENTITY_ID]
else:
target = {ATTR_ENTITY_ID: config[CONF_SERVICE_ENTITY_ID]}

return domain, service, service_data
return {
"domain": domain,
"service": service,
"service_data": service_data,
"target": target,
}


@bind_hass
Expand Down Expand Up @@ -431,6 +448,7 @@ async def async_get_all_descriptions(

description = descriptions_cache[cache_key] = {
"description": yaml_description.get("description", ""),
"target": yaml_description.get("target"),
"fields": yaml_description.get("fields", {}),
}

Expand Down
41 changes: 41 additions & 0 deletions tests/components/websocket_api/test_commands.py
Expand Up @@ -52,6 +52,47 @@ def service_call(call):
assert call.data == {"hello": "world"}


async def test_call_service_target(hass, websocket_client):
"""Test call service command with target."""
calls = []

@callback
def service_call(call):
calls.append(call)

hass.services.async_register("domain_test", "test_service", service_call)

await websocket_client.send_json(
{
"id": 5,
"type": "call_service",
"domain": "domain_test",
"service": "test_service",
"service_data": {"hello": "world"},
"target": {
"entity_id": ["entity.one", "entity.two"],
"device_id": "deviceid",
},
}
)

msg = await websocket_client.receive_json()
assert msg["id"] == 5
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]

assert len(calls) == 1
call = calls[0]

assert call.domain == "domain_test"
assert call.service == "test_service"
assert call.data == {
"hello": "world",
"entity_id": ["entity.one", "entity.two"],
"device_id": ["deviceid"],
}


async def test_call_service_not_found(hass, websocket_client):
"""Test call service command."""
await websocket_client.send_json(
Expand Down

0 comments on commit 4b493c5

Please sign in to comment.