Skip to content

Commit

Permalink
Add some typing to common test helpers (#80337)
Browse files Browse the repository at this point in the history
  • Loading branch information
frenck committed Oct 14, 2022
1 parent 4ebf9df commit e3af2cb
Showing 1 changed file with 29 additions and 13 deletions.
42 changes: 29 additions & 13 deletions tests/common.py
Expand Up @@ -20,6 +20,7 @@
from unittest.mock import AsyncMock, Mock, patch

from aiohttp.test_utils import unused_port as get_test_instance_port # noqa: F401
import voluptuous as vol

from homeassistant import auth, config_entries, core as ha, loader
from homeassistant.auth import (
Expand All @@ -42,7 +43,7 @@
STATE_OFF,
STATE_ON,
)
from homeassistant.core import BLOCK_LOG_TIMEOUT, HomeAssistant
from homeassistant.core import BLOCK_LOG_TIMEOUT, HomeAssistant, ServiceCall, State
from homeassistant.helpers import (
area_registry,
device_registry,
Expand All @@ -57,6 +58,7 @@
)
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.json import JSONEncoder
from homeassistant.helpers.typing import ConfigType
from homeassistant.setup import setup_component
from homeassistant.util.async_ import run_callback_threadsafe
import homeassistant.util.dt as date_util
Expand Down Expand Up @@ -328,7 +330,9 @@ def clear_instance(event):
return hass


def async_mock_service(hass, domain, service, schema=None):
def async_mock_service(
hass: HomeAssistant, domain: str, service: str, schema: vol.Schema | None = None
) -> list[ServiceCall]:
"""Set up a fake service & return a calls log list to this service."""
calls = []

Expand Down Expand Up @@ -417,18 +421,20 @@ def get_fixture_path(filename: str, integration: str | None = None) -> pathlib.P

if integration is None:
return pathlib.Path(__file__).parent.joinpath("fixtures", filename)
else:
return pathlib.Path(__file__).parent.joinpath(
"components", integration, "fixtures", filename
)

return pathlib.Path(__file__).parent.joinpath(
"components", integration, "fixtures", filename
)


def load_fixture(filename, integration=None):
def load_fixture(filename: str, integration: str | None = None) -> str:
"""Load a fixture."""
return get_fixture_path(filename, integration).read_text()


def mock_state_change_event(hass, new_state, old_state=None):
def mock_state_change_event(
hass: HomeAssistant, new_state: State, old_state: State | None = None
) -> None:
"""Mock state change envent."""
event_data = {"entity_id": new_state.entity_id, "new_state": new_state}

Expand All @@ -439,15 +445,18 @@ def mock_state_change_event(hass, new_state, old_state=None):


@ha.callback
def mock_component(hass, component):
def mock_component(hass: HomeAssistant, component: str) -> None:
"""Mock a component is setup."""
if component in hass.config.components:
AssertionError(f"Integration {component} is already setup")

hass.config.components.add(component)


def mock_registry(hass, mock_entries=None):
def mock_registry(
hass: HomeAssistant,
mock_entries: dict[str, entity_registry.RegistryEntry] | None = None,
) -> entity_registry.EntityRegistry:
"""Mock the Entity Registry."""
registry = entity_registry.EntityRegistry(hass)
if mock_entries is None:
Expand All @@ -460,7 +469,9 @@ def mock_registry(hass, mock_entries=None):
return registry


def mock_area_registry(hass, mock_entries=None):
def mock_area_registry(
hass: HomeAssistant, mock_entries: dict[str, area_registry.AreaEntry] | None = None
) -> area_registry.AreaRegistry:
"""Mock the Area Registry."""
registry = area_registry.AreaRegistry(hass)
registry.areas = mock_entries or OrderedDict()
Expand All @@ -469,7 +480,10 @@ def mock_area_registry(hass, mock_entries=None):
return registry


def mock_device_registry(hass, mock_entries=None):
def mock_device_registry(
hass: HomeAssistant,
mock_entries: dict[str, device_registry.DeviceEntry] | None = None,
) -> device_registry.DeviceRegistry:
"""Mock the Device Registry."""
registry = device_registry.DeviceRegistry(hass)
registry.devices = device_registry.DeviceRegistryItems()
Expand Down Expand Up @@ -545,7 +559,9 @@ def mock_policy(self, policy):
self._permissions = auth_permissions.PolicyPermissions(policy, self.perm_lookup)


async def register_auth_provider(hass, config):
async def register_auth_provider(
hass: HomeAssistant, config: ConfigType
) -> auth_providers.AuthProvider:
"""Register an auth provider."""
provider = await auth_providers.auth_provider_from_config(
hass, hass.auth._store, config
Expand Down

0 comments on commit e3af2cb

Please sign in to comment.