diff --git a/homeassistant/auth/__init__.py b/homeassistant/auth/__init__.py index 0011c98ce730..e69dec37df28 100644 --- a/homeassistant/auth/__init__.py +++ b/homeassistant/auth/__init__.py @@ -118,6 +118,10 @@ async def async_get_user(self, user_id: str) -> Optional[models.User]: """Retrieve a user.""" return await self._store.async_get_user(user_id) + async def async_get_group(self, group_id: str) -> Optional[models.Group]: + """Retrieve all groups.""" + return await self._store.async_get_group(group_id) + async def async_get_user_by_credentials( self, credentials: models.Credentials) -> Optional[models.User]: """Get a user by credential, return None if not found.""" diff --git a/homeassistant/auth/auth_store.py b/homeassistant/auth/auth_store.py index ab233489db0e..867d5357a583 100644 --- a/homeassistant/auth/auth_store.py +++ b/homeassistant/auth/auth_store.py @@ -45,6 +45,14 @@ async def async_get_groups(self) -> List[models.Group]: return list(self._groups.values()) + async def async_get_group(self, group_id: str) -> Optional[models.Group]: + """Retrieve all users.""" + if self._groups is None: + await self._async_load() + assert self._groups is not None + + return self._groups.get(group_id) + async def async_get_users(self) -> List[models.User]: """Retrieve all users.""" if self._users is None: diff --git a/homeassistant/exceptions.py b/homeassistant/exceptions.py index 11aa1848529c..0613b7cb10c5 100644 --- a/homeassistant/exceptions.py +++ b/homeassistant/exceptions.py @@ -1,24 +1,24 @@ """The exceptions used by Home Assistant.""" +from typing import Optional, Tuple, TYPE_CHECKING import jinja2 +# pylint: disable=using-constant-test +if TYPE_CHECKING: + # pylint: disable=unused-import + from .core import Context # noqa + class HomeAssistantError(Exception): """General Home Assistant exception occurred.""" - pass - class InvalidEntityFormatError(HomeAssistantError): """When an invalid formatted entity is encountered.""" - pass - class NoEntitySpecifiedError(HomeAssistantError): """When no entity is specified.""" - pass - class TemplateError(HomeAssistantError): """Error during template rendering.""" @@ -32,16 +32,29 @@ def __init__(self, exception: jinja2.TemplateError) -> None: class PlatformNotReady(HomeAssistantError): """Error to indicate that platform is not ready.""" - pass - class ConfigEntryNotReady(HomeAssistantError): """Error to indicate that config entry is not ready.""" - pass - class InvalidStateError(HomeAssistantError): """When an invalid state is encountered.""" - pass + +class Unauthorized(HomeAssistantError): + """When an action is unauthorized.""" + + def __init__(self, context: Optional['Context'] = None, + user_id: Optional[str] = None, + entity_id: Optional[str] = None, + permission: Optional[Tuple[str]] = None) -> None: + """Unauthorized error.""" + super().__init__(self.__class__.__name__) + self.context = context + self.user_id = user_id + self.entity_id = entity_id + self.permission = permission + + +class UnknownUser(Unauthorized): + """When call is made with user ID that doesn't exist.""" diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index 0f394a6f153f..5e0d9c7e88af 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -5,9 +5,10 @@ import voluptuous as vol +from homeassistant.auth.permissions.const import POLICY_CONTROL from homeassistant.const import ATTR_ENTITY_ID import homeassistant.core as ha -from homeassistant.exceptions import TemplateError +from homeassistant.exceptions import TemplateError, Unauthorized, UnknownUser from homeassistant.helpers import template from homeassistant.loader import get_component, bind_hass from homeassistant.util.yaml import load_yaml @@ -187,23 +188,75 @@ async def entity_service_call(hass, platforms, func, call): Calls all platforms simultaneously. """ - tasks = [] - all_entities = ATTR_ENTITY_ID not in call.data - if not all_entities: + if call.context.user_id: + user = await hass.auth.async_get_user(call.context.user_id) + if user is None: + raise UnknownUser(context=call.context) + perms = user.permissions + else: + perms = None + + # Are we trying to target all entities + target_all_entities = ATTR_ENTITY_ID not in call.data + + if not target_all_entities: + # A set of entities we're trying to target. entity_ids = set( extract_entity_ids(hass, call, True)) + # If the service function is a string, we'll pass it the service call data if isinstance(func, str): data = {key: val for key, val in call.data.items() if key != ATTR_ENTITY_ID} + # If the service function is not a string, we pass the service call else: data = call + # Check the permissions + + # A list with for each platform in platforms a list of entities to call + # the service on. + platforms_entities = [] + + if perms is None: + for platform in platforms: + if target_all_entities: + platforms_entities.append(list(platform.entities.values())) + else: + platforms_entities.append([ + entity for entity in platform.entities.values() + if entity.entity_id in entity_ids + ]) + + elif target_all_entities: + # If we target all entities, we will select all entities the user + # is allowed to control. + for platform in platforms: + platforms_entities.append([ + entity for entity in platform.entities.values() + if perms.check_entity(entity.entity_id, POLICY_CONTROL)]) + + else: + for platform in platforms: + platform_entities = [] + for entity in platform.entities.values(): + if entity.entity_id not in entity_ids: + continue + + if not perms.check_entity(entity.entity_id, POLICY_CONTROL): + raise Unauthorized( + context=call.context, + entity_id=entity.entity_id, + permission=POLICY_CONTROL + ) + + platform_entities.append(entity) + + platforms_entities.append(platform_entities) + tasks = [ - _handle_service_platform_call(func, data, [ - entity for entity in platform.entities.values() - if all_entities or entity.entity_id in entity_ids - ], call.context) for platform in platforms + _handle_service_platform_call(func, data, entities, call.context) + for platform, entities in zip(platforms, platforms_entities) ] if tasks: diff --git a/tests/components/conftest.py b/tests/components/conftest.py index 252d0b1d872f..2568a1092448 100644 --- a/tests/components/conftest.py +++ b/tests/components/conftest.py @@ -3,6 +3,7 @@ import pytest +from homeassistant.auth.const import GROUP_ID_ADMIN, GROUP_ID_READ_ONLY from homeassistant.setup import async_setup_component from homeassistant.components.websocket_api.http import URL from homeassistant.components.websocket_api.auth import ( @@ -77,3 +78,19 @@ def hass_access_token(hass): refresh_token = hass.loop.run_until_complete( hass.auth.async_create_refresh_token(user, CLIENT_ID)) yield hass.auth.async_create_access_token(refresh_token) + + +@pytest.fixture +def hass_admin_user(hass): + """Return a Home Assistant admin user.""" + admin_group = hass.loop.run_until_complete(hass.auth.async_get_group( + GROUP_ID_ADMIN)) + return MockUser(groups=[admin_group]).add_to_hass(hass) + + +@pytest.fixture +def hass_read_only_user(hass): + """Return a Home Assistant read only user.""" + read_only_group = hass.loop.run_until_complete(hass.auth.async_get_group( + GROUP_ID_READ_ONLY)) + return MockUser(groups=[read_only_group]).add_to_hass(hass) diff --git a/tests/components/counter/test_init.py b/tests/components/counter/test_init.py index 78ca72dd1e4a..c8411bf2fdec 100644 --- a/tests/components/counter/test_init.py +++ b/tests/components/counter/test_init.py @@ -234,7 +234,7 @@ def test_no_initial_state_and_no_restore_state(hass): assert int(state.state) == 0 -async def test_counter_context(hass): +async def test_counter_context(hass, hass_admin_user): """Test that counter context works.""" assert await async_setup_component(hass, 'counter', { 'counter': { @@ -247,9 +247,9 @@ async def test_counter_context(hass): await hass.services.async_call('counter', 'increment', { 'entity_id': state.entity_id, - }, True, Context(user_id='abcd')) + }, True, Context(user_id=hass_admin_user.id)) state2 = hass.states.get('counter.test') assert state2 is not None assert state.state != state2.state - assert state2.context.user_id == 'abcd' + assert state2.context.user_id == hass_admin_user.id diff --git a/tests/components/light/test_init.py b/tests/components/light/test_init.py index a04fb853996b..09474a5ad064 100644 --- a/tests/components/light/test_init.py +++ b/tests/components/light/test_init.py @@ -476,7 +476,7 @@ async def test_intent_set_color_and_brightness(hass): assert call.data.get(light.ATTR_BRIGHTNESS_PCT) == 20 -async def test_light_context(hass): +async def test_light_context(hass, hass_admin_user): """Test that light context works.""" assert await async_setup_component(hass, 'light', { 'light': { @@ -489,9 +489,9 @@ async def test_light_context(hass): await hass.services.async_call('light', 'toggle', { 'entity_id': state.entity_id, - }, True, core.Context(user_id='abcd')) + }, True, core.Context(user_id=hass_admin_user.id)) state2 = hass.states.get('light.ceiling') assert state2 is not None assert state.state != state2.state - assert state2.context.user_id == 'abcd' + assert state2.context.user_id == hass_admin_user.id diff --git a/tests/components/switch/test_init.py b/tests/components/switch/test_init.py index 1a51457df962..d39c5a24ddc5 100644 --- a/tests/components/switch/test_init.py +++ b/tests/components/switch/test_init.py @@ -91,7 +91,7 @@ def test_setup_two_platforms(self): ) -async def test_switch_context(hass): +async def test_switch_context(hass, hass_admin_user): """Test that switch context works.""" assert await async_setup_component(hass, 'switch', { 'switch': { @@ -104,9 +104,9 @@ async def test_switch_context(hass): await hass.services.async_call('switch', 'toggle', { 'entity_id': state.entity_id, - }, True, core.Context(user_id='abcd')) + }, True, core.Context(user_id=hass_admin_user.id)) state2 = hass.states.get('switch.ac') assert state2 is not None assert state.state != state2.state - assert state2.context.user_id == 'abcd' + assert state2.context.user_id == hass_admin_user.id diff --git a/tests/components/test_input_boolean.py b/tests/components/test_input_boolean.py index a77e0a8c0102..019318c2693f 100644 --- a/tests/components/test_input_boolean.py +++ b/tests/components/test_input_boolean.py @@ -147,7 +147,7 @@ def test_initial_state_overrules_restore_state(hass): assert state.state == 'on' -async def test_input_boolean_context(hass): +async def test_input_boolean_context(hass, hass_admin_user): """Test that input_boolean context works.""" assert await async_setup_component(hass, 'input_boolean', { 'input_boolean': { @@ -160,9 +160,9 @@ async def test_input_boolean_context(hass): await hass.services.async_call('input_boolean', 'turn_off', { 'entity_id': state.entity_id, - }, True, Context(user_id='abcd')) + }, True, Context(user_id=hass_admin_user.id)) state2 = hass.states.get('input_boolean.ac') assert state2 is not None assert state.state != state2.state - assert state2.context.user_id == 'abcd' + assert state2.context.user_id == hass_admin_user.id diff --git a/tests/components/test_input_datetime.py b/tests/components/test_input_datetime.py index 9649531a8a1d..a61cefe34f2f 100644 --- a/tests/components/test_input_datetime.py +++ b/tests/components/test_input_datetime.py @@ -195,7 +195,7 @@ def test_restore_state(hass): assert state_bogus.state == str(initial) -async def test_input_datetime_context(hass): +async def test_input_datetime_context(hass, hass_admin_user): """Test that input_datetime context works.""" assert await async_setup_component(hass, 'input_datetime', { 'input_datetime': { @@ -211,9 +211,9 @@ async def test_input_datetime_context(hass): await hass.services.async_call('input_datetime', 'set_datetime', { 'entity_id': state.entity_id, 'date': '2018-01-02' - }, True, Context(user_id='abcd')) + }, True, Context(user_id=hass_admin_user.id)) state2 = hass.states.get('input_datetime.only_date') assert state2 is not None assert state.state != state2.state - assert state2.context.user_id == 'abcd' + assert state2.context.user_id == hass_admin_user.id diff --git a/tests/components/test_input_number.py b/tests/components/test_input_number.py index 354c67b4d1b6..70dfeec2e7fb 100644 --- a/tests/components/test_input_number.py +++ b/tests/components/test_input_number.py @@ -266,7 +266,7 @@ def test_no_initial_state_and_no_restore_state(hass): assert float(state.state) == 0 -async def test_input_number_context(hass): +async def test_input_number_context(hass, hass_admin_user): """Test that input_number context works.""" assert await async_setup_component(hass, 'input_number', { 'input_number': { @@ -282,9 +282,9 @@ async def test_input_number_context(hass): await hass.services.async_call('input_number', 'increment', { 'entity_id': state.entity_id, - }, True, Context(user_id='abcd')) + }, True, Context(user_id=hass_admin_user.id)) state2 = hass.states.get('input_number.b1') assert state2 is not None assert state.state != state2.state - assert state2.context.user_id == 'abcd' + assert state2.context.user_id == hass_admin_user.id diff --git a/tests/components/test_input_select.py b/tests/components/test_input_select.py index f37566ffd737..528560edc049 100644 --- a/tests/components/test_input_select.py +++ b/tests/components/test_input_select.py @@ -302,7 +302,7 @@ def test_initial_state_overrules_restore_state(hass): assert state.state == 'middle option' -async def test_input_select_context(hass): +async def test_input_select_context(hass, hass_admin_user): """Test that input_select context works.""" assert await async_setup_component(hass, 'input_select', { 'input_select': { @@ -321,9 +321,9 @@ async def test_input_select_context(hass): await hass.services.async_call('input_select', 'select_next', { 'entity_id': state.entity_id, - }, True, Context(user_id='abcd')) + }, True, Context(user_id=hass_admin_user.id)) state2 = hass.states.get('input_select.s1') assert state2 is not None assert state.state != state2.state - assert state2.context.user_id == 'abcd' + assert state2.context.user_id == hass_admin_user.id diff --git a/tests/components/test_input_text.py b/tests/components/test_input_text.py index 7e8cec6ff803..f0dec42ccea1 100644 --- a/tests/components/test_input_text.py +++ b/tests/components/test_input_text.py @@ -184,7 +184,7 @@ def test_no_initial_state_and_no_restore_state(hass): assert str(state.state) == 'unknown' -async def test_input_text_context(hass): +async def test_input_text_context(hass, hass_admin_user): """Test that input_text context works.""" assert await async_setup_component(hass, 'input_text', { 'input_text': { @@ -200,9 +200,9 @@ async def test_input_text_context(hass): await hass.services.async_call('input_text', 'set_value', { 'entity_id': state.entity_id, 'value': 'new_value', - }, True, Context(user_id='abcd')) + }, True, Context(user_id=hass_admin_user.id)) state2 = hass.states.get('input_text.t1') assert state2 is not None assert state.state != state2.state - assert state2.context.user_id == 'abcd' + assert state2.context.user_id == hass_admin_user.id diff --git a/tests/helpers/test_service.py b/tests/helpers/test_service.py index 71775574c280..a4e9a5719434 100644 --- a/tests/helpers/test_service.py +++ b/tests/helpers/test_service.py @@ -1,18 +1,49 @@ """Test service helpers.""" import asyncio +from collections import OrderedDict from copy import deepcopy import unittest -from unittest.mock import patch +from unittest.mock import Mock, patch + +import pytest # To prevent circular import when running just this file import homeassistant.components # noqa -from homeassistant import core as ha, loader +from homeassistant import core as ha, loader, exceptions from homeassistant.const import STATE_ON, STATE_OFF, ATTR_ENTITY_ID from homeassistant.helpers import service, template from homeassistant.setup import async_setup_component import homeassistant.helpers.config_validation as cv - -from tests.common import get_test_home_assistant, mock_service +from homeassistant.auth.permissions import PolicyPermissions + +from tests.common import get_test_home_assistant, mock_service, mock_coro + + +@pytest.fixture +def mock_service_platform_call(): + """Mock service platform call.""" + with patch('homeassistant.helpers.service._handle_service_platform_call', + side_effect=lambda *args: mock_coro()) as mock_call: + yield mock_call + + +@pytest.fixture +def mock_entities(): + """Return mock entities in an ordered dict.""" + kitchen = Mock( + entity_id='light.kitchen', + available=True, + should_poll=False, + ) + living_room = Mock( + entity_id='light.living_room', + available=True, + should_poll=False, + ) + entities = OrderedDict() + entities[kitchen.entity_id] = kitchen + entities[living_room.entity_id] = living_room + return entities class TestServiceHelpers(unittest.TestCase): @@ -179,3 +210,99 @@ def test_async_get_all_descriptions(hass): assert 'description' in descriptions[logger.DOMAIN]['set_level'] assert 'fields' in descriptions[logger.DOMAIN]['set_level'] + + +async def test_call_context_user_not_exist(hass): + """Check we don't allow deleted users to do things.""" + with pytest.raises(exceptions.UnknownUser) as err: + await service.entity_service_call(hass, [], Mock(), ha.ServiceCall( + 'test_domain', 'test_service', context=ha.Context( + user_id='non-existing'))) + + assert err.value.context.user_id == 'non-existing' + + +async def test_call_context_target_all(hass, mock_service_platform_call, + mock_entities): + """Check we only target allowed entities if targetting all.""" + with patch('homeassistant.auth.AuthManager.async_get_user', + return_value=mock_coro(Mock(permissions=PolicyPermissions({ + 'entities': { + 'entity_ids': { + 'light.kitchen': True + } + } + })))): + await service.entity_service_call(hass, [ + Mock(entities=mock_entities) + ], Mock(), ha.ServiceCall('test_domain', 'test_service', + context=ha.Context(user_id='mock-id'))) + + assert len(mock_service_platform_call.mock_calls) == 1 + entities = mock_service_platform_call.mock_calls[0][1][2] + assert entities == [mock_entities['light.kitchen']] + + +async def test_call_context_target_specific(hass, mock_service_platform_call, + mock_entities): + """Check targeting specific entities.""" + with patch('homeassistant.auth.AuthManager.async_get_user', + return_value=mock_coro(Mock(permissions=PolicyPermissions({ + 'entities': { + 'entity_ids': { + 'light.kitchen': True + } + } + })))): + await service.entity_service_call(hass, [ + Mock(entities=mock_entities) + ], Mock(), ha.ServiceCall('test_domain', 'test_service', { + 'entity_id': 'light.kitchen' + }, context=ha.Context(user_id='mock-id'))) + + assert len(mock_service_platform_call.mock_calls) == 1 + entities = mock_service_platform_call.mock_calls[0][1][2] + assert entities == [mock_entities['light.kitchen']] + + +async def test_call_context_target_specific_no_auth( + hass, mock_service_platform_call, mock_entities): + """Check targeting specific entities without auth.""" + with pytest.raises(exceptions.Unauthorized) as err: + with patch('homeassistant.auth.AuthManager.async_get_user', + return_value=mock_coro(Mock( + permissions=PolicyPermissions({})))): + await service.entity_service_call(hass, [ + Mock(entities=mock_entities) + ], Mock(), ha.ServiceCall('test_domain', 'test_service', { + 'entity_id': 'light.kitchen' + }, context=ha.Context(user_id='mock-id'))) + + assert err.value.context.user_id == 'mock-id' + assert err.value.entity_id == 'light.kitchen' + + +async def test_call_no_context_target_all(hass, mock_service_platform_call, + mock_entities): + """Check we target all if no user context given.""" + await service.entity_service_call(hass, [ + Mock(entities=mock_entities) + ], Mock(), ha.ServiceCall('test_domain', 'test_service')) + + assert len(mock_service_platform_call.mock_calls) == 1 + entities = mock_service_platform_call.mock_calls[0][1][2] + assert entities == list(mock_entities.values()) + + +async def test_call_no_context_target_specific( + hass, mock_service_platform_call, mock_entities): + """Check we can target specified entities.""" + await service.entity_service_call(hass, [ + Mock(entities=mock_entities) + ], Mock(), ha.ServiceCall('test_domain', 'test_service', { + 'entity_id': ['light.kitchen', 'light.non-existing'] + })) + + assert len(mock_service_platform_call.mock_calls) == 1 + entities = mock_service_platform_call.mock_calls[0][1][2] + assert entities == [mock_entities['light.kitchen']]