Skip to content

Commit

Permalink
Add permission checks to Rest API
Browse files Browse the repository at this point in the history
  • Loading branch information
balloob committed Nov 22, 2018
1 parent 36c31a6 commit 4a7d647
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 39 deletions.
10 changes: 10 additions & 0 deletions homeassistant/auth/models.py
Expand Up @@ -8,6 +8,7 @@
from homeassistant.util import dt as dt_util

from . import permissions as perm_mdl
from .const import GROUP_ID_ADMIN
from .util import generate_secret

TOKEN_TYPE_NORMAL = 'normal'
Expand Down Expand Up @@ -69,6 +70,15 @@ def permissions(self) -> perm_mdl.AbstractPermissions:

return self._permissions

@property
def is_admin(self) -> bool:
"""Return if user is part of the admin group."""
if self.is_owner:
return True

return self.is_active and any(
gr.id == GROUP_ID_ADMIN for gr in self.groups)


@attr.s(slots=True)
class RefreshToken:
Expand Down
33 changes: 14 additions & 19 deletions homeassistant/auth/permissions/__init__.py
Expand Up @@ -22,13 +22,19 @@
class AbstractPermissions:
"""Default permissions class."""

def entity_func(self) -> Callable[[str, Tuple[str, ...]], bool]:
"""Return a function that can test entity access."""
raise NotImplementedError

def check_entity(self, entity_id: str, key: str) -> bool:
"""Test if we can access entity."""
raise NotImplementedError
return self.entity_func()(entity_id, (key,))

def filter_states(self, states: List[State]) -> List[State]:
"""Filter a list of states for what the user is allowed to see."""
raise NotImplementedError
func = self.entity_func()
keys = ('read',)
return [entity for entity in states if func(entity.entity_id, keys)]


class PolicyPermissions(AbstractPermissions):
Expand All @@ -39,16 +45,9 @@ def __init__(self, policy: PolicyType) -> None:
self._policy = policy
self._compiled = {} # type: Dict[str, Callable[..., bool]]

def check_entity(self, entity_id: str, key: str) -> bool:
"""Test if we can access entity."""
func = self._policy_func(CAT_ENTITIES, compile_entities)
return func(entity_id, (key,))

def filter_states(self, states: List[State]) -> List[State]:
"""Filter a list of states for what the user is allowed to see."""
func = self._policy_func(CAT_ENTITIES, compile_entities)
keys = ('read',)
return [entity for entity in states if func(entity.entity_id, keys)]
def entity_func(self) -> Callable[[str, Tuple[str, ...]], bool]:
"""Return a function that can test entity access."""
return self._policy_func(CAT_ENTITIES, compile_entities)

def _policy_func(self, category: str,
compile_func: Callable[[CategoryType], Callable]) \
Expand Down Expand Up @@ -78,13 +77,9 @@ class _OwnerPermissions(AbstractPermissions):

# pylint: disable=no-self-use

def check_entity(self, entity_id: str, key: str) -> bool:
"""Test if we can access entity."""
return True

def filter_states(self, states: List[State]) -> List[State]:
"""Filter a list of states for what the user is allowed to see."""
return states
def entity_func(self) -> Callable[[str, Tuple[str, ...]], bool]:
"""Return a function that can test entity access."""
return lambda entity_id, keys: True


OwnerPermissions = _OwnerPermissions() # pylint: disable=invalid-name
24 changes: 22 additions & 2 deletions homeassistant/components/api.py
Expand Up @@ -20,7 +20,8 @@
URL_API_SERVICES, URL_API_STATES, URL_API_STATES_ENTITY, URL_API_STREAM,
URL_API_TEMPLATE, __version__)
import homeassistant.core as ha
from homeassistant.exceptions import TemplateError
from homeassistant.auth.permissions.const import POLICY_READ
from homeassistant.exceptions import TemplateError, Unauthorized
from homeassistant.helpers import template
from homeassistant.helpers.service import async_get_all_descriptions
from homeassistant.helpers.state import AsyncTrackStates
Expand Down Expand Up @@ -81,6 +82,8 @@ class APIEventStream(HomeAssistantView):

async def get(self, request):
"""Provide a streaming interface for the event bus."""
if not request['hass_user'].is_admin:
raise Unauthorized()
hass = request.app['hass']
stop_obj = object()
to_write = asyncio.Queue(loop=hass.loop)
Expand Down Expand Up @@ -185,7 +188,10 @@ class APIStatesView(HomeAssistantView):
@ha.callback
def get(self, request):
"""Get current states."""
return self.json(request.app['hass'].states.async_all())
user = request['hass_user']
states = user.permissions.filter_states(
request.app['hass'].states.async_all())
return self.json(states)


class APIEntityStateView(HomeAssistantView):
Expand All @@ -197,13 +203,19 @@ class APIEntityStateView(HomeAssistantView):
@ha.callback
def get(self, request, entity_id):
"""Retrieve state of entity."""
user = request['hass_user']
if not user.permissions.check_entity(entity_id, POLICY_READ):
raise Unauthorized(entity_id=entity_id)

state = request.app['hass'].states.get(entity_id)
if state:
return self.json(state)
return self.json_message("Entity not found.", HTTP_NOT_FOUND)

async def post(self, request, entity_id):
"""Update state of entity."""
if not request['hass_user'].is_admin:
raise Unauthorized(entity_id=entity_id)
hass = request.app['hass']
try:
data = await request.json()
Expand Down Expand Up @@ -236,6 +248,8 @@ async def post(self, request, entity_id):
@ha.callback
def delete(self, request, entity_id):
"""Remove entity."""
if not request['hass_user'].is_admin:
raise Unauthorized(entity_id=entity_id)
if request.app['hass'].states.async_remove(entity_id):
return self.json_message("Entity removed.")
return self.json_message("Entity not found.", HTTP_NOT_FOUND)
Expand All @@ -261,6 +275,8 @@ class APIEventView(HomeAssistantView):

async def post(self, request, event_type):
"""Fire events."""
if not request['hass_user'].is_admin:
raise Unauthorized()
body = await request.text()
try:
event_data = json.loads(body) if body else None
Expand Down Expand Up @@ -346,6 +362,8 @@ class APITemplateView(HomeAssistantView):

async def post(self, request):
"""Render a template."""
if not request['hass_user'].is_admin:
raise Unauthorized()
try:
data = await request.json()
tpl = template.Template(data['template'], request.app['hass'])
Expand All @@ -363,6 +381,8 @@ class APIErrorLog(HomeAssistantView):

async def get(self, request):
"""Retrieve API error log."""
if not request['hass_user'].is_admin:
raise Unauthorized()
return web.FileResponse(request.app['hass'].data[DATA_LOGGING])


Expand Down
10 changes: 7 additions & 3 deletions homeassistant/components/http/view.py
Expand Up @@ -14,6 +14,7 @@
from homeassistant.components.http.ban import process_success_login
from homeassistant.core import Context, is_callback
from homeassistant.const import CONTENT_TYPE_JSON
from homeassistant import exceptions
from homeassistant.helpers.json import JSONEncoder

from .const import KEY_AUTHENTICATED, KEY_REAL_IP
Expand Down Expand Up @@ -107,10 +108,13 @@ async def handle(request):
_LOGGER.info('Serving %s to %s (auth: %s)',
request.path, request.get(KEY_REAL_IP), authenticated)

result = handler(request, **request.match_info)
try:
result = handler(request, **request.match_info)

if asyncio.iscoroutine(result):
result = await result
if asyncio.iscoroutine(result):
result = await result
except exceptions.Unauthorized:
raise HTTPUnauthorized()

if isinstance(result, web.StreamResponse):
# The method handler returned a ready-made Response, how nice of it
Expand Down
10 changes: 5 additions & 5 deletions homeassistant/helpers/service.py
Expand Up @@ -192,9 +192,9 @@ async def entity_service_call(hass, platforms, func, call):
user = await hass.auth.async_get_user(call.context.user_id)
if user is None:
raise UnknownUser(context=call.context)
perms = user.permissions
entity_perms = user.permissions.entity_func()
else:
perms = None
entity_perms = None

# Are we trying to target all entities
target_all_entities = ATTR_ENTITY_ID not in call.data
Expand All @@ -218,7 +218,7 @@ async def entity_service_call(hass, platforms, func, call):
# the service on.
platforms_entities = []

if perms is None:
if entity_perms is None:
for platform in platforms:
if target_all_entities:
platforms_entities.append(list(platform.entities.values()))
Expand All @@ -234,7 +234,7 @@ async def entity_service_call(hass, platforms, func, call):
for platform in platforms:
platforms_entities.append([
entity for entity in platform.entities.values()
if perms.check_entity(entity.entity_id, POLICY_CONTROL)])
if entity_perms(entity.entity_id, POLICY_CONTROL)])

else:
for platform in platforms:
Expand All @@ -243,7 +243,7 @@ async def entity_service_call(hass, platforms, func, call):
if entity.entity_id not in entity_ids:
continue

if not perms.check_entity(entity.entity_id, POLICY_CONTROL):
if not entity_perms(entity.entity_id, POLICY_CONTROL):
raise Unauthorized(
context=call.context,
entity_id=entity.entity_id,
Expand Down
7 changes: 6 additions & 1 deletion tests/common.py
Expand Up @@ -14,7 +14,8 @@

from homeassistant import auth, core as ha, config_entries
from homeassistant.auth import (
models as auth_models, auth_store, providers as auth_providers)
models as auth_models, auth_store, providers as auth_providers,
permissions as auth_permissions)
from homeassistant.auth.permissions import system_policies
from homeassistant.setup import setup_component, async_setup_component
from homeassistant.config import async_process_component_config
Expand Down Expand Up @@ -400,6 +401,10 @@ def add_to_auth_manager(self, auth_mgr):
auth_mgr._store._users[self.id] = self
return self

def mock_policy(self, policy):
"""Mock a policy for a user."""
self._permissions = auth_permissions.PolicyPermissions(policy)


async def register_auth_provider(hass, config):
"""Register an auth provider."""
Expand Down
5 changes: 2 additions & 3 deletions tests/components/conftest.py
Expand Up @@ -72,11 +72,10 @@ async def create_client(hass, access_token=None):


@pytest.fixture
def hass_access_token(hass):
def hass_access_token(hass, hass_admin_user):
"""Return an access token to access Home Assistant."""
user = MockUser().add_to_hass(hass)
refresh_token = hass.loop.run_until_complete(
hass.auth.async_create_refresh_token(user, CLIENT_ID))
hass.auth.async_create_refresh_token(hass_admin_user, CLIENT_ID))
yield hass.auth.async_create_access_token(refresh_token)


Expand Down
86 changes: 80 additions & 6 deletions tests/components/test_api.py
Expand Up @@ -16,10 +16,12 @@


@pytest.fixture
def mock_api_client(hass, aiohttp_client):
"""Start the Hass HTTP component."""
def mock_api_client(hass, aiohttp_client, hass_access_token):
"""Start the Hass HTTP component and return admin API client."""
hass.loop.run_until_complete(async_setup_component(hass, 'api', {}))
return hass.loop.run_until_complete(aiohttp_client(hass.http.app))
return hass.loop.run_until_complete(aiohttp_client(hass.http.app, headers={
'Authorization': 'Bearer {}'.format(hass_access_token)
}))


@asyncio.coroutine
Expand Down Expand Up @@ -405,7 +407,8 @@ def _listen_count(hass):
return sum(hass.bus.async_listeners().values())


async def test_api_error_log(hass, aiohttp_client):
async def test_api_error_log(hass, aiohttp_client, hass_access_token,
hass_admin_user):
"""Test if we can fetch the error log."""
hass.data[DATA_LOGGING] = '/some/path'
await async_setup_component(hass, 'api', {
Expand All @@ -416,22 +419,29 @@ async def test_api_error_log(hass, aiohttp_client):
client = await aiohttp_client(hass.http.app)

resp = await client.get(const.URL_API_ERROR_LOG)
# Verufy auth required
# Verify auth required
assert resp.status == 401

with patch(
'aiohttp.web.FileResponse',
return_value=web.Response(status=200, text='Hello')
) as mock_file:
resp = await client.get(const.URL_API_ERROR_LOG, headers={
'x-ha-access': 'yolo'
'Authorization': 'Bearer {}'.format(hass_access_token)
})

assert len(mock_file.mock_calls) == 1
assert mock_file.mock_calls[0][1][0] == hass.data[DATA_LOGGING]
assert resp.status == 200
assert await resp.text() == 'Hello'

# Verify we require admin user
hass_admin_user.groups = []
resp = await client.get(const.URL_API_ERROR_LOG, headers={
'Authorization': 'Bearer {}'.format(hass_access_token)
})
assert resp.status == 401


async def test_api_fire_event_context(hass, mock_api_client,
hass_access_token):
Expand Down Expand Up @@ -494,3 +504,67 @@ async def test_api_set_state_context(hass, mock_api_client, hass_access_token):

state = hass.states.get('light.kitchen')
assert state.context.user_id == refresh_token.user.id


async def test_event_stream_requires_admin(hass, mock_api_client,
hass_admin_user):
"""Test user needs to be admin to access event stream."""
hass_admin_user.groups = []
resp = await mock_api_client.get('/api/stream')
assert resp.status == 401


async def test_states_view_filters(hass, mock_api_client, hass_admin_user):
"""Test filtering only visible states."""
hass_admin_user.mock_policy({
'entities': {
'entity_ids': {
'test.entity': True
}
}
})
hass.states.async_set('test.entity', 'hello')
hass.states.async_set('test.not_visible_entity', 'invisible')
resp = await mock_api_client.get(const.URL_API_STATES)
assert resp.status == 200
json = await resp.json()
assert len(json) == 1
assert json[0]['entity_id'] == 'test.entity'


async def test_get_entity_state_read_perm(hass, mock_api_client,
hass_admin_user):
"""Test getting a state requires read permission."""
hass_admin_user.mock_policy({})
resp = await mock_api_client.get('/api/states/light.test')
assert resp.status == 401


async def test_post_entity_state_admin(hass, mock_api_client, hass_admin_user):
"""Test updating state requires admin."""
hass_admin_user.groups = []
resp = await mock_api_client.post('/api/states/light.test')
assert resp.status == 401


async def test_delete_entity_state_admin(hass, mock_api_client,
hass_admin_user):
"""Test deleting entity requires admin."""
hass_admin_user.groups = []
resp = await mock_api_client.delete('/api/states/light.test')
assert resp.status == 401


async def test_post_event_admin(hass, mock_api_client, hass_admin_user):
"""Test sending event requires admin."""
hass_admin_user.groups = []
resp = await mock_api_client.post('/api/events/state_changed')
assert resp.status == 401


async def test_rendering_template_admin(hass, mock_api_client,
hass_admin_user):
"""Test rendering a template requires admin."""
hass_admin_user.groups = []
resp = await mock_api_client.post('/api/template')
assert resp.status == 401

0 comments on commit 4a7d647

Please sign in to comment.