Permalink
Find file Copy path
296 lines (232 sloc) 9.58 KB
"""Service calling related helpers."""
import asyncio
import logging
from os import path
import voluptuous as vol
from homeassistant.auth.permissions.const import POLICY_CONTROL
from homeassistant.const import ATTR_ENTITY_ID, ENTITY_MATCH_ALL
import homeassistant.core as ha
from homeassistant.exceptions import TemplateError, Unauthorized, UnknownUser
from homeassistant.helpers import template
from homeassistant.loader import get_component, bind_hass
from homeassistant.util.yaml import load_yaml
import homeassistant.helpers.config_validation as cv
from homeassistant.util.async_ import run_coroutine_threadsafe
CONF_SERVICE = 'service'
CONF_SERVICE_TEMPLATE = 'service_template'
CONF_SERVICE_ENTITY_ID = 'entity_id'
CONF_SERVICE_DATA = 'data'
CONF_SERVICE_DATA_TEMPLATE = 'data_template'
_LOGGER = logging.getLogger(__name__)
SERVICE_DESCRIPTION_CACHE = 'service_description_cache'
@bind_hass
def call_from_config(hass, config, blocking=False, variables=None,
validate_config=True):
"""Call a service based on a config hash."""
run_coroutine_threadsafe(
async_call_from_config(hass, config, blocking, variables,
validate_config), hass.loop).result()
@bind_hass
async def async_call_from_config(hass, config, blocking=False, variables=None,
validate_config=True, context=None):
"""Call a service based on a config hash."""
if validate_config:
try:
config = cv.SERVICE_SCHEMA(config)
except vol.Invalid as ex:
_LOGGER.error("Invalid config for calling service: %s", ex)
return
if CONF_SERVICE in config:
domain_service = config[CONF_SERVICE]
else:
try:
config[CONF_SERVICE_TEMPLATE].hass = hass
domain_service = config[CONF_SERVICE_TEMPLATE].async_render(
variables)
domain_service = cv.service(domain_service)
except TemplateError as ex:
if blocking:
raise
_LOGGER.error('Error rendering service name template: %s', ex)
return
except vol.Invalid:
if blocking:
raise
_LOGGER.error('Template rendered invalid service: %s',
domain_service)
return
domain, service_name = domain_service.split('.', 1)
service_data = dict(config.get(CONF_SERVICE_DATA, {}))
if CONF_SERVICE_DATA_TEMPLATE in config:
try:
template.attach(hass, config[CONF_SERVICE_DATA_TEMPLATE])
service_data.update(template.render_complex(
config[CONF_SERVICE_DATA_TEMPLATE], variables))
except TemplateError as ex:
_LOGGER.error('Error rendering data template: %s', ex)
return
if CONF_SERVICE_ENTITY_ID in config:
service_data[ATTR_ENTITY_ID] = config[CONF_SERVICE_ENTITY_ID]
await hass.services.async_call(
domain, service_name, service_data, blocking=blocking, context=context)
@bind_hass
def extract_entity_ids(hass, service_call, expand_group=True):
"""Extract a list of entity ids from a service call.
Will convert group entity ids to the entity ids it represents.
Async friendly.
"""
if not (service_call.data and ATTR_ENTITY_ID in service_call.data):
return []
group = hass.components.group
# Entity ID attr can be a list or a string
service_ent_id = service_call.data[ATTR_ENTITY_ID]
if expand_group:
if isinstance(service_ent_id, str):
return group.expand_entity_ids([service_ent_id])
return [ent_id for ent_id in
group.expand_entity_ids(service_ent_id)]
if isinstance(service_ent_id, str):
return [service_ent_id]
return service_ent_id
@bind_hass
async def async_get_all_descriptions(hass):
"""Return descriptions (i.e. user documentation) for all service calls."""
if SERVICE_DESCRIPTION_CACHE not in hass.data:
hass.data[SERVICE_DESCRIPTION_CACHE] = {}
description_cache = hass.data[SERVICE_DESCRIPTION_CACHE]
format_cache_key = '{}.{}'.format
def domain_yaml_file(domain):
"""Return the services.yaml location for a domain."""
if domain == ha.DOMAIN:
from homeassistant import components
component_path = path.dirname(components.__file__)
else:
component_path = path.dirname(get_component(hass, domain).__file__)
return path.join(component_path, 'services.yaml')
def load_services_files(yaml_files):
"""Load and parse services.yaml files."""
loaded = {}
for yaml_file in yaml_files:
try:
loaded[yaml_file] = load_yaml(yaml_file)
except FileNotFoundError:
loaded[yaml_file] = {}
return loaded
services = hass.services.async_services()
# Load missing files
missing = set()
for domain in services:
for service in services[domain]:
if format_cache_key(domain, service) not in description_cache:
missing.add(domain_yaml_file(domain))
break
if missing:
loaded = await hass.async_add_job(load_services_files, missing)
# Build response
catch_all_yaml_file = domain_yaml_file(ha.DOMAIN)
descriptions = {}
for domain in services:
descriptions[domain] = {}
yaml_file = domain_yaml_file(domain)
for service in services[domain]:
cache_key = format_cache_key(domain, service)
description = description_cache.get(cache_key)
# Cache missing descriptions
if description is None:
if yaml_file == catch_all_yaml_file:
yaml_services = loaded[yaml_file].get(domain, {})
else:
yaml_services = loaded[yaml_file]
yaml_description = yaml_services.get(service, {})
description = description_cache[cache_key] = {
'description': yaml_description.get('description', ''),
'fields': yaml_description.get('fields', {})
}
descriptions[domain][service] = description
return descriptions
@bind_hass
async def entity_service_call(hass, platforms, func, call):
"""Handle an entity service call.
Calls all platforms simultaneously.
"""
if call.context.user_id:
user = await hass.auth.async_get_user(call.context.user_id)
if user is None:
raise UnknownUser(context=call.context)
entity_perms = user.permissions.check_entity
else:
entity_perms = None
# Are we trying to target all entities
if ATTR_ENTITY_ID in call.data:
target_all_entities = call.data[ATTR_ENTITY_ID] == ENTITY_MATCH_ALL
else:
_LOGGER.warning(
'Not passing an entity ID to a service to target all entities is '
'deprecated. Use instead: entity_id: "%s"', ENTITY_MATCH_ALL)
target_all_entities = True
if not target_all_entities:
# A set of entities we're trying to target.
entity_ids = set(
extract_entity_ids(hass, call, True))
# If the service function is a string, we'll pass it the service call data
if isinstance(func, str):
data = {key: val for key, val in call.data.items()
if key != ATTR_ENTITY_ID}
# If the service function is not a string, we pass the service call
else:
data = call
# Check the permissions
# A list with for each platform in platforms a list of entities to call
# the service on.
platforms_entities = []
if entity_perms is None:
for platform in platforms:
if target_all_entities:
platforms_entities.append(list(platform.entities.values()))
else:
platforms_entities.append([
entity for entity in platform.entities.values()
if entity.entity_id in entity_ids
])
elif target_all_entities:
# If we target all entities, we will select all entities the user
# is allowed to control.
for platform in platforms:
platforms_entities.append([
entity for entity in platform.entities.values()
if entity_perms(entity.entity_id, POLICY_CONTROL)])
else:
for platform in platforms:
platform_entities = []
for entity in platform.entities.values():
if entity.entity_id not in entity_ids:
continue
if not entity_perms(entity.entity_id, POLICY_CONTROL):
raise Unauthorized(
context=call.context,
entity_id=entity.entity_id,
permission=POLICY_CONTROL
)
platform_entities.append(entity)
platforms_entities.append(platform_entities)
tasks = [
_handle_service_platform_call(func, data, entities, call.context)
for platform, entities in zip(platforms, platforms_entities)
]
if tasks:
await asyncio.wait(tasks)
async def _handle_service_platform_call(func, data, entities, context):
"""Handle a function call."""
tasks = []
for entity in entities:
if not entity.available:
continue
entity.async_set_context(context)
if isinstance(func, str):
await getattr(entity, func)(**data)
else:
await func(entity, data)
if entity.should_poll:
tasks.append(entity.async_update_ha_state(True))
if tasks:
await asyncio.wait(tasks)