diff --git a/homeassistant/auth/models.py b/homeassistant/auth/models.py index cefaabe752140c..6a51f1b2f38432 100644 --- a/homeassistant/auth/models.py +++ b/homeassistant/auth/models.py @@ -8,6 +8,7 @@ from homeassistant.util import dt as dt_util from . import permissions as perm_mdl +from .const import GROUP_ID_ADMIN from .util import generate_secret TOKEN_TYPE_NORMAL = 'normal' @@ -69,6 +70,15 @@ def permissions(self) -> perm_mdl.AbstractPermissions: return self._permissions + @property + def is_admin(self) -> bool: + """Return if user is part of the admin group.""" + if self.is_owner: + return True + + return self.is_active and any( + gr.id == GROUP_ID_ADMIN for gr in self.groups) + @attr.s(slots=True) class RefreshToken: diff --git a/homeassistant/auth/permissions/__init__.py b/homeassistant/auth/permissions/__init__.py index fd3cf81f029589..12b6a2ae2c2b9a 100644 --- a/homeassistant/auth/permissions/__init__.py +++ b/homeassistant/auth/permissions/__init__.py @@ -22,13 +22,19 @@ class AbstractPermissions: """Default permissions class.""" + def entity_func(self) -> Callable[[str, Tuple[str, ...]], bool]: + """Return a function that can test entity access.""" + raise NotImplementedError + def check_entity(self, entity_id: str, key: str) -> bool: """Test if we can access entity.""" - raise NotImplementedError + return self.entity_func()(entity_id, (key,)) def filter_states(self, states: List[State]) -> List[State]: """Filter a list of states for what the user is allowed to see.""" - raise NotImplementedError + func = self.entity_func() + keys = ('read',) + return [entity for entity in states if func(entity.entity_id, keys)] class PolicyPermissions(AbstractPermissions): @@ -39,16 +45,9 @@ def __init__(self, policy: PolicyType) -> None: self._policy = policy self._compiled = {} # type: Dict[str, Callable[..., bool]] - def check_entity(self, entity_id: str, key: str) -> bool: - """Test if we can access entity.""" - func = self._policy_func(CAT_ENTITIES, compile_entities) - return func(entity_id, (key,)) - - def filter_states(self, states: List[State]) -> List[State]: - """Filter a list of states for what the user is allowed to see.""" - func = self._policy_func(CAT_ENTITIES, compile_entities) - keys = ('read',) - return [entity for entity in states if func(entity.entity_id, keys)] + def entity_func(self) -> Callable[[str, Tuple[str, ...]], bool]: + """Return a function that can test entity access.""" + return self._policy_func(CAT_ENTITIES, compile_entities) def _policy_func(self, category: str, compile_func: Callable[[CategoryType], Callable]) \ @@ -78,13 +77,9 @@ class _OwnerPermissions(AbstractPermissions): # pylint: disable=no-self-use - def check_entity(self, entity_id: str, key: str) -> bool: - """Test if we can access entity.""" - return True - - def filter_states(self, states: List[State]) -> List[State]: - """Filter a list of states for what the user is allowed to see.""" - return states + def entity_func(self) -> Callable[[str, Tuple[str, ...]], bool]: + """Return a function that can test entity access.""" + return lambda entity_id, keys: True OwnerPermissions = _OwnerPermissions() # pylint: disable=invalid-name diff --git a/homeassistant/components/api.py b/homeassistant/components/api.py index cbe404537ebd52..900bb92ccef2c9 100644 --- a/homeassistant/components/api.py +++ b/homeassistant/components/api.py @@ -20,7 +20,8 @@ URL_API_SERVICES, URL_API_STATES, URL_API_STATES_ENTITY, URL_API_STREAM, URL_API_TEMPLATE, __version__) import homeassistant.core as ha -from homeassistant.exceptions import TemplateError +from homeassistant.auth.permissions.const import POLICY_READ +from homeassistant.exceptions import TemplateError, Unauthorized from homeassistant.helpers import template from homeassistant.helpers.service import async_get_all_descriptions from homeassistant.helpers.state import AsyncTrackStates @@ -81,6 +82,8 @@ class APIEventStream(HomeAssistantView): async def get(self, request): """Provide a streaming interface for the event bus.""" + if not request['hass_user'].is_admin: + raise Unauthorized() hass = request.app['hass'] stop_obj = object() to_write = asyncio.Queue(loop=hass.loop) @@ -185,7 +188,10 @@ class APIStatesView(HomeAssistantView): @ha.callback def get(self, request): """Get current states.""" - return self.json(request.app['hass'].states.async_all()) + user = request['hass_user'] + states = user.permissions.filter_states( + request.app['hass'].states.async_all()) + return self.json(states) class APIEntityStateView(HomeAssistantView): @@ -197,6 +203,10 @@ class APIEntityStateView(HomeAssistantView): @ha.callback def get(self, request, entity_id): """Retrieve state of entity.""" + user = request['hass_user'] + if not user.permissions.check_entity(entity_id, POLICY_READ): + raise Unauthorized(entity_id=entity_id) + state = request.app['hass'].states.get(entity_id) if state: return self.json(state) @@ -204,6 +214,8 @@ def get(self, request, entity_id): async def post(self, request, entity_id): """Update state of entity.""" + if not request['hass_user'].is_admin: + raise Unauthorized(entity_id=entity_id) hass = request.app['hass'] try: data = await request.json() @@ -236,6 +248,8 @@ async def post(self, request, entity_id): @ha.callback def delete(self, request, entity_id): """Remove entity.""" + if not request['hass_user'].is_admin: + raise Unauthorized(entity_id=entity_id) if request.app['hass'].states.async_remove(entity_id): return self.json_message("Entity removed.") return self.json_message("Entity not found.", HTTP_NOT_FOUND) @@ -261,6 +275,8 @@ class APIEventView(HomeAssistantView): async def post(self, request, event_type): """Fire events.""" + if not request['hass_user'].is_admin: + raise Unauthorized() body = await request.text() try: event_data = json.loads(body) if body else None @@ -346,6 +362,8 @@ class APITemplateView(HomeAssistantView): async def post(self, request): """Render a template.""" + if not request['hass_user'].is_admin: + raise Unauthorized() try: data = await request.json() tpl = template.Template(data['template'], request.app['hass']) @@ -363,6 +381,8 @@ class APIErrorLog(HomeAssistantView): async def get(self, request): """Retrieve API error log.""" + if not request['hass_user'].is_admin: + raise Unauthorized() return web.FileResponse(request.app['hass'].data[DATA_LOGGING]) diff --git a/homeassistant/components/http/view.py b/homeassistant/components/http/view.py index b3b2587fc458e8..30d4ed0ab8da73 100644 --- a/homeassistant/components/http/view.py +++ b/homeassistant/components/http/view.py @@ -14,6 +14,7 @@ from homeassistant.components.http.ban import process_success_login from homeassistant.core import Context, is_callback from homeassistant.const import CONTENT_TYPE_JSON +from homeassistant import exceptions from homeassistant.helpers.json import JSONEncoder from .const import KEY_AUTHENTICATED, KEY_REAL_IP @@ -107,10 +108,13 @@ async def handle(request): _LOGGER.info('Serving %s to %s (auth: %s)', request.path, request.get(KEY_REAL_IP), authenticated) - result = handler(request, **request.match_info) + try: + result = handler(request, **request.match_info) - if asyncio.iscoroutine(result): - result = await result + if asyncio.iscoroutine(result): + result = await result + except exceptions.Unauthorized: + raise HTTPUnauthorized() if isinstance(result, web.StreamResponse): # The method handler returned a ready-made Response, how nice of it diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index 5e0d9c7e88afc3..91ada6b973148d 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -192,9 +192,9 @@ async def entity_service_call(hass, platforms, func, call): user = await hass.auth.async_get_user(call.context.user_id) if user is None: raise UnknownUser(context=call.context) - perms = user.permissions + entity_perms = user.permissions.entity_func() else: - perms = None + entity_perms = None # Are we trying to target all entities target_all_entities = ATTR_ENTITY_ID not in call.data @@ -218,7 +218,7 @@ async def entity_service_call(hass, platforms, func, call): # the service on. platforms_entities = [] - if perms is None: + if entity_perms is None: for platform in platforms: if target_all_entities: platforms_entities.append(list(platform.entities.values())) @@ -234,7 +234,7 @@ async def entity_service_call(hass, platforms, func, call): for platform in platforms: platforms_entities.append([ entity for entity in platform.entities.values() - if perms.check_entity(entity.entity_id, POLICY_CONTROL)]) + if entity_perms(entity.entity_id, POLICY_CONTROL)]) else: for platform in platforms: @@ -243,7 +243,7 @@ async def entity_service_call(hass, platforms, func, call): if entity.entity_id not in entity_ids: continue - if not perms.check_entity(entity.entity_id, POLICY_CONTROL): + if not entity_perms(entity.entity_id, POLICY_CONTROL): raise Unauthorized( context=call.context, entity_id=entity.entity_id, diff --git a/tests/common.py b/tests/common.py index c6a75fcb63d8ed..d5056e220f0156 100644 --- a/tests/common.py +++ b/tests/common.py @@ -14,7 +14,8 @@ from homeassistant import auth, core as ha, config_entries from homeassistant.auth import ( - models as auth_models, auth_store, providers as auth_providers) + models as auth_models, auth_store, providers as auth_providers, + permissions as auth_permissions) from homeassistant.auth.permissions import system_policies from homeassistant.setup import setup_component, async_setup_component from homeassistant.config import async_process_component_config @@ -400,6 +401,10 @@ def add_to_auth_manager(self, auth_mgr): auth_mgr._store._users[self.id] = self return self + def mock_policy(self, policy): + """Mock a policy for a user.""" + self._permissions = auth_permissions.PolicyPermissions(policy) + async def register_auth_provider(hass, config): """Register an auth provider.""" diff --git a/tests/components/conftest.py b/tests/components/conftest.py index 2568a1092448e3..97f2044baea84b 100644 --- a/tests/components/conftest.py +++ b/tests/components/conftest.py @@ -72,11 +72,10 @@ async def create_client(hass, access_token=None): @pytest.fixture -def hass_access_token(hass): +def hass_access_token(hass, hass_admin_user): """Return an access token to access Home Assistant.""" - user = MockUser().add_to_hass(hass) refresh_token = hass.loop.run_until_complete( - hass.auth.async_create_refresh_token(user, CLIENT_ID)) + hass.auth.async_create_refresh_token(hass_admin_user, CLIENT_ID)) yield hass.auth.async_create_access_token(refresh_token) diff --git a/tests/components/test_api.py b/tests/components/test_api.py index 6f6b4e93068e21..3ebfa05a3d39c7 100644 --- a/tests/components/test_api.py +++ b/tests/components/test_api.py @@ -16,10 +16,12 @@ @pytest.fixture -def mock_api_client(hass, aiohttp_client): - """Start the Hass HTTP component.""" +def mock_api_client(hass, aiohttp_client, hass_access_token): + """Start the Hass HTTP component and return admin API client.""" hass.loop.run_until_complete(async_setup_component(hass, 'api', {})) - return hass.loop.run_until_complete(aiohttp_client(hass.http.app)) + return hass.loop.run_until_complete(aiohttp_client(hass.http.app, headers={ + 'Authorization': 'Bearer {}'.format(hass_access_token) + })) @asyncio.coroutine @@ -405,7 +407,8 @@ def _listen_count(hass): return sum(hass.bus.async_listeners().values()) -async def test_api_error_log(hass, aiohttp_client): +async def test_api_error_log(hass, aiohttp_client, hass_access_token, + hass_admin_user): """Test if we can fetch the error log.""" hass.data[DATA_LOGGING] = '/some/path' await async_setup_component(hass, 'api', { @@ -416,7 +419,7 @@ async def test_api_error_log(hass, aiohttp_client): client = await aiohttp_client(hass.http.app) resp = await client.get(const.URL_API_ERROR_LOG) - # Verufy auth required + # Verify auth required assert resp.status == 401 with patch( @@ -424,7 +427,7 @@ async def test_api_error_log(hass, aiohttp_client): return_value=web.Response(status=200, text='Hello') ) as mock_file: resp = await client.get(const.URL_API_ERROR_LOG, headers={ - 'x-ha-access': 'yolo' + 'Authorization': 'Bearer {}'.format(hass_access_token) }) assert len(mock_file.mock_calls) == 1 @@ -432,6 +435,13 @@ async def test_api_error_log(hass, aiohttp_client): assert resp.status == 200 assert await resp.text() == 'Hello' + # Verify we require admin user + hass_admin_user.groups = [] + resp = await client.get(const.URL_API_ERROR_LOG, headers={ + 'Authorization': 'Bearer {}'.format(hass_access_token) + }) + assert resp.status == 401 + async def test_api_fire_event_context(hass, mock_api_client, hass_access_token): @@ -494,3 +504,67 @@ async def test_api_set_state_context(hass, mock_api_client, hass_access_token): state = hass.states.get('light.kitchen') assert state.context.user_id == refresh_token.user.id + + +async def test_event_stream_requires_admin(hass, mock_api_client, + hass_admin_user): + """Test user needs to be admin to access event stream.""" + hass_admin_user.groups = [] + resp = await mock_api_client.get('/api/stream') + assert resp.status == 401 + + +async def test_states_view_filters(hass, mock_api_client, hass_admin_user): + """Test filtering only visible states.""" + hass_admin_user.mock_policy({ + 'entities': { + 'entity_ids': { + 'test.entity': True + } + } + }) + hass.states.async_set('test.entity', 'hello') + hass.states.async_set('test.not_visible_entity', 'invisible') + resp = await mock_api_client.get(const.URL_API_STATES) + assert resp.status == 200 + json = await resp.json() + assert len(json) == 1 + assert json[0]['entity_id'] == 'test.entity' + + +async def test_get_entity_state_read_perm(hass, mock_api_client, + hass_admin_user): + """Test getting a state requires read permission.""" + hass_admin_user.mock_policy({}) + resp = await mock_api_client.get('/api/states/light.test') + assert resp.status == 401 + + +async def test_post_entity_state_admin(hass, mock_api_client, hass_admin_user): + """Test updating state requires admin.""" + hass_admin_user.groups = [] + resp = await mock_api_client.post('/api/states/light.test') + assert resp.status == 401 + + +async def test_delete_entity_state_admin(hass, mock_api_client, + hass_admin_user): + """Test deleting entity requires admin.""" + hass_admin_user.groups = [] + resp = await mock_api_client.delete('/api/states/light.test') + assert resp.status == 401 + + +async def test_post_event_admin(hass, mock_api_client, hass_admin_user): + """Test sending event requires admin.""" + hass_admin_user.groups = [] + resp = await mock_api_client.post('/api/events/state_changed') + assert resp.status == 401 + + +async def test_rendering_template_admin(hass, mock_api_client, + hass_admin_user): + """Test rendering a template requires admin.""" + hass_admin_user.groups = [] + resp = await mock_api_client.post('/api/template') + assert resp.status == 401