Skip to content

Commit

Permalink
Fix memory leak when importing a platform fails (#114602)
Browse files Browse the repository at this point in the history
* Fix memory leak when importing a platform fails

re-raising ImportError would trigger a memory leak

* fixes, coverage

* Apply suggestions from code review
  • Loading branch information
bdraco committed Apr 2, 2024
1 parent 0963f5e commit b12c69a
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 40 deletions.
31 changes: 15 additions & 16 deletions homeassistant/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,9 +750,7 @@ def __init__(
self._import_futures: dict[str, asyncio.Future[ModuleType]] = {}
cache: dict[str, ModuleType | ComponentProtocol] = hass.data[DATA_COMPONENTS]
self._cache = cache
missing_platforms_cache: dict[str, ImportError] = hass.data[
DATA_MISSING_PLATFORMS
]
missing_platforms_cache: dict[str, bool] = hass.data[DATA_MISSING_PLATFORMS]
self._missing_platforms_cache = missing_platforms_cache
self._top_level_files = top_level_files or set()
_LOGGER.info("Loaded %s from %s", self.domain, pkg_path)
Expand Down Expand Up @@ -1085,8 +1083,7 @@ async def async_get_platforms(
import_futures: list[tuple[str, asyncio.Future[ModuleType]]] = []

for platform_name in platform_names:
full_name = f"{domain}.{platform_name}"
if platform := self._get_platform_cached_or_raise(full_name):
if platform := self._get_platform_cached_or_raise(platform_name):
platforms[platform_name] = platform
continue

Expand All @@ -1095,6 +1092,7 @@ async def async_get_platforms(
in_progress_imports[platform_name] = future
continue

full_name = f"{domain}.{platform_name}"
if (
self.import_executor
and full_name not in self.hass.config.components
Expand Down Expand Up @@ -1166,14 +1164,18 @@ async def async_get_platforms(

return platforms

def _get_platform_cached_or_raise(self, full_name: str) -> ModuleType | None:
def _get_platform_cached_or_raise(self, platform_name: str) -> ModuleType | None:
"""Return a platform for an integration from cache."""
full_name = f"{self.domain}.{platform_name}"
if full_name in self._cache:
# the cache is either a ModuleType or a ComponentProtocol
# but we only care about the ModuleType here
return self._cache[full_name] # type: ignore[return-value]
if full_name in self._missing_platforms_cache:
raise self._missing_platforms_cache[full_name]
raise ModuleNotFoundError(
f"Platform {full_name} not found",
name=f"{self.pkg_path}.{platform_name}",
)
return None

def platforms_are_loaded(self, platform_names: Iterable[str]) -> bool:
Expand All @@ -1189,9 +1191,7 @@ def get_platform_cached(self, platform_name: str) -> ModuleType | None:

def get_platform(self, platform_name: str) -> ModuleType:
"""Return a platform for an integration."""
if platform := self._get_platform_cached_or_raise(
f"{self.domain}.{platform_name}"
):
if platform := self._get_platform_cached_or_raise(platform_name):
return platform
return self._load_platform(platform_name)

Expand All @@ -1212,10 +1212,7 @@ def platforms_exists(self, platform_names: Iterable[str]) -> list[str]:
):
existing_platforms.append(platform_name)
continue
missing_platforms[full_name] = ModuleNotFoundError(
f"Platform {full_name} not found",
name=f"{self.pkg_path}.{platform_name}",
)
missing_platforms[full_name] = True

return existing_platforms

Expand All @@ -1233,11 +1230,13 @@ def _load_platform(self, platform_name: str) -> ModuleType:
cache: dict[str, ModuleType] = self.hass.data[DATA_COMPONENTS]
try:
cache[full_name] = self._import_platform(platform_name)
except ImportError as ex:
except ModuleNotFoundError:
if self.domain in cache:
# If the domain is loaded, cache that the platform
# does not exist so we do not try to load it again
self._missing_platforms_cache[full_name] = ex
self._missing_platforms_cache[full_name] = True
raise
except ImportError:
raise
except RuntimeError as err:
# _DeadlockError inherits from RuntimeError
Expand Down
107 changes: 83 additions & 24 deletions tests/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,61 @@ async def test_get_integration_exceptions(hass: HomeAssistant) -> None:
async def test_get_platform_caches_failures_when_component_loaded(
hass: HomeAssistant,
) -> None:
"""Test get_platform cache failures only when the component is loaded."""
"""Test get_platform caches failures only when the component is loaded.
Only ModuleNotFoundError is cached, ImportError is not cached.
"""
integration = await loader.async_get_integration(hass, "hue")

with (
pytest.raises(ModuleNotFoundError),
patch(
"homeassistant.loader.importlib.import_module",
side_effect=ModuleNotFoundError("Boom"),
),
):
assert integration.get_component() == hue

with (
pytest.raises(ModuleNotFoundError),
patch(
"homeassistant.loader.importlib.import_module",
side_effect=ModuleNotFoundError("Boom"),
),
):
assert integration.get_platform("light") == hue_light

# Hue is not loaded so we should still hit the import_module path
with (
pytest.raises(ModuleNotFoundError),
patch(
"homeassistant.loader.importlib.import_module",
side_effect=ModuleNotFoundError("Boom"),
),
):
assert integration.get_platform("light") == hue_light

assert integration.get_component() == hue

# Hue is loaded so we should cache the import_module failure now
with (
pytest.raises(ModuleNotFoundError),
patch(
"homeassistant.loader.importlib.import_module",
side_effect=ModuleNotFoundError("Boom"),
),
):
assert integration.get_platform("light") == hue_light

# Hue is loaded and the last call should have cached the import_module failure
with pytest.raises(ModuleNotFoundError):
assert integration.get_platform("light") == hue_light


async def test_get_platform_only_cached_module_not_found_when_component_loaded(
hass: HomeAssistant,
) -> None:
"""Test get_platform cache only cache module not found when the component is loaded."""
integration = await loader.async_get_integration(hass, "hue")

with (
Expand Down Expand Up @@ -317,41 +371,43 @@ async def test_get_platform_caches_failures_when_component_loaded(
):
assert integration.get_platform("light") == hue_light

# Hue is loaded and the last call should have cached the import_module failure
with pytest.raises(ImportError):
assert integration.get_platform("light") == hue_light
# ImportError is not cached because we only cache ModuleNotFoundError
assert integration.get_platform("light") == hue_light


async def test_async_get_platform_caches_failures_when_component_loaded(
hass: HomeAssistant,
) -> None:
"""Test async_get_platform cache failures only when the component is loaded."""
"""Test async_get_platform caches failures only when the component is loaded.
Only ModuleNotFoundError is cached, ImportError is not cached.
"""
integration = await loader.async_get_integration(hass, "hue")

with (
pytest.raises(ImportError),
pytest.raises(ModuleNotFoundError),
patch(
"homeassistant.loader.importlib.import_module",
side_effect=ImportError("Boom"),
side_effect=ModuleNotFoundError("Boom"),
),
):
assert integration.get_component() == hue

with (
pytest.raises(ImportError),
pytest.raises(ModuleNotFoundError),
patch(
"homeassistant.loader.importlib.import_module",
side_effect=ImportError("Boom"),
side_effect=ModuleNotFoundError("Boom"),
),
):
assert await integration.async_get_platform("light") == hue_light

# Hue is not loaded so we should still hit the import_module path
with (
pytest.raises(ImportError),
pytest.raises(ModuleNotFoundError),
patch(
"homeassistant.loader.importlib.import_module",
side_effect=ImportError("Boom"),
side_effect=ModuleNotFoundError("Boom"),
),
):
assert await integration.async_get_platform("light") == hue_light
Expand All @@ -360,16 +416,16 @@ async def test_async_get_platform_caches_failures_when_component_loaded(

# Hue is loaded so we should cache the import_module failure now
with (
pytest.raises(ImportError),
pytest.raises(ModuleNotFoundError),
patch(
"homeassistant.loader.importlib.import_module",
side_effect=ImportError("Boom"),
side_effect=ModuleNotFoundError("Boom"),
),
):
assert await integration.async_get_platform("light") == hue_light

# Hue is loaded and the last call should have cached the import_module failure
with pytest.raises(ImportError):
with pytest.raises(ModuleNotFoundError):
assert await integration.async_get_platform("light") == hue_light

# The cache should never be filled because the import error is remembered
Expand All @@ -379,33 +435,36 @@ async def test_async_get_platform_caches_failures_when_component_loaded(
async def test_async_get_platforms_caches_failures_when_component_loaded(
hass: HomeAssistant,
) -> None:
"""Test async_get_platforms cache failures only when the component is loaded."""
"""Test async_get_platforms cache failures only when the component is loaded.
Only ModuleNotFoundError is cached, ImportError is not cached.
"""
integration = await loader.async_get_integration(hass, "hue")

with (
pytest.raises(ImportError),
pytest.raises(ModuleNotFoundError),
patch(
"homeassistant.loader.importlib.import_module",
side_effect=ImportError("Boom"),
side_effect=ModuleNotFoundError("Boom"),
),
):
assert integration.get_component() == hue

with (
pytest.raises(ImportError),
pytest.raises(ModuleNotFoundError),
patch(
"homeassistant.loader.importlib.import_module",
side_effect=ImportError("Boom"),
side_effect=ModuleNotFoundError("Boom"),
),
):
assert await integration.async_get_platforms(["light"]) == {"light": hue_light}

# Hue is not loaded so we should still hit the import_module path
with (
pytest.raises(ImportError),
pytest.raises(ModuleNotFoundError),
patch(
"homeassistant.loader.importlib.import_module",
side_effect=ImportError("Boom"),
side_effect=ModuleNotFoundError("Boom"),
),
):
assert await integration.async_get_platforms(["light"]) == {"light": hue_light}
Expand All @@ -414,16 +473,16 @@ async def test_async_get_platforms_caches_failures_when_component_loaded(

# Hue is loaded so we should cache the import_module failure now
with (
pytest.raises(ImportError),
pytest.raises(ModuleNotFoundError),
patch(
"homeassistant.loader.importlib.import_module",
side_effect=ImportError("Boom"),
side_effect=ModuleNotFoundError("Boom"),
),
):
assert await integration.async_get_platforms(["light"]) == {"light": hue_light}

# Hue is loaded and the last call should have cached the import_module failure
with pytest.raises(ImportError):
with pytest.raises(ModuleNotFoundError):
assert await integration.async_get_platforms(["light"]) == {"light": hue_light}

# The cache should never be filled because the import error is remembered
Expand Down

0 comments on commit b12c69a

Please sign in to comment.