diff --git a/.coveragerc b/.coveragerc index d42d7cbb3b3250..c92e0e182b9b14 100644 --- a/.coveragerc +++ b/.coveragerc @@ -673,6 +673,7 @@ omit = homeassistant/components/tradfri/* homeassistant/components/tradfri/light.py homeassistant/components/tradfri/cover.py + homeassistant/components/tradfri/base_class.py homeassistant/components/trafikverket_train/sensor.py homeassistant/components/trafikverket_weatherstation/sensor.py homeassistant/components/transmission/__init__.py diff --git a/homeassistant/components/tradfri/base_class.py b/homeassistant/components/tradfri/base_class.py new file mode 100644 index 00000000000000..5fce3c08510543 --- /dev/null +++ b/homeassistant/components/tradfri/base_class.py @@ -0,0 +1,96 @@ +"""Base class for IKEA TRADFRI.""" +import logging + +from pytradfri.error import PytradfriError + +from homeassistant.core import callback +from homeassistant.helpers.entity import Entity +from . import DOMAIN as TRADFRI_DOMAIN + +_LOGGER = logging.getLogger(__name__) + + +class TradfriBaseDevice(Entity): + """Base class for a TRADFRI device.""" + + def __init__(self, device, api, gateway_id): + """Initialize a device.""" + self._available = True + self._api = api + self._device = None + self._device_control = None + self._device_data = None + self._gateway_id = gateway_id + self._name = None + self._unique_id = None + + self._refresh(device) + + @callback + def _async_start_observe(self, exc=None): + """Start observation of device.""" + if exc: + self._available = False + self.async_schedule_update_ha_state() + _LOGGER.warning("Observation failed for %s", self._name, exc_info=exc) + + try: + cmd = self._device.observe( + callback=self._observe_update, + err_callback=self._async_start_observe, + duration=0, + ) + self.hass.async_create_task(self._api(cmd)) + except PytradfriError as err: + _LOGGER.warning("Observation failed, trying again", exc_info=err) + self._async_start_observe() + + async def async_added_to_hass(self): + """Start thread when added to hass.""" + self._async_start_observe() + + @property + def available(self): + """Return True if entity is available.""" + return self._available + + @property + def device_info(self): + """Return the device info.""" + info = self._device.device_info + + return { + "identifiers": {(TRADFRI_DOMAIN, self._device.id)}, + "name": self._name, + "manufacturer": info.manufacturer, + "model": info.model_number, + "sw_version": info.firmware_version, + "via_device": (TRADFRI_DOMAIN, self._gateway_id), + } + + @property + def name(self): + """Return the display name of this device.""" + return self._name + + @property + def should_poll(self): + """No polling needed for tradfri device.""" + return False + + @property + def unique_id(self): + """Return unique ID for device.""" + return self._unique_id + + @callback + def _observe_update(self, device): + """Receive new state data for this device.""" + self._refresh(device) + self.async_schedule_update_ha_state() + + def _refresh(self, device): + """Refresh the device data.""" + self._device = device + self._name = device.name + self._available = device.reachable diff --git a/homeassistant/components/tradfri/switch.py b/homeassistant/components/tradfri/switch.py index 545c1ad93cec17..1e322ff47f5111 100644 --- a/homeassistant/components/tradfri/switch.py +++ b/homeassistant/components/tradfri/switch.py @@ -1,17 +1,9 @@ """Support for IKEA Tradfri switches.""" -import logging - -from pytradfri.error import PytradfriError - from homeassistant.components.switch import SwitchDevice -from homeassistant.core import callback -from . import DOMAIN as TRADFRI_DOMAIN, KEY_API, KEY_GATEWAY +from . import KEY_API, KEY_GATEWAY +from .base_class import TradfriBaseDevice from .const import CONF_GATEWAY_ID -_LOGGER = logging.getLogger(__name__) - -TRADFRI_SWITCH_MANAGER = "Tradfri Switch Manager" - async def async_setup_entry(hass, config_entry, async_add_entities): """Load Tradfri switches based on a config entry.""" @@ -28,104 +20,31 @@ async def async_setup_entry(hass, config_entry, async_add_entities): ) -class TradfriSwitch(SwitchDevice): +class TradfriSwitch(TradfriBaseDevice, SwitchDevice): """The platform class required by Home Assistant.""" - def __init__(self, switch, api, gateway_id): + def __init__(self, device, api, gateway_id): """Initialize a switch.""" - self._api = api - self._unique_id = f"{gateway_id}-{switch.id}" - self._switch = None - self._socket_control = None - self._switch_data = None - self._name = None - self._available = True - self._gateway_id = gateway_id - - self._refresh(switch) - - @property - def unique_id(self): - """Return unique ID for switch.""" - return self._unique_id - - @property - def device_info(self): - """Return the device info.""" - info = self._switch.device_info - - return { - "identifiers": {(TRADFRI_DOMAIN, self._switch.id)}, - "name": self._name, - "manufacturer": info.manufacturer, - "model": info.model_number, - "sw_version": info.firmware_version, - "via_device": (TRADFRI_DOMAIN, self._gateway_id), - } - - async def async_added_to_hass(self): - """Start thread when added to hass.""" - self._async_start_observe() + super().__init__(device, api, gateway_id) + self._unique_id = f"{gateway_id}-{device.id}" - @property - def available(self): - """Return True if entity is available.""" - return self._available - - @property - def should_poll(self): - """No polling needed for tradfri switch.""" - return False + def _refresh(self, device): + """Refresh the switch data.""" + super()._refresh(device) - @property - def name(self): - """Return the display name of this switch.""" - return self._name + # Caching of switch control and switch object + self._device_control = device.socket_control + self._device_data = device.socket_control.sockets[0] @property def is_on(self): """Return true if switch is on.""" - return self._switch_data.state + return self._device_data.state async def async_turn_off(self, **kwargs): """Instruct the switch to turn off.""" - await self._api(self._socket_control.set_state(False)) + await self._api(self._device_control.set_state(False)) async def async_turn_on(self, **kwargs): """Instruct the switch to turn on.""" - await self._api(self._socket_control.set_state(True)) - - @callback - def _async_start_observe(self, exc=None): - """Start observation of switch.""" - if exc: - self._available = False - self.async_schedule_update_ha_state() - _LOGGER.warning("Observation failed for %s", self._name, exc_info=exc) - - try: - cmd = self._switch.observe( - callback=self._observe_update, - err_callback=self._async_start_observe, - duration=0, - ) - self.hass.async_create_task(self._api(cmd)) - except PytradfriError as err: - _LOGGER.warning("Observation failed, trying again", exc_info=err) - self._async_start_observe() - - def _refresh(self, switch): - """Refresh the switch data.""" - self._switch = switch - - # Caching of switchControl and switch object - self._available = switch.reachable - self._socket_control = switch.socket_control - self._switch_data = switch.socket_control.sockets[0] - self._name = switch.name - - @callback - def _observe_update(self, tradfri_device): - """Receive new state data for this switch.""" - self._refresh(tradfri_device) - self.async_schedule_update_ha_state() + await self._api(self._device_control.set_state(True))