Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add permissions check in service helper #18596

Merged
merged 6 commits into from Nov 21, 2018
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
27 changes: 16 additions & 11 deletions homeassistant/exceptions.py
Expand Up @@ -5,20 +5,14 @@
class HomeAssistantError(Exception):
"""General Home Assistant exception occurred."""

pass


class InvalidEntityFormatError(HomeAssistantError):
"""When an invalid formatted entity is encountered."""

pass


class NoEntitySpecifiedError(HomeAssistantError):
"""When no entity is specified."""

pass


class TemplateError(HomeAssistantError):
"""Error during template rendering."""
Expand All @@ -32,16 +26,27 @@ def __init__(self, exception: jinja2.TemplateError) -> None:
class PlatformNotReady(HomeAssistantError):
"""Error to indicate that platform is not ready."""

pass


class ConfigEntryNotReady(HomeAssistantError):
"""Error to indicate that config entry is not ready."""

pass


class InvalidStateError(HomeAssistantError):
"""When an invalid state is encountered."""

pass

class Unauthorized(HomeAssistantError):
"""When an action is unauthorized."""

def __init__(self, context=None, user_id=None, entity_id=None,
permission=None):
"""Unauthorized error."""
super().__init__(self.__class__.__name__)
self.context = context
self.user_id = user_id
self.entity_id = entity_id
self.permission = permission


class UnknownUser(Unauthorized):
"""When call is made with user ID that doesn't exist."""
69 changes: 61 additions & 8 deletions homeassistant/helpers/service.py
Expand Up @@ -5,9 +5,10 @@

import voluptuous as vol

from homeassistant.auth.permissions.const import POLICY_CONTROL
from homeassistant.const import ATTR_ENTITY_ID
import homeassistant.core as ha
from homeassistant.exceptions import TemplateError
from homeassistant.exceptions import TemplateError, Unauthorized, UnknownUser
from homeassistant.helpers import template
from homeassistant.loader import get_component, bind_hass
from homeassistant.util.yaml import load_yaml
Expand Down Expand Up @@ -187,23 +188,75 @@ async def entity_service_call(hass, platforms, func, call):

Calls all platforms simultaneously.
"""
tasks = []
all_entities = ATTR_ENTITY_ID not in call.data
if not all_entities:
if call.context.user_id:
user = await hass.auth.async_get_user(call.context.user_id)
if user is None:
raise UnknownUser(context=call.context)
perms = user.permissions
else:
perms = None

# Are we trying to target all entities
target_all_entities = ATTR_ENTITY_ID not in call.data

if not target_all_entities:
# A set of entities we're trying to target.
entity_ids = set(
extract_entity_ids(hass, call, True))

# If the service function is a string, we'll pass it the service call data
if isinstance(func, str):
data = {key: val for key, val in call.data.items()
if key != ATTR_ENTITY_ID}
# If the service function is not a string, we pass the service call
else:
data = call

# Check the permissions

# A list with for each platform in platforms a list of entities to call
# the service on.
platforms_entities = []

if perms is None:
for platform in platforms:
if target_all_entities:
platforms_entities.append(list(platform.entities.values()))
else:
platforms_entities.append([
entity for entity in platform.entities.values()
if entity.entity_id in entity_ids
])

elif target_all_entities:
# If we target all entities, we will select all entities the user
# is allowed to control.
for platform in platforms:
platforms_entities.append([
entity for entity in platform.entities.values()
if perms.check_entity(entity.entity_id, POLICY_CONTROL)])

else:
for platform in platforms:
platform_entities = []
for entity in platform.entities.values():
if entity.entity_id not in entity_ids:
continue

if not perms.check_entity(entity.entity_id, POLICY_CONTROL):
raise Unauthorized(
context=call.context,
entity_id=entity.entity_id,
permission=POLICY_CONTROL
)

platform_entities.append(entity)

platforms_entities.append(platform_entities)

tasks = [
_handle_service_platform_call(func, data, [
entity for entity in platform.entities.values()
if all_entities or entity.entity_id in entity_ids
], call.context) for platform in platforms
_handle_service_platform_call(func, data, entities, call.context)
for platform, entities in zip(platforms, platforms_entities)
]

if tasks:
Expand Down
135 changes: 131 additions & 4 deletions tests/helpers/test_service.py
@@ -1,18 +1,49 @@
"""Test service helpers."""
import asyncio
from collections import OrderedDict
from copy import deepcopy
import unittest
from unittest.mock import patch
from unittest.mock import Mock, patch

import pytest

# To prevent circular import when running just this file
import homeassistant.components # noqa
from homeassistant import core as ha, loader
from homeassistant import core as ha, loader, exceptions
from homeassistant.const import STATE_ON, STATE_OFF, ATTR_ENTITY_ID
from homeassistant.helpers import service, template
from homeassistant.setup import async_setup_component
import homeassistant.helpers.config_validation as cv

from tests.common import get_test_home_assistant, mock_service
from homeassistant.auth.permissions import PolicyPermissions

from tests.common import get_test_home_assistant, mock_service, mock_coro


@pytest.fixture

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expected 2 blank lines, found 1

def mock_service_platform_call():
"""Mock service platform call."""
with patch('homeassistant.helpers.service._handle_service_platform_call',
side_effect=lambda *args: mock_coro()) as mock_call:
yield mock_call


@pytest.fixture
def mock_entities():
"""Return mock entities in an ordered dict."""
kitchen = Mock(
entity_id='light.kitchen',
available=True,
should_poll=False,
)
living_room = Mock(
entity_id='light.living_room',
available=True,
should_poll=False,
)
entities = OrderedDict()
entities[kitchen.entity_id] = kitchen
entities[living_room.entity_id] = living_room
return entities


class TestServiceHelpers(unittest.TestCase):
Expand Down Expand Up @@ -179,3 +210,99 @@ def test_async_get_all_descriptions(hass):

assert 'description' in descriptions[logger.DOMAIN]['set_level']
assert 'fields' in descriptions[logger.DOMAIN]['set_level']


async def test_call_context_user_not_exist(hass):
"""Check we don't allow deleted users to do things."""
with pytest.raises(exceptions.UnknownUser) as err:
await service.entity_service_call(hass, [], Mock(), ha.ServiceCall(
'test_domain', 'test_service', context=ha.Context(
user_id='non-existing')))

assert err.value.context.user_id == 'non-existing'


async def test_call_context_target_all(hass, mock_service_platform_call,
mock_entities):
"""Check we only target allowed entities if targetting all."""
with patch('homeassistant.auth.AuthManager.async_get_user',
return_value=mock_coro(Mock(permissions=PolicyPermissions({
'entities': {
'entity_ids': {
'light.kitchen': True
}
}
})))):
await service.entity_service_call(hass, [
Mock(entities=mock_entities)
], Mock(), ha.ServiceCall('test_domain', 'test_service',
context=ha.Context(user_id='mock-id')))

assert len(mock_service_platform_call.mock_calls) == 1
entities = mock_service_platform_call.mock_calls[0][1][2]
assert entities == [mock_entities['light.kitchen']]


async def test_call_context_target_specific(hass, mock_service_platform_call,
mock_entities):
"""Check targeting specific entities."""
with patch('homeassistant.auth.AuthManager.async_get_user',
return_value=mock_coro(Mock(permissions=PolicyPermissions({
'entities': {
'entity_ids': {
'light.kitchen': True
}
}
})))):
await service.entity_service_call(hass, [
Mock(entities=mock_entities)
], Mock(), ha.ServiceCall('test_domain', 'test_service', {
'entity_id': 'light.kitchen'
}, context=ha.Context(user_id='mock-id')))

assert len(mock_service_platform_call.mock_calls) == 1
entities = mock_service_platform_call.mock_calls[0][1][2]
assert entities == [mock_entities['light.kitchen']]


async def test_call_context_target_specific_no_auth(
hass, mock_service_platform_call, mock_entities):
"""Check targeting specific entities without auth."""
with pytest.raises(exceptions.Unauthorized) as err:
with patch('homeassistant.auth.AuthManager.async_get_user',
return_value=mock_coro(Mock(
permissions=PolicyPermissions({})))):
await service.entity_service_call(hass, [
Mock(entities=mock_entities)
], Mock(), ha.ServiceCall('test_domain', 'test_service', {
'entity_id': 'light.kitchen'
}, context=ha.Context(user_id='mock-id')))

assert err.value.context.user_id == 'mock-id'
assert err.value.entity_id == 'light.kitchen'


async def test_call_no_context_target_all(hass, mock_service_platform_call,
mock_entities):
"""Check we target all if no user context given."""
await service.entity_service_call(hass, [
Mock(entities=mock_entities)
], Mock(), ha.ServiceCall('test_domain', 'test_service'))

assert len(mock_service_platform_call.mock_calls) == 1
entities = mock_service_platform_call.mock_calls[0][1][2]
assert entities == list(mock_entities.values())


async def test_call_no_context_target_specific(
hass, mock_service_platform_call, mock_entities):
"""Check we can target specified entities."""
await service.entity_service_call(hass, [
Mock(entities=mock_entities)
], Mock(), ha.ServiceCall('test_domain', 'test_service', {
'entity_id': ['light.kitchen', 'light.non-existing']
}))

assert len(mock_service_platform_call.mock_calls) == 1
entities = mock_service_platform_call.mock_calls[0][1][2]
assert entities == [mock_entities['light.kitchen']]