Skip to content

Commit

Permalink
Don't blow up if config entries have unhashable unique IDs (#109966)
Browse files Browse the repository at this point in the history
* Don't blow up if config entries have unhashable unique IDs

* Add test

* Add comment on when we remove the guard

* Don't stringify hashable non string unique_id
  • Loading branch information
emontnemery authored and frenck committed Feb 9, 2024
1 parent a9e9ec2 commit 95a800b
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 5 deletions.
38 changes: 33 additions & 5 deletions homeassistant/config_entries.py
Expand Up @@ -7,6 +7,7 @@
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Mapping,
ValuesView,
Expand Down Expand Up @@ -49,6 +50,7 @@
)
from .helpers.frame import report
from .helpers.typing import UNDEFINED, ConfigType, DiscoveryInfoType, UndefinedType
from .loader import async_suggest_report_issue
from .setup import DATA_SETUP_DONE, async_process_deps_reqs, async_setup_component
from .util import uuid as uuid_util
from .util.decorator import Registry
Expand Down Expand Up @@ -1124,9 +1126,10 @@ class ConfigEntryItems(UserDict[str, ConfigEntry]):
- domain -> unique_id -> ConfigEntry
"""

def __init__(self) -> None:
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the container."""
super().__init__()
self._hass = hass
self._domain_index: dict[str, list[ConfigEntry]] = {}
self._domain_unique_id_index: dict[str, dict[str, ConfigEntry]] = {}

Expand All @@ -1145,8 +1148,27 @@ def __setitem__(self, entry_id: str, entry: ConfigEntry) -> None:
data[entry_id] = entry
self._domain_index.setdefault(entry.domain, []).append(entry)
if entry.unique_id is not None:
unique_id_hash = entry.unique_id
# Guard against integrations using unhashable unique_id
# In HA Core 2024.9, we should remove the guard and instead fail
if not isinstance(entry.unique_id, Hashable):
unique_id_hash = str(entry.unique_id) # type: ignore[unreachable]
report_issue = async_suggest_report_issue(
self._hass, integration_domain=entry.domain
)
_LOGGER.error(
(
"Config entry '%s' from integration %s has an invalid unique_id"
" '%s', please %s"
),
entry.title,
entry.domain,
entry.unique_id,
report_issue,
)

self._domain_unique_id_index.setdefault(entry.domain, {})[
entry.unique_id
unique_id_hash
] = entry

def _unindex_entry(self, entry_id: str) -> None:
Expand All @@ -1157,6 +1179,9 @@ def _unindex_entry(self, entry_id: str) -> None:
if not self._domain_index[domain]:
del self._domain_index[domain]
if (unique_id := entry.unique_id) is not None:
# Check type first to avoid expensive isinstance call
if type(unique_id) is not str and not isinstance(unique_id, Hashable): # noqa: E721
unique_id = str(entry.unique_id) # type: ignore[unreachable]
del self._domain_unique_id_index[domain][unique_id]
if not self._domain_unique_id_index[domain]:
del self._domain_unique_id_index[domain]
Expand All @@ -1174,6 +1199,9 @@ def get_entry_by_domain_and_unique_id(
self, domain: str, unique_id: str
) -> ConfigEntry | None:
"""Get entry by domain and unique id."""
# Check type first to avoid expensive isinstance call
if type(unique_id) is not str and not isinstance(unique_id, Hashable): # noqa: E721
unique_id = str(unique_id) # type: ignore[unreachable]
return self._domain_unique_id_index.get(domain, {}).get(unique_id)


Expand All @@ -1189,7 +1217,7 @@ def __init__(self, hass: HomeAssistant, hass_config: ConfigType) -> None:
self.flow = ConfigEntriesFlowManager(hass, self, hass_config)
self.options = OptionsFlowManager(hass)
self._hass_config = hass_config
self._entries = ConfigEntryItems()
self._entries = ConfigEntryItems(hass)
self._store = storage.Store[dict[str, list[dict[str, Any]]]](
hass, STORAGE_VERSION, STORAGE_KEY
)
Expand Down Expand Up @@ -1314,10 +1342,10 @@ async def async_initialize(self) -> None:
self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, self._async_shutdown)

if config is None:
self._entries = ConfigEntryItems()
self._entries = ConfigEntryItems(self.hass)
return

entries: ConfigEntryItems = ConfigEntryItems()
entries: ConfigEntryItems = ConfigEntryItems(self.hass)
for entry in config["entries"]:
pref_disable_new_entities = entry.get("pref_disable_new_entities")

Expand Down
61 changes: 61 additions & 0 deletions tests/test_config_entries.py
Expand Up @@ -4257,3 +4257,64 @@ async def async_step_reauth(self, data):
assert entry.state == config_entries.ConfigEntryState.LOADED
assert task["type"] == FlowResultType.ABORT
assert task["reason"] == "reauth_successful"


@pytest.mark.parametrize("unique_id", [["blah", "bleh"], {"key": "value"}])
async def test_unhashable_unique_id(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture, unique_id: Any
) -> None:
"""Test the ConfigEntryItems user dict handles unhashable unique_id."""
entries = config_entries.ConfigEntryItems(hass)
entry = config_entries.ConfigEntry(
version=1,
minor_version=1,
domain="test",
entry_id="mock_id",
title="title",
data={},
source="test",
unique_id=unique_id,
)

entries[entry.entry_id] = entry
assert (
"Config entry 'title' from integration test has an invalid unique_id "
f"'{str(unique_id)}'"
) in caplog.text

assert entry.entry_id in entries
assert entries[entry.entry_id] is entry
assert entries.get_entry_by_domain_and_unique_id("test", unique_id) == entry
del entries[entry.entry_id]
assert not entries
assert entries.get_entry_by_domain_and_unique_id("test", unique_id) is None


@pytest.mark.parametrize("unique_id", [123])
async def test_hashable_non_string_unique_id(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture, unique_id: Any
) -> None:
"""Test the ConfigEntryItems user dict handles hashable non string unique_id."""
entries = config_entries.ConfigEntryItems(hass)
entry = config_entries.ConfigEntry(
version=1,
minor_version=1,
domain="test",
entry_id="mock_id",
title="title",
data={},
source="test",
unique_id=unique_id,
)

entries[entry.entry_id] = entry
assert (
"Config entry 'title' from integration test has an invalid unique_id"
) not in caplog.text

assert entry.entry_id in entries
assert entries[entry.entry_id] is entry
assert entries.get_entry_by_domain_and_unique_id("test", unique_id) == entry
del entries[entry.entry_id]
assert not entries
assert entries.get_entry_by_domain_and_unique_id("test", unique_id) is None

0 comments on commit 95a800b

Please sign in to comment.