Skip to content

Commit

Permalink
Resolve auth_store loading race condition (#21794)
Browse files Browse the repository at this point in the history
* Add lock in auth_store._async_load()

* Python 3.5 does not like assert_called_once()
  • Loading branch information
awarecan authored and balloob committed Mar 8, 2019
1 parent 3ff2d99 commit 3d8673d
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 11 deletions.
8 changes: 8 additions & 0 deletions homeassistant/auth/auth_store.py
Expand Up @@ -38,6 +38,7 @@ def __init__(self, hass: HomeAssistant) -> None:
self._perm_lookup = None # type: Optional[PermissionLookup]
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY,
private=True)
self._lock = asyncio.Lock()

async def async_get_groups(self) -> List[models.Group]:
"""Retrieve all users."""
Expand Down Expand Up @@ -271,6 +272,13 @@ def async_log_refresh_token_usage(
self._async_schedule_save()

async def _async_load(self) -> None:
"""Load the users."""
async with self._lock:
if self._users is not None:
return
await self._async_load_task()

async def _async_load_task(self) -> None:
"""Load the users."""
[ent_reg, data] = await asyncio.gather(
self.hass.helpers.entity_registry.async_get_registry(),
Expand Down
22 changes: 22 additions & 0 deletions tests/auth/test_auth_store.py
@@ -1,4 +1,8 @@
"""Tests for the auth store."""
import asyncio

import asynctest

from homeassistant.auth import auth_store


Expand Down Expand Up @@ -218,3 +222,21 @@ async def test_system_groups_store_id_and_name(hass, hass_storage):
'name': auth_store.GROUP_NAME_READ_ONLY,
},
]


async def test_loading_race_condition(hass):
"""Test only one storage load called when concurrent loading occurred ."""
store = auth_store.AuthStore(hass)
with asynctest.patch(
'homeassistant.helpers.entity_registry.async_get_registry',
) as mock_registry, asynctest.patch(
'homeassistant.helpers.storage.Store.async_load',
) as mock_load:
results = await asyncio.gather(
store.async_get_users(),
store.async_get_users(),
)

mock_registry.assert_called_once_with(hass)
mock_load.assert_called_once_with()
assert results[0] == results[1]
17 changes: 17 additions & 0 deletions tests/helpers/test_area_registry.py
@@ -1,4 +1,7 @@
"""Tests for the Area Registry."""
import asyncio

import asynctest
import pytest

from homeassistant.helpers import area_registry
Expand Down Expand Up @@ -125,3 +128,17 @@ async def test_loading_area_from_storage(hass, hass_storage):
registry = await area_registry.async_get_registry(hass)

assert len(registry.areas) == 1


async def test_loading_race_condition(hass):
"""Test only one storage load called when concurrent loading occurred ."""
with asynctest.patch(
'homeassistant.helpers.area_registry.AreaRegistry.async_load',
) as mock_load:
results = await asyncio.gather(
area_registry.async_get_registry(hass),
area_registry.async_get_registry(hass),
)

mock_load.assert_called_once_with()
assert results[0] == results[1]
16 changes: 16 additions & 0 deletions tests/helpers/test_device_registry.py
@@ -1,6 +1,8 @@
"""Tests for the Device Registry."""
import asyncio
from unittest.mock import patch

import asynctest
import pytest

from homeassistant.helpers import device_registry
Expand Down Expand Up @@ -370,3 +372,17 @@ async def test_update(registry):
assert updated_entry != entry
assert updated_entry.area_id == '12345A'
assert updated_entry.name_by_user == 'Test Friendly Name'


async def test_loading_race_condition(hass):
"""Test only one storage load called when concurrent loading occurred ."""
with asynctest.patch(
'homeassistant.helpers.device_registry.DeviceRegistry.async_load',
) as mock_load:
results = await asyncio.gather(
device_registry.async_get_registry(hass),
device_registry.async_get_registry(hass),
)

mock_load.assert_called_once_with()
assert results[0] == results[1]
28 changes: 17 additions & 11 deletions tests/helpers/test_entity_registry.py
Expand Up @@ -2,6 +2,7 @@
import asyncio
from unittest.mock import patch

import asynctest
import pytest

from homeassistant.core import valid_entity_id
Expand All @@ -19,7 +20,6 @@ def registry(hass):
return mock_registry(hass)


@asyncio.coroutine
def test_get_or_create_returns_same_entry(registry):
"""Make sure we do not duplicate entries."""
entry = registry.async_get_or_create('light', 'hue', '1234')
Expand All @@ -30,7 +30,6 @@ def test_get_or_create_returns_same_entry(registry):
assert entry.entity_id == 'light.hue_1234'


@asyncio.coroutine
def test_get_or_create_suggested_object_id(registry):
"""Test that suggested_object_id works."""
entry = registry.async_get_or_create(
Expand All @@ -39,7 +38,6 @@ def test_get_or_create_suggested_object_id(registry):
assert entry.entity_id == 'light.beer'


@asyncio.coroutine
def test_get_or_create_suggested_object_id_conflict_register(registry):
"""Test that we don't generate an entity id that is already registered."""
entry = registry.async_get_or_create(
Expand All @@ -51,15 +49,13 @@ def test_get_or_create_suggested_object_id_conflict_register(registry):
assert entry2.entity_id == 'light.beer_2'


@asyncio.coroutine
def test_get_or_create_suggested_object_id_conflict_existing(hass, registry):
"""Test that we don't generate an entity id that currently exists."""
hass.states.async_set('light.hue_1234', 'on')
entry = registry.async_get_or_create('light', 'hue', '1234')
assert entry.entity_id == 'light.hue_1234_2'


@asyncio.coroutine
def test_create_triggers_save(hass, registry):
"""Test that registering entry triggers a save."""
with patch.object(registry, 'async_schedule_save') as mock_schedule_save:
Expand Down Expand Up @@ -91,7 +87,6 @@ async def test_loading_saving_data(hass, registry):
assert orig_entry2 == new_entry2


@asyncio.coroutine
def test_generate_entity_considers_registered_entities(registry):
"""Test that we don't create entity id that are already registered."""
entry = registry.async_get_or_create('light', 'hue', '1234')
Expand All @@ -100,15 +95,13 @@ def test_generate_entity_considers_registered_entities(registry):
'light.hue_1234_2'


@asyncio.coroutine
def test_generate_entity_considers_existing_entities(hass, registry):
"""Test that we don't create entity id that currently exists."""
hass.states.async_set('light.kitchen', 'on')
assert registry.async_generate_entity_id('light', 'kitchen') == \
'light.kitchen_2'


@asyncio.coroutine
def test_is_registered(registry):
"""Test that is_registered works."""
entry = registry.async_get_or_create('light', 'hue', '1234')
Expand Down Expand Up @@ -166,7 +159,6 @@ async def test_loading_extra_values(hass, hass_storage):
assert entry_disabled_user.disabled_by == entity_registry.DISABLED_USER


@asyncio.coroutine
def test_async_get_entity_id(registry):
"""Test that entity_id is returned."""
entry = registry.async_get_or_create('light', 'hue', '1234')
Expand All @@ -176,7 +168,7 @@ def test_async_get_entity_id(registry):
assert registry.async_get_entity_id('light', 'hue', '123') is None


async def test_updating_config_entry_id(registry):
def test_updating_config_entry_id(registry):
"""Test that we update config entry id in registry."""
entry = registry.async_get_or_create(
'light', 'hue', '5678', config_entry_id='mock-id-1')
Expand All @@ -186,7 +178,7 @@ async def test_updating_config_entry_id(registry):
assert entry2.config_entry_id == 'mock-id-2'


async def test_removing_config_entry_id(registry):
def test_removing_config_entry_id(registry):
"""Test that we update config entry id in registry."""
entry = registry.async_get_or_create(
'light', 'hue', '5678', config_entry_id='mock-id-1')
Expand Down Expand Up @@ -265,3 +257,17 @@ async def test_loading_invalid_entity_id(hass, hass_storage):
'test', 'super_platform', 'id-invalid-start')

assert valid_entity_id(entity_invalid_start.entity_id)


async def test_loading_race_condition(hass):
"""Test only one storage load called when concurrent loading occurred ."""
with asynctest.patch(
'homeassistant.helpers.entity_registry.EntityRegistry.async_load',
) as mock_load:
results = await asyncio.gather(
entity_registry.async_get_registry(hass),
entity_registry.async_get_registry(hass),
)

mock_load.assert_called_once_with()
assert results[0] == results[1]

0 comments on commit 3d8673d

Please sign in to comment.