Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add target to service call API #45898

Merged
merged 6 commits into from Feb 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion homeassistant/components/api/__init__.py
Expand Up @@ -377,7 +377,7 @@ async def post(self, request, domain, service):
with AsyncTrackStates(hass) as changed_states:
try:
await hass.services.async_call(
domain, service, data, True, self.context(request)
domain, service, data, blocking=True, context=self.context(request)
)
except (vol.Invalid, ServiceNotFound) as ex:
raise HTTPBadRequest() from ex
Expand Down
2 changes: 2 additions & 0 deletions homeassistant/components/websocket_api/commands.py
Expand Up @@ -121,6 +121,7 @@ def handle_unsubscribe_events(hass, connection, msg):
vol.Required("type"): "call_service",
vol.Required("domain"): str,
vol.Required("service"): str,
vol.Optional("target"): cv.ENTITY_SERVICE_FIELDS,
vol.Optional("service_data"): dict,
}
)
Expand All @@ -139,6 +140,7 @@ async def handle_call_service(hass, connection, msg):
msg.get("service_data"),
blocking,
context,
target=msg.get("target"),
)
connection.send_message(
messages.result_message(msg["id"], {"context": context})
Expand Down
9 changes: 8 additions & 1 deletion homeassistant/core.py
Expand Up @@ -1358,14 +1358,17 @@ def call(
blocking: bool = False,
context: Optional[Context] = None,
limit: Optional[float] = SERVICE_CALL_LIMIT,
target: Optional[Dict] = None,
) -> Optional[bool]:
"""
Call a service.

See description of async_call for details.
"""
return asyncio.run_coroutine_threadsafe(
self.async_call(domain, service, service_data, blocking, context, limit),
self.async_call(
domain, service, service_data, blocking, context, limit, target
),
self._hass.loop,
).result()

Expand All @@ -1377,6 +1380,7 @@ async def async_call(
blocking: bool = False,
context: Optional[Context] = None,
limit: Optional[float] = SERVICE_CALL_LIMIT,
target: Optional[Dict] = None,
) -> Optional[bool]:
"""
Call a service.
Expand Down Expand Up @@ -1404,6 +1408,9 @@ async def async_call(
except KeyError:
raise ServiceNotFound(domain, service) from None

if target:
service_data.update(target)

if handler.schema:
try:
processed_data = handler.schema(service_data)
Expand Down
12 changes: 5 additions & 7 deletions homeassistant/helpers/script.py
Expand Up @@ -433,14 +433,14 @@ async def _async_call_service_step(self):
self._script.last_action = self._action.get(CONF_ALIAS, "call service")
self._log("Executing step %s", self._script.last_action)

domain, service_name, service_data = service.async_prepare_call_from_config(
params = service.async_prepare_call_from_config(
self._hass, self._action, self._variables
)

running_script = (
domain == "automation"
and service_name == "trigger"
or domain in ("python_script", "script")
params["domain"] == "automation"
and params["service_name"] == "trigger"
or params["domain"] in ("python_script", "script")
)
# If this might start a script then disable the call timeout.
# Otherwise use the normal service call limit.
Expand All @@ -451,9 +451,7 @@ async def _async_call_service_step(self):

service_task = self._hass.async_create_task(
self._hass.services.async_call(
domain,
service_name,
service_data,
**params,
blocking=True,
context=self._context,
limit=limit,
Expand Down
32 changes: 25 additions & 7 deletions homeassistant/helpers/service.py
Expand Up @@ -14,6 +14,7 @@
Optional,
Set,
Tuple,
TypedDict,
Union,
cast,
)
Expand Down Expand Up @@ -70,6 +71,15 @@
SERVICE_DESCRIPTION_CACHE = "service_description_cache"


class ServiceParams(TypedDict):
"""Type for service call parameters."""

domain: str
service: str
service_data: Dict[str, Any]
target: Optional[Dict]


@dataclasses.dataclass
class SelectedEntities:
"""Class to hold the selected entities."""
Expand Down Expand Up @@ -136,7 +146,7 @@ async def async_call_from_config(
raise
_LOGGER.error(ex)
else:
await hass.services.async_call(*params, blocking, context)
await hass.services.async_call(**params, blocking=blocking, context=context)


@ha.callback
Expand All @@ -146,7 +156,7 @@ def async_prepare_call_from_config(
config: ConfigType,
variables: TemplateVarsType = None,
validate_config: bool = False,
) -> Tuple[str, str, Dict[str, Any]]:
) -> ServiceParams:
"""Prepare to call a service based on a config hash."""
if validate_config:
try:
Expand Down Expand Up @@ -177,10 +187,9 @@ def async_prepare_call_from_config(

domain, service = domain_service.split(".", 1)

service_data = {}
target = config.get(CONF_TARGET)

if CONF_TARGET in config:
service_data.update(config[CONF_TARGET])
service_data = {}

for conf in [CONF_SERVICE_DATA, CONF_SERVICE_DATA_TEMPLATE]:
if conf not in config:
Expand All @@ -192,9 +201,17 @@ def async_prepare_call_from_config(
raise HomeAssistantError(f"Error rendering data template: {ex}") from ex

if CONF_SERVICE_ENTITY_ID in config:
service_data[ATTR_ENTITY_ID] = config[CONF_SERVICE_ENTITY_ID]
if target:
target[ATTR_ENTITY_ID] = config[CONF_SERVICE_ENTITY_ID]
else:
target = {ATTR_ENTITY_ID: config[CONF_SERVICE_ENTITY_ID]}

return domain, service, service_data
return {
"domain": domain,
"service": service,
"service_data": service_data,
"target": target,
}


@bind_hass
Expand Down Expand Up @@ -431,6 +448,7 @@ async def async_get_all_descriptions(

description = descriptions_cache[cache_key] = {
"description": yaml_description.get("description", ""),
"target": yaml_description.get("target"),
"fields": yaml_description.get("fields", {}),
}

Expand Down
41 changes: 41 additions & 0 deletions tests/components/websocket_api/test_commands.py
Expand Up @@ -52,6 +52,47 @@ def service_call(call):
assert call.data == {"hello": "world"}


async def test_call_service_target(hass, websocket_client):
"""Test call service command with target."""
calls = []

@callback
def service_call(call):
calls.append(call)

hass.services.async_register("domain_test", "test_service", service_call)

await websocket_client.send_json(
{
"id": 5,
"type": "call_service",
"domain": "domain_test",
"service": "test_service",
"service_data": {"hello": "world"},
"target": {
"entity_id": ["entity.one", "entity.two"],
"device_id": "deviceid",
},
}
)

msg = await websocket_client.receive_json()
assert msg["id"] == 5
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]

assert len(calls) == 1
call = calls[0]

assert call.domain == "domain_test"
assert call.service == "test_service"
assert call.data == {
"hello": "world",
"entity_id": ["entity.one", "entity.two"],
"device_id": ["deviceid"],
}


async def test_call_service_not_found(hass, websocket_client):
"""Test call service command."""
await websocket_client.send_json(
Expand Down