Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve apple_tv typing #107694

Merged
merged 5 commits into from Jan 14, 2024
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
51 changes: 30 additions & 21 deletions homeassistant/components/apple_tv/__init__.py
Expand Up @@ -2,10 +2,13 @@
import asyncio
import logging
from random import randrange
from typing import TYPE_CHECKING, cast

from pyatv import connect, exceptions, scan
from pyatv.conf import AppleTV
from pyatv.const import DeviceModel, Protocol
from pyatv.convert import model_str
from pyatv.interface import AppleTV as AppleTVInterface, DeviceListener

from homeassistant.components import zeroconf
from homeassistant.config_entries import ConfigEntry
Expand Down Expand Up @@ -92,10 +95,14 @@ class AppleTVEntity(Entity):
_attr_has_entity_name = True
_attr_name = None

def __init__(self, name, identifier, manager):
def __init__(
self, name: str, identifier: str | None, manager: "AppleTVManager"
) -> None:
"""Initialize device."""
self.atv = None
self.atv: AppleTV = None # type: ignore[assignment]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ignore for now since every platform need lots of asserts otherwise and it would massively increase the scope of this PR

self.manager = manager
if TYPE_CHECKING:
assert identifier is not None
self._attr_unique_id = identifier
self._attr_device_info = DeviceInfo(
identifiers={(DOMAIN, identifier)},
Expand Down Expand Up @@ -143,19 +150,19 @@ def async_device_disconnected(self):
"""Handle when connection was lost to device."""


class AppleTVManager:
class AppleTVManager(DeviceListener):
"""Connection and power manager for an Apple TV.

An instance is used per device to share the same power state between
several platforms. It also manages scanning and connection establishment
in case of problems.
"""

def __init__(self, hass, config_entry):
def __init__(self, hass: HomeAssistant, config_entry: ConfigEntry) -> None:
"""Initialize power manager."""
self.config_entry = config_entry
self.hass = hass
self.atv = None
self.atv: AppleTVInterface | None = None
self.is_on = not config_entry.options.get(CONF_START_OFF, False)
self._connection_attempts = 0
self._connection_was_lost = False
Expand Down Expand Up @@ -220,7 +227,7 @@ def _start_connect_loop(self):
"Not starting connect loop (%s, %s)", self.atv is None, self.is_on
)

async def connect_once(self, raise_missing_credentials):
async def connect_once(self, raise_missing_credentials: bool) -> None:
"""Try to connect once."""
try:
if conf := await self._scan():
Expand Down Expand Up @@ -264,49 +271,51 @@ async def _connect_loop(self):
_LOGGER.debug("Connect loop ended")
self._task = None

async def _scan(self):
async def _scan(self) -> AppleTV | None:
"""Try to find device by scanning for it."""
identifiers = set(
self.config_entry.data.get(CONF_IDENTIFIERS, [self.config_entry.unique_id])
config_entry = self.config_entry
identifiers: set[str] = set(
config_entry.data.get(CONF_IDENTIFIERS, [config_entry.unique_id])
)
address = self.config_entry.data[CONF_ADDRESS]
address: str = config_entry.data[CONF_ADDRESS]
hass = self.hass

# Only scan for and set up protocols that was successfully paired
protocols = {
Protocol(int(protocol))
for protocol in self.config_entry.data[CONF_CREDENTIALS]
Protocol(int(protocol)) for protocol in config_entry.data[CONF_CREDENTIALS]
}

_LOGGER.debug("Discovering device %s", self.config_entry.title)
aiozc = await zeroconf.async_get_async_instance(self.hass)
_LOGGER.debug("Discovering device %s", config_entry.title)
aiozc = await zeroconf.async_get_async_instance(hass)
atvs = await scan(
self.hass.loop,
hass.loop,
identifier=identifiers,
protocol=protocols,
hosts=[address],
aiozc=aiozc,
)
if atvs:
return atvs[0]
return cast(AppleTV, atvs[0])

_LOGGER.debug(
"Failed to find device %s with address %s",
self.config_entry.title,
config_entry.title,
address,
)
# We no longer multicast scan for the device since as soon as async_step_zeroconf runs,
# it will update the address and reload the config entry when the device is found.
return None

async def _connect(self, conf, raise_missing_credentials):
async def _connect(self, conf: AppleTV, raise_missing_credentials: bool) -> None:
"""Connect to device."""
credentials = self.config_entry.data[CONF_CREDENTIALS]
name = self.config_entry.data[CONF_NAME]
config_entry = self.config_entry
credentials: dict[int, str | None] = config_entry.data[CONF_CREDENTIALS]
name: str = config_entry.data[CONF_NAME]
missing_protocols = []
for protocol_int, creds in credentials.items():
protocol = Protocol(int(protocol_int))
if conf.get_service(protocol) is not None:
conf.set_credentials(protocol, creds)
conf.set_credentials(protocol, creds) # type: ignore[arg-type]
else:
missing_protocols.append(protocol.name)

Expand Down