From 0dd1e0cabbd8e5a8282b19caab3dd32512cc8ddb Mon Sep 17 00:00:00 2001 From: Jan Bouwhuis Date: Wed, 30 Jul 2025 09:06:15 +0200 Subject: [PATCH 1/8] Suppress exception stack trace when writing MQTT entity state if a ValueError occured (#149583) --- homeassistant/components/mqtt/models.py | 9 +++++++++ tests/components/mqtt/test_init.py | 27 +++++++++++++++++++------ 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/homeassistant/components/mqtt/models.py b/homeassistant/components/mqtt/models.py index 8a42797b0f2dfc..4cc0424195a5e2 100644 --- a/homeassistant/components/mqtt/models.py +++ b/homeassistant/components/mqtt/models.py @@ -364,6 +364,15 @@ def process_write_state_requests(self, msg: MQTTMessage) -> None: entity_id, entity = self.subscribe_calls.popitem() try: entity.async_write_ha_state() + except ValueError as exc: + _LOGGER.error( + "Value error while updating state of %s, topic: " + "'%s' with payload: %s: %s", + entity_id, + msg.topic, + msg.payload, + exc, + ) except Exception: _LOGGER.exception( "Exception raised while updating state of %s, topic: " diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index f789d7f3be132b..1aeb9843b54ecd 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -604,6 +604,23 @@ def test_entity_device_info_schema() -> None: ) +@pytest.mark.parametrize( + ("side_effect", "error_message"), + [ + ( + ValueError("Invalid value for sensor"), + "Value error while updating " + "state of sensor.test_sensor, topic: 'test/state' " + "with payload: b'payload causing errors'", + ), + ( + TypeError("Invalid value for sensor"), + "Exception raised while updating " + "state of sensor.test_sensor, topic: 'test/state' " + "with payload: b'payload causing errors'", + ), + ], +) @pytest.mark.parametrize( "hass_config", [ @@ -625,6 +642,8 @@ async def test_handle_logging_on_writing_the_entity_state( hass: HomeAssistant, mqtt_mock_entry: MqttMockHAClientGenerator, caplog: pytest.LogCaptureFixture, + side_effect: Exception, + error_message: str, ) -> None: """Test on log handling when an error occurs writing the state.""" await mqtt_mock_entry() @@ -637,7 +656,7 @@ async def test_handle_logging_on_writing_the_entity_state( assert state.state == "initial_state" with patch( "homeassistant.helpers.entity.Entity.async_write_ha_state", - side_effect=ValueError("Invalid value for sensor"), + side_effect=side_effect, ): async_fire_mqtt_message(hass, "test/state", b"payload causing errors") await hass.async_block_till_done() @@ -645,11 +664,7 @@ async def test_handle_logging_on_writing_the_entity_state( assert state is not None assert state.state == "initial_state" assert "Invalid value for sensor" in caplog.text - assert ( - "Exception raised while updating " - "state of sensor.test_sensor, topic: 'test/state' " - "with payload: b'payload causing errors'" in caplog.text - ) + assert error_message in caplog.text async def test_receiving_non_utf8_message_gets_logged( From 2ee82e1d6f548d04e5dcfd768b42ce5189e2e9a3 Mon Sep 17 00:00:00 2001 From: Robert Resch Date: Wed, 30 Jul 2025 09:24:16 +0200 Subject: [PATCH 2/8] Remove battery attribute from Ecovacs vacuums (#149581) --- homeassistant/components/ecovacs/vacuum.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/homeassistant/components/ecovacs/vacuum.py b/homeassistant/components/ecovacs/vacuum.py index d432410c8c5c86..86a30558375da4 100644 --- a/homeassistant/components/ecovacs/vacuum.py +++ b/homeassistant/components/ecovacs/vacuum.py @@ -8,7 +8,7 @@ from deebot_client.capabilities import Capabilities, DeviceType from deebot_client.device import Device -from deebot_client.events import BatteryEvent, FanSpeedEvent, RoomsEvent, StateEvent +from deebot_client.events import FanSpeedEvent, RoomsEvent, StateEvent from deebot_client.models import CleanAction, CleanMode, Room, State import sucks @@ -216,7 +216,6 @@ class EcovacsVacuum( VacuumEntityFeature.PAUSE | VacuumEntityFeature.STOP | VacuumEntityFeature.RETURN_HOME - | VacuumEntityFeature.BATTERY | VacuumEntityFeature.SEND_COMMAND | VacuumEntityFeature.LOCATE | VacuumEntityFeature.STATE @@ -243,10 +242,6 @@ async def async_added_to_hass(self) -> None: """Set up the event listeners now that hass is ready.""" await super().async_added_to_hass() - async def on_battery(event: BatteryEvent) -> None: - self._attr_battery_level = event.value - self.async_write_ha_state() - async def on_rooms(event: RoomsEvent) -> None: self._rooms = event.rooms self.async_write_ha_state() @@ -255,7 +250,6 @@ async def on_status(event: StateEvent) -> None: self._attr_activity = _STATE_TO_VACUUM_STATE[event.state] self.async_write_ha_state() - self._subscribe(self._capability.battery.event, on_battery) self._subscribe(self._capability.state.event, on_status) if self._capability.fan_speed: From f66e83f33ebe382696edc612c3580c2c3b156942 Mon Sep 17 00:00:00 2001 From: Marcel van der Veldt Date: Wed, 30 Jul 2025 09:54:00 +0200 Subject: [PATCH 3/8] Add dynamic encryption key support to the ESPHome integration (#148746) Co-authored-by: Jesse Hills <3060199+jesserockz@users.noreply.github.com> Co-authored-by: J. Nick Koston --- .../components/esphome/config_flow.py | 30 +- .../esphome/encryption_key_storage.py | 94 ++++ homeassistant/components/esphome/manager.py | 98 +++- tests/components/esphome/test_config_flow.py | 115 +++++ .../esphome/test_dynamic_encryption.py | 102 ++++ tests/components/esphome/test_manager.py | 484 +++++++++++++++++- 6 files changed, 918 insertions(+), 5 deletions(-) create mode 100644 homeassistant/components/esphome/encryption_key_storage.py create mode 100644 tests/components/esphome/test_dynamic_encryption.py diff --git a/homeassistant/components/esphome/config_flow.py b/homeassistant/components/esphome/config_flow.py index 75408246e786f9..dc0e9b8e1b17ff 100644 --- a/homeassistant/components/esphome/config_flow.py +++ b/homeassistant/components/esphome/config_flow.py @@ -51,6 +51,7 @@ DOMAIN, ) from .dashboard import async_get_or_create_dashboard_manager, async_set_dashboard_info +from .encryption_key_storage import async_get_encryption_key_storage from .entry_data import ESPHomeConfigEntry from .manager import async_replace_device @@ -159,7 +160,10 @@ async def async_step_reauth_confirm( """Handle reauthorization flow.""" errors = {} - if await self._retrieve_encryption_key_from_dashboard(): + if ( + await self._retrieve_encryption_key_from_storage() + or await self._retrieve_encryption_key_from_dashboard() + ): error = await self.fetch_device_info() if error is None: return await self._async_authenticate_or_add() @@ -226,9 +230,12 @@ async def _async_try_fetch_device_info(self) -> ConfigFlowResult: response = await self.fetch_device_info() self._noise_psk = None + # Try to retrieve an existing key from dashboard or storage. if ( self._device_name and await self._retrieve_encryption_key_from_dashboard() + ) or ( + self._device_mac and await self._retrieve_encryption_key_from_storage() ): response = await self.fetch_device_info() @@ -284,6 +291,7 @@ async def async_step_zeroconf( self._name = discovery_info.properties.get("friendly_name", device_name) self._host = discovery_info.host self._port = discovery_info.port + self._device_mac = mac_address self._noise_required = bool(discovery_info.properties.get("api_encryption")) # Check if already configured @@ -772,6 +780,26 @@ async def _retrieve_encryption_key_from_dashboard(self) -> bool: self._noise_psk = noise_psk return True + async def _retrieve_encryption_key_from_storage(self) -> bool: + """Try to retrieve the encryption key from storage. + + Return boolean if a key was retrieved. + """ + # Try to get MAC address from current flow state or reauth entry + mac_address = self._device_mac + if mac_address is None and self._reauth_entry is not None: + # In reauth flow, get MAC from the existing entry's unique_id + mac_address = self._reauth_entry.unique_id + + assert mac_address is not None + + storage = await async_get_encryption_key_storage(self.hass) + if stored_key := await storage.async_get_key(mac_address): + self._noise_psk = stored_key + return True + + return False + @staticmethod @callback def async_get_options_flow( diff --git a/homeassistant/components/esphome/encryption_key_storage.py b/homeassistant/components/esphome/encryption_key_storage.py new file mode 100644 index 00000000000000..e4b5ef41c2e12c --- /dev/null +++ b/homeassistant/components/esphome/encryption_key_storage.py @@ -0,0 +1,94 @@ +"""Encryption key storage for ESPHome devices.""" + +from __future__ import annotations + +import logging +from typing import TypedDict + +from homeassistant.core import HomeAssistant +from homeassistant.helpers.json import JSONEncoder +from homeassistant.helpers.singleton import singleton +from homeassistant.helpers.storage import Store +from homeassistant.util.hass_dict import HassKey + +_LOGGER = logging.getLogger(__name__) + +ENCRYPTION_KEY_STORAGE_VERSION = 1 +ENCRYPTION_KEY_STORAGE_KEY = "esphome.encryption_keys" + + +class EncryptionKeyData(TypedDict): + """Encryption key storage data.""" + + keys: dict[str, str] # MAC address -> base64 encoded key + + +KEY_ENCRYPTION_STORAGE: HassKey[ESPHomeEncryptionKeyStorage] = HassKey( + "esphome_encryption_key_storage" +) + + +class ESPHomeEncryptionKeyStorage: + """Storage for ESPHome encryption keys.""" + + def __init__(self, hass: HomeAssistant) -> None: + """Initialize the encryption key storage.""" + self.hass = hass + self._store = Store[EncryptionKeyData]( + hass, + ENCRYPTION_KEY_STORAGE_VERSION, + ENCRYPTION_KEY_STORAGE_KEY, + encoder=JSONEncoder, + ) + self._data: EncryptionKeyData | None = None + + async def async_load(self) -> None: + """Load encryption keys from storage.""" + if self._data is None: + data = await self._store.async_load() + self._data = data or {"keys": {}} + + async def async_save(self) -> None: + """Save encryption keys to storage.""" + if self._data is not None: + await self._store.async_save(self._data) + + async def async_get_key(self, mac_address: str) -> str | None: + """Get encryption key for a MAC address.""" + await self.async_load() + assert self._data is not None + return self._data["keys"].get(mac_address.lower()) + + async def async_store_key(self, mac_address: str, key: str) -> None: + """Store encryption key for a MAC address.""" + await self.async_load() + assert self._data is not None + self._data["keys"][mac_address.lower()] = key + await self.async_save() + _LOGGER.debug( + "Stored encryption key for device with MAC %s", + mac_address, + ) + + async def async_remove_key(self, mac_address: str) -> None: + """Remove encryption key for a MAC address.""" + await self.async_load() + assert self._data is not None + lower_mac_address = mac_address.lower() + if lower_mac_address in self._data["keys"]: + del self._data["keys"][lower_mac_address] + await self.async_save() + _LOGGER.debug( + "Removed encryption key for device with MAC %s", + mac_address, + ) + + +@singleton(KEY_ENCRYPTION_STORAGE, async_=True) +async def async_get_encryption_key_storage( + hass: HomeAssistant, +) -> ESPHomeEncryptionKeyStorage: + """Get the encryption key storage instance.""" + storage = ESPHomeEncryptionKeyStorage(hass) + await storage.async_load() + return storage diff --git a/homeassistant/components/esphome/manager.py b/homeassistant/components/esphome/manager.py index 5e9e11171af52a..4d5de77b1e05d3 100644 --- a/homeassistant/components/esphome/manager.py +++ b/homeassistant/components/esphome/manager.py @@ -3,8 +3,10 @@ from __future__ import annotations import asyncio +import base64 from functools import partial import logging +import secrets from typing import TYPE_CHECKING, Any, NamedTuple from aioesphomeapi import ( @@ -68,6 +70,7 @@ CONF_ALLOW_SERVICE_CALLS, CONF_BLUETOOTH_MAC_ADDRESS, CONF_DEVICE_NAME, + CONF_NOISE_PSK, CONF_SUBSCRIBE_LOGS, DEFAULT_ALLOW_SERVICE_CALLS, DEFAULT_URL, @@ -78,6 +81,7 @@ ) from .dashboard import async_get_dashboard from .domain_data import DomainData +from .encryption_key_storage import async_get_encryption_key_storage # Import config flow so that it's added to the registry from .entry_data import ESPHomeConfigEntry, RuntimeEntryData @@ -85,9 +89,7 @@ DEVICE_CONFLICT_ISSUE_FORMAT = "device_conflict-{}" if TYPE_CHECKING: - from aioesphomeapi.api_pb2 import ( # type: ignore[attr-defined] - SubscribeLogsResponse, - ) + from aioesphomeapi.api_pb2 import SubscribeLogsResponse # type: ignore[attr-defined] # noqa: I001 _LOGGER = logging.getLogger(__name__) @@ -515,6 +517,8 @@ async def _on_connect(self) -> None: assert api_version is not None, "API version must be set" entry_data.async_on_connect(device_info, api_version) + await self._handle_dynamic_encryption_key(device_info) + if device_info.name: reconnect_logic.name = device_info.name @@ -618,6 +622,7 @@ async def on_connect_error(self, err: Exception) -> None: ), ): return + if isinstance(err, InvalidEncryptionKeyAPIError): if ( (received_name := err.received_name) @@ -648,6 +653,93 @@ async def on_connect_error(self, err: Exception) -> None: return self.entry.async_start_reauth(self.hass) + async def _handle_dynamic_encryption_key( + self, device_info: EsphomeDeviceInfo + ) -> None: + """Handle dynamic encryption keys. + + If a device reports it supports encryption, but we connected without a key, + we need to generate and store one. + """ + noise_psk: str | None = self.entry.data.get(CONF_NOISE_PSK) + if noise_psk: + # we're already connected with a noise PSK - nothing to do + return + + if not device_info.api_encryption_supported: + # device does not support encryption - nothing to do + return + + # Connected to device without key and the device supports encryption + storage = await async_get_encryption_key_storage(self.hass) + + # First check if we have a key in storage for this device + from_storage: bool = False + if self.entry.unique_id and ( + stored_key := await storage.async_get_key(self.entry.unique_id) + ): + _LOGGER.debug( + "Retrieved encryption key from storage for device %s", + self.entry.unique_id, + ) + # Use the stored key + new_key = stored_key.encode() + new_key_str = stored_key + from_storage = True + else: + # No stored key found, generate a new one + _LOGGER.debug( + "Generating new encryption key for device %s", self.entry.unique_id + ) + new_key = base64.b64encode(secrets.token_bytes(32)) + new_key_str = new_key.decode() + + try: + # Store the key on the device using the existing connection + result = await self.cli.noise_encryption_set_key(new_key) + except APIConnectionError as ex: + _LOGGER.error( + "Connection error while storing encryption key for device %s (%s): %s", + self.entry.data.get(CONF_DEVICE_NAME, self.host), + self.entry.unique_id, + ex, + ) + return + else: + if not result: + _LOGGER.error( + "Failed to set dynamic encryption key on device %s (%s)", + self.entry.data.get(CONF_DEVICE_NAME, self.host), + self.entry.unique_id, + ) + return + + # Key stored successfully on device + assert self.entry.unique_id is not None + + # Only store in storage if it was newly generated + if not from_storage: + await storage.async_store_key(self.entry.unique_id, new_key_str) + + # Always update config entry + self.hass.config_entries.async_update_entry( + self.entry, + data={**self.entry.data, CONF_NOISE_PSK: new_key_str}, + ) + + if from_storage: + _LOGGER.info( + "Set encryption key from storage on device %s (%s)", + self.entry.data.get(CONF_DEVICE_NAME, self.host), + self.entry.unique_id, + ) + else: + _LOGGER.info( + "Generated and stored encryption key for device %s (%s)", + self.entry.data.get(CONF_DEVICE_NAME, self.host), + self.entry.unique_id, + ) + @callback def _async_handle_logging_changed(self, _event: Event) -> None: """Handle when the logging level changes.""" diff --git a/tests/components/esphome/test_config_flow.py b/tests/components/esphome/test_config_flow.py index 3f0148262e46de..d76991a984c24f 100644 --- a/tests/components/esphome/test_config_flow.py +++ b/tests/components/esphome/test_config_flow.py @@ -27,6 +27,9 @@ DEFAULT_NEW_CONFIG_ALLOW_ALLOW_SERVICE_CALLS, DOMAIN, ) +from homeassistant.components.esphome.encryption_key_storage import ( + ENCRYPTION_KEY_STORAGE_KEY, +) from homeassistant.config_entries import SOURCE_IGNORE, ConfigFlowResult from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_PORT from homeassistant.core import HomeAssistant @@ -41,6 +44,118 @@ from tests.common import MockConfigEntry + +async def test_retrieve_encryption_key_from_storage_with_device_mac( + hass: HomeAssistant, + mock_client: APIClient, + hass_storage: dict[str, Any], +) -> None: + """Test key successfully retrieved from storage.""" + + # Mock the encryption key storage + hass_storage[ENCRYPTION_KEY_STORAGE_KEY] = { + "version": 1, + "minor_version": 1, + "key": ENCRYPTION_KEY_STORAGE_KEY, + "data": {"keys": {"11:22:33:44:55:aa": VALID_NOISE_PSK}}, + } + + mock_client.device_info.side_effect = [ + RequiresEncryptionAPIError, + InvalidEncryptionKeyAPIError("Wrong key", "test", "11:22:33:44:55:AA"), + DeviceInfo( + uses_password=False, + name="test", + mac_address="11:22:33:44:55:AA", + ), + ] + + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": config_entries.SOURCE_USER}, + data={CONF_HOST: "127.0.0.1", CONF_PORT: 6053}, + ) + await hass.async_block_till_done() + + assert result["type"] is FlowResultType.CREATE_ENTRY + assert result["data"] == { + CONF_HOST: "127.0.0.1", + CONF_PORT: 6053, + CONF_PASSWORD: "", + CONF_NOISE_PSK: VALID_NOISE_PSK, + CONF_DEVICE_NAME: "test", + } + + assert mock_client.noise_psk == VALID_NOISE_PSK + + +async def test_reauth_fixed_from_from_storage( + hass: HomeAssistant, + mock_client: APIClient, + hass_storage: dict[str, Any], +) -> None: + """Test reauth fixed automatically via storage.""" + + # Mock the encryption key storage + hass_storage[ENCRYPTION_KEY_STORAGE_KEY] = { + "version": 1, + "minor_version": 1, + "key": ENCRYPTION_KEY_STORAGE_KEY, + "data": {"keys": {"11:22:33:44:55:aa": VALID_NOISE_PSK}}, + } + + entry = MockConfigEntry( + domain=DOMAIN, + data={ + CONF_HOST: "127.0.0.1", + CONF_PORT: 6053, + CONF_PASSWORD: "", + CONF_DEVICE_NAME: "test", + }, + unique_id="11:22:33:44:55:aa", + ) + entry.add_to_hass(hass) + + mock_client.device_info.return_value = DeviceInfo( + uses_password=False, name="test", mac_address="11:22:33:44:55:aa" + ) + + result = await entry.start_reauth_flow(hass) + + assert result["type"] is FlowResultType.ABORT, result + assert result["reason"] == "reauth_successful" + assert entry.data[CONF_NOISE_PSK] == VALID_NOISE_PSK + + +async def test_retrieve_encryption_key_from_storage_no_key_found( + hass: HomeAssistant, + mock_client: APIClient, +) -> None: + """Test _retrieve_encryption_key_from_storage when no key is found.""" + + entry = MockConfigEntry( + domain=DOMAIN, + data={ + CONF_HOST: "127.0.0.1", + CONF_PORT: 6053, + CONF_PASSWORD: "", + CONF_DEVICE_NAME: "test", + }, + unique_id="11:22:33:44:55:aa", + ) + entry.add_to_hass(hass) + + mock_client.device_info.return_value = DeviceInfo( + uses_password=False, name="test", mac_address="11:22:33:44:55:aa" + ) + + result = await entry.start_reauth_flow(hass) + + assert result["type"] is FlowResultType.FORM, result + assert result["step_id"] == "reauth_confirm" + assert CONF_NOISE_PSK not in entry.data + + INVALID_NOISE_PSK = "lSYBYEjQI1bVL8s2Vask4YytGMj1f1epNtmoim2yuTM=" WRONG_NOISE_PSK = "GP+ciK+nVfTQ/gcz6uOdS+oKEdJgesU+jeu8Ssj2how=" diff --git a/tests/components/esphome/test_dynamic_encryption.py b/tests/components/esphome/test_dynamic_encryption.py new file mode 100644 index 00000000000000..cbdcc35aea2e79 --- /dev/null +++ b/tests/components/esphome/test_dynamic_encryption.py @@ -0,0 +1,102 @@ +"""Tests for ESPHome dynamic encryption key generation.""" + +from __future__ import annotations + +import base64 + +from homeassistant.components.esphome.encryption_key_storage import ( + ESPHomeEncryptionKeyStorage, + async_get_encryption_key_storage, +) +from homeassistant.core import HomeAssistant + + +async def test_dynamic_encryption_key_generation_mock(hass: HomeAssistant) -> None: + """Test that encryption key generation works with mocked storage.""" + storage = await async_get_encryption_key_storage(hass) + + # Store a key + mac_address = "11:22:33:44:55:aa" + test_key = base64.b64encode(b"test_key_32_bytes_long_exactly!").decode() + + await storage.async_store_key(mac_address, test_key) + + # Retrieve a key + retrieved_key = await storage.async_get_key(mac_address) + assert retrieved_key == test_key + + +async def test_encryption_key_storage_remove_key(hass: HomeAssistant) -> None: + """Test ESPHomeEncryptionKeyStorage async_remove_key method.""" + # Create storage instance + storage = ESPHomeEncryptionKeyStorage(hass) + + # Test removing a key that exists + mac_address = "11:22:33:44:55:aa" + test_key = "test_encryption_key_32_bytes_long" + + # First store a key + await storage.async_store_key(mac_address, test_key) + + # Verify key exists + retrieved_key = await storage.async_get_key(mac_address) + assert retrieved_key == test_key + + # Remove the key + await storage.async_remove_key(mac_address) + + # Verify key no longer exists + retrieved_key = await storage.async_get_key(mac_address) + assert retrieved_key is None + + # Test removing a key that doesn't exist (should not raise an error) + non_existent_mac = "aa:bb:cc:dd:ee:ff" + await storage.async_remove_key(non_existent_mac) # Should not raise + + # Test case insensitive removal + upper_mac = "22:33:44:55:66:77" + await storage.async_store_key(upper_mac, test_key) + + # Remove using lowercase MAC address + await storage.async_remove_key(upper_mac.lower()) + + # Verify key was removed + retrieved_key = await storage.async_get_key(upper_mac) + assert retrieved_key is None + + +async def test_encryption_key_basic_storage( + hass: HomeAssistant, +) -> None: + """Test basic encryption key storage functionality.""" + storage = await async_get_encryption_key_storage(hass) + mac_address = "11:22:33:44:55:aa" + key = "test_encryption_key_32_bytes_long" + + # Store key + await storage.async_store_key(mac_address, key) + + # Retrieve key + retrieved_key = await storage.async_get_key(mac_address) + assert retrieved_key == key + + +async def test_retrieve_key_from_storage( + hass: HomeAssistant, +) -> None: + """Test config flow can retrieve encryption key from storage for new device.""" + # Test that the encryption key storage integration works with config flow + storage = await async_get_encryption_key_storage(hass) + mac_address = "11:22:33:44:55:aa" + stored_key = "test_encryption_key_32_bytes_long" + + # Store encryption key for a device + await storage.async_store_key(mac_address, stored_key) + + # Verify the key can be retrieved (simulating config flow behavior) + retrieved_key = await storage.async_get_key(mac_address) + assert retrieved_key == stored_key + + # Test case insensitive retrieval (since config flows might use different case) + retrieved_key_upper = await storage.async_get_key(mac_address.upper()) + assert retrieved_key_upper == stored_key diff --git a/tests/components/esphome/test_manager.py b/tests/components/esphome/test_manager.py index 318ccde221f78a..8d2dd211869a7d 100644 --- a/tests/components/esphome/test_manager.py +++ b/tests/components/esphome/test_manager.py @@ -1,8 +1,10 @@ """Test ESPHome manager.""" import asyncio +import base64 import logging -from unittest.mock import AsyncMock, Mock, call +from typing import Any +from unittest.mock import AsyncMock, Mock, call, patch from aioesphomeapi import ( APIClient, @@ -27,11 +29,15 @@ CONF_ALLOW_SERVICE_CALLS, CONF_BLUETOOTH_MAC_ADDRESS, CONF_DEVICE_NAME, + CONF_NOISE_PSK, CONF_SUBSCRIBE_LOGS, DOMAIN, STABLE_BLE_URL_VERSION, STABLE_BLE_VERSION_STR, ) +from homeassistant.components.esphome.encryption_key_storage import ( + ENCRYPTION_KEY_STORAGE_KEY, +) from homeassistant.components.esphome.manager import DEVICE_CONFLICT_ISSUE_FORMAT from homeassistant.components.tag import DOMAIN as TAG_DOMAIN from homeassistant.const import ( @@ -1788,3 +1794,479 @@ async def test_sub_device_references_main_device_area( ) assert sub_device_3 is not None assert sub_device_3.suggested_area == "Bedroom" + + +@patch("homeassistant.components.esphome.manager.secrets.token_bytes") +async def test_dynamic_encryption_key_generation( + mock_token_bytes: Mock, + hass: HomeAssistant, + mock_client: APIClient, + mock_esphome_device: MockESPHomeDeviceType, + hass_storage: dict[str, Any], +) -> None: + """Test that a device without a key in storage gets a new one generated.""" + mac_address = "11:22:33:44:55:aa" + test_key_bytes = b"test_key_32_bytes_long_exactly!" + mock_token_bytes.return_value = test_key_bytes + expected_key = base64.b64encode(test_key_bytes).decode() + + # Create entry without noise PSK + entry = MockConfigEntry( + domain=DOMAIN, + data={ + CONF_HOST: "192.168.1.100", + CONF_PORT: 6053, + CONF_PASSWORD: "", + CONF_DEVICE_NAME: "test-device", + }, + unique_id=mac_address, + ) + entry.add_to_hass(hass) + + # Mock the client methods + mock_client.noise_encryption_set_key = AsyncMock(return_value=True) + + # Set up device with encryption support + device = await mock_esphome_device( + mock_client=mock_client, + entry=entry, + device_info={ + "uses_password": False, + "name": "test-device", + "mac_address": mac_address, + "esphome_version": "2023.12.0", + "api_encryption_supported": True, + }, + ) + + # Force reconnect to trigger key generation + await device.mock_disconnect(True) + await device.mock_connect() + + # Verify the key was generated and set + mock_token_bytes.assert_called_once_with(32) + mock_client.noise_encryption_set_key.assert_called_once() + + # Verify config entry was updated + assert entry.data[CONF_NOISE_PSK] == expected_key + + +async def test_manager_retrieves_key_from_storage_on_reconnect( + hass: HomeAssistant, + mock_client: APIClient, + mock_esphome_device: MockESPHomeDeviceType, + hass_storage: dict[str, Any], +) -> None: + """Test that manager retrieves encryption key from storage during reconnect.""" + mac_address = "11:22:33:44:55:aa" + test_key = base64.b64encode(b"existing_key_32_bytes_long!!!").decode() + + # Set up storage with existing key + hass_storage[ENCRYPTION_KEY_STORAGE_KEY] = { + "version": 1, + "minor_version": 1, + "key": ENCRYPTION_KEY_STORAGE_KEY, + "data": {"keys": {mac_address: test_key}}, + } + + # Create entry without noise PSK (will be loaded from storage) + entry = MockConfigEntry( + domain=DOMAIN, + data={ + CONF_HOST: "192.168.1.100", + CONF_PORT: 6053, + CONF_PASSWORD: "", + CONF_DEVICE_NAME: "test-device", + }, + unique_id=mac_address, + ) + entry.add_to_hass(hass) + + # Mock the client methods + mock_client.noise_encryption_set_key = AsyncMock(return_value=True) + + # Set up device with encryption support + device = await mock_esphome_device( + mock_client=mock_client, + entry=entry, + device_info={ + "uses_password": False, + "name": "test-device", + "mac_address": mac_address, + "esphome_version": "2023.12.0", + "api_encryption_supported": True, + }, + ) + + # Force reconnect to trigger key retrieval from storage + await device.mock_disconnect(True) + await device.mock_connect() + + # Verify noise_encryption_set_key was called with the stored key + mock_client.noise_encryption_set_key.assert_called_once_with(test_key.encode()) + + # Verify config entry was updated with key from storage + assert entry.data[CONF_NOISE_PSK] == test_key + + +async def test_manager_handle_dynamic_encryption_key_guard_clauses( + hass: HomeAssistant, + mock_client: APIClient, + mock_esphome_device: MockESPHomeDeviceType, +) -> None: + """Test _handle_dynamic_encryption_key guard clauses and early returns.""" + # Test guard clause - no unique_id + entry_no_id = MockConfigEntry( + domain=DOMAIN, + data={ + CONF_HOST: "192.168.1.100", + CONF_PORT: 6053, + CONF_PASSWORD: "", + CONF_DEVICE_NAME: "test-device", + }, + unique_id=None, # No unique ID - should not generate key + ) + entry_no_id.add_to_hass(hass) + + # Set up device without unique ID + device = await mock_esphome_device( + mock_client=mock_client, + entry=entry_no_id, + device_info={ + "uses_password": False, + "name": "test-device", + "mac_address": "11:22:33:44:55:aa", + "esphome_version": "2023.12.0", + "api_encryption_supported": True, + }, + ) + + # noise_encryption_set_key should not be called when no unique_id + mock_client.noise_encryption_set_key = AsyncMock() + await device.mock_disconnect(True) + await device.mock_connect() + + mock_client.noise_encryption_set_key.assert_not_called() + + +async def test_manager_handle_dynamic_encryption_key_edge_cases( + hass: HomeAssistant, + mock_client: APIClient, + mock_esphome_device: MockESPHomeDeviceType, +) -> None: + """Test _handle_dynamic_encryption_key edge cases for better coverage.""" + mac_address = "11:22:33:44:55:aa" + + # Test device without encryption support + entry = MockConfigEntry( + domain=DOMAIN, + data={ + CONF_HOST: "192.168.1.100", + CONF_PORT: 6053, + CONF_PASSWORD: "", + CONF_DEVICE_NAME: "test-device", + }, + unique_id=mac_address, + ) + entry.add_to_hass(hass) + + # Set up device without encryption support + device = await mock_esphome_device( + mock_client=mock_client, + entry=entry, + device_info={ + "uses_password": False, + "name": "test-device", + "mac_address": mac_address, + "esphome_version": "2023.12.0", + "api_encryption_supported": False, # No encryption support + }, + ) + + # noise_encryption_set_key should not be called when encryption not supported + mock_client.noise_encryption_set_key = AsyncMock() + await device.mock_disconnect(True) + await device.mock_connect() + + mock_client.noise_encryption_set_key.assert_not_called() + + +@patch("homeassistant.components.esphome.manager.secrets.token_bytes") +async def test_manager_dynamic_encryption_key_generation_flow( + mock_token_bytes: Mock, + hass: HomeAssistant, + mock_client: APIClient, + mock_esphome_device: MockESPHomeDeviceType, + hass_storage: dict[str, Any], +) -> None: + """Test the complete dynamic encryption key generation flow.""" + mac_address = "11:22:33:44:55:aa" + test_key_bytes = b"test_key_32_bytes_long_exactly!" + mock_token_bytes.return_value = test_key_bytes + expected_key = base64.b64encode(test_key_bytes).decode() + + # Initialize empty storage + hass_storage[ENCRYPTION_KEY_STORAGE_KEY] = { + "version": 1, + "minor_version": 1, + "key": ENCRYPTION_KEY_STORAGE_KEY, + "data": { + "keys": {} # No existing keys + }, + } + + # Create entry without noise PSK + entry = MockConfigEntry( + domain=DOMAIN, + data={ + CONF_HOST: "192.168.1.100", + CONF_PORT: 6053, + CONF_PASSWORD: "", + CONF_DEVICE_NAME: "test-device", + }, + unique_id=mac_address, + ) + entry.add_to_hass(hass) + + # Mock the client methods + mock_client.noise_encryption_set_key = AsyncMock(return_value=True) + + # Set up device with encryption support + device = await mock_esphome_device( + mock_client=mock_client, + entry=entry, + device_info={ + "uses_password": False, + "name": "test-device", + "mac_address": mac_address, + "esphome_version": "2023.12.0", + "api_encryption_supported": True, + }, + ) + + # Force reconnect to trigger key generation + await device.mock_disconnect(True) + await device.mock_connect() + + # Verify the complete flow + mock_token_bytes.assert_called_once_with(32) + mock_client.noise_encryption_set_key.assert_called_once() + assert entry.data[CONF_NOISE_PSK] == expected_key + + # Verify key was stored in hass_storage + assert ( + hass_storage[ENCRYPTION_KEY_STORAGE_KEY]["data"]["keys"][mac_address] + == expected_key + ) + + +@patch("homeassistant.components.esphome.manager.secrets.token_bytes") +async def test_manager_handle_dynamic_encryption_key_no_existing_key( + mock_token_bytes: Mock, + hass: HomeAssistant, + mock_client: APIClient, + mock_esphome_device: MockESPHomeDeviceType, + hass_storage: dict[str, Any], +) -> None: + """Test _handle_dynamic_encryption_key when no existing key is found.""" + mac_address = "11:22:33:44:55:aa" + test_key_bytes = b"test_key_32_bytes_long_exactly!" + mock_token_bytes.return_value = test_key_bytes + expected_key = base64.b64encode(test_key_bytes).decode() + + # Initialize empty storage + hass_storage[ENCRYPTION_KEY_STORAGE_KEY] = { + "version": 1, + "minor_version": 1, + "key": ENCRYPTION_KEY_STORAGE_KEY, + "data": { + "keys": {} # No existing keys + }, + } + + # Create entry without noise PSK + entry = MockConfigEntry( + domain=DOMAIN, + data={ + CONF_HOST: "192.168.1.100", + CONF_PORT: 6053, + CONF_PASSWORD: "", + CONF_DEVICE_NAME: "test-device", + }, + unique_id=mac_address, + ) + entry.add_to_hass(hass) + + # Mock the client methods + mock_client.noise_encryption_set_key = AsyncMock(return_value=True) + + # Set up device with encryption support + device = await mock_esphome_device( + mock_client=mock_client, + entry=entry, + device_info={ + "uses_password": False, + "name": "test-device", + "mac_address": mac_address, + "esphome_version": "2023.12.0", + "api_encryption_supported": True, + }, + ) + + # Force reconnect to trigger key generation + await device.mock_disconnect(True) + await device.mock_connect() + + # Verify key generation flow + mock_token_bytes.assert_called_once_with(32) + mock_client.noise_encryption_set_key.assert_called_once() + + # Verify config entry was updated + assert entry.data[CONF_NOISE_PSK] == expected_key + + # Verify key was stored + assert ( + hass_storage[ENCRYPTION_KEY_STORAGE_KEY]["data"]["keys"][mac_address] + == expected_key + ) + + +@patch("homeassistant.components.esphome.manager.secrets.token_bytes") +async def test_manager_handle_dynamic_encryption_key_device_set_key_fails( + mock_token_bytes: Mock, + hass: HomeAssistant, + mock_client: APIClient, + mock_esphome_device: MockESPHomeDeviceType, + hass_storage: dict[str, Any], +) -> None: + """Test _handle_dynamic_encryption_key when noise_encryption_set_key returns False.""" + mac_address = "11:22:33:44:55:aa" + test_key_bytes = b"test_key_32_bytes_long_exactly!" + mock_token_bytes.return_value = test_key_bytes + + # Initialize empty storage + hass_storage[ENCRYPTION_KEY_STORAGE_KEY] = { + "version": 1, + "minor_version": 1, + "key": ENCRYPTION_KEY_STORAGE_KEY, + "data": { + "keys": {} # No existing keys + }, + } + + # Create entry without noise PSK + entry = MockConfigEntry( + domain=DOMAIN, + data={ + CONF_HOST: "192.168.1.100", + CONF_PORT: 6053, + CONF_PASSWORD: "", + CONF_DEVICE_NAME: "test-device", + }, + unique_id=mac_address, + ) + entry.add_to_hass(hass) + + # Mock the client methods - set_key returns False + mock_client.noise_encryption_set_key = AsyncMock(return_value=False) + + # Set up device with encryption support + device = await mock_esphome_device( + mock_client=mock_client, + entry=entry, + device_info={ + "uses_password": False, + "name": "test-device", + "mac_address": mac_address, + "esphome_version": "2023.12.0", + "api_encryption_supported": True, + }, + ) + + # Reset mocks since initial connection already happened + mock_token_bytes.reset_mock() + mock_client.noise_encryption_set_key.reset_mock() + + # Force reconnect to trigger key generation + await device.mock_disconnect(True) + await device.mock_connect() + + # Verify key generation was attempted with the expected key + mock_token_bytes.assert_called_once_with(32) + mock_client.noise_encryption_set_key.assert_called_once_with( + base64.b64encode(test_key_bytes) + ) + + # Verify config entry was NOT updated since set_key failed + assert CONF_NOISE_PSK not in entry.data + + +@patch("homeassistant.components.esphome.manager.secrets.token_bytes") +async def test_manager_handle_dynamic_encryption_key_connection_error( + mock_token_bytes: Mock, + hass: HomeAssistant, + mock_client: APIClient, + mock_esphome_device: MockESPHomeDeviceType, + hass_storage: dict[str, Any], +) -> None: + """Test _handle_dynamic_encryption_key when noise_encryption_set_key raises APIConnectionError.""" + mac_address = "11:22:33:44:55:aa" + test_key_bytes = b"test_key_32_bytes_long_exactly!" + mock_token_bytes.return_value = test_key_bytes + + # Initialize empty storage + hass_storage[ENCRYPTION_KEY_STORAGE_KEY] = { + "version": 1, + "minor_version": 1, + "key": ENCRYPTION_KEY_STORAGE_KEY, + "data": { + "keys": {} # No existing keys + }, + } + + # Create entry without noise PSK + entry = MockConfigEntry( + domain=DOMAIN, + data={ + CONF_HOST: "192.168.1.100", + CONF_PORT: 6053, + CONF_PASSWORD: "", + CONF_DEVICE_NAME: "test-device", + }, + unique_id=mac_address, + ) + entry.add_to_hass(hass) + + # Mock the client methods - set_key raises APIConnectionError + mock_client.noise_encryption_set_key = AsyncMock( + side_effect=APIConnectionError("Connection failed") + ) + + # Set up device with encryption support + device = await mock_esphome_device( + mock_client=mock_client, + entry=entry, + device_info={ + "uses_password": False, + "name": "test-device", + "mac_address": mac_address, + "esphome_version": "2023.12.0", + "api_encryption_supported": True, + }, + ) + + # Force reconnect to trigger key generation + await device.mock_disconnect(True) + await device.mock_connect() + + # Verify key generation was attempted twice (once during setup, once during reconnect) + # This is expected because the first attempt failed with connection error + assert mock_token_bytes.call_count == 2 + mock_token_bytes.assert_called_with(32) + assert mock_client.noise_encryption_set_key.call_count == 2 + + # Verify config entry was NOT updated since connection error occurred + assert CONF_NOISE_PSK not in entry.data + + # Verify key was NOT stored due to connection error + assert mac_address not in hass_storage[ENCRYPTION_KEY_STORAGE_KEY]["data"]["keys"] From 6f8214bbb47364d376cb45794248d11ea307bc74 Mon Sep 17 00:00:00 2001 From: Norbert Rittel Date: Wed, 30 Jul 2025 10:22:35 +0200 Subject: [PATCH 4/8] Fix spelling mistakes in abort message of `leaone` (#149653) --- homeassistant/components/leaone/strings.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/homeassistant/components/leaone/strings.json b/homeassistant/components/leaone/strings.json index bb6849411470c3..53332ce2fece06 100644 --- a/homeassistant/components/leaone/strings.json +++ b/homeassistant/components/leaone/strings.json @@ -13,7 +13,7 @@ } }, "abort": { - "no_devices_found": "No supported LeaOne devices found in range; If the device is in range, ensure it has been activated in the last few minutes. If you need clarification on whether the device is in-range, download the diagnostics for the integration that provides your Bluetooth adapter or proxy and check if the MAC address of the LeaOne device is present.", + "no_devices_found": "No supported LeaOne devices found in range. If the device is in range, ensure it has been activated in the last few minutes. If you need clarification on whether the device is in range, download the diagnostics for the integration that provides your Bluetooth adapter or proxy and check if the MAC address of the LeaOne device is present.", "already_in_progress": "[%key:common::config_flow::abort::already_in_progress%]", "already_configured": "[%key:common::config_flow::abort::already_configured_device%]" } From 6b641411a01e036a38e77c3361cbc5dec922753e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 30 Jul 2025 10:33:09 +0200 Subject: [PATCH 5/8] Bump github/codeql-action from 3.29.4 to 3.29.5 (#149648) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/codeql.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index cc6014b38b04aa..c5dcf19ce6ed01 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -24,11 +24,11 @@ jobs: uses: actions/checkout@v4.2.2 - name: Initialize CodeQL - uses: github/codeql-action/init@v3.29.4 + uses: github/codeql-action/init@v3.29.5 with: languages: python - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v3.29.4 + uses: github/codeql-action/analyze@v3.29.5 with: category: "/language:python" From 8e9e304608e3a58d61e8d99164847e193c38a3af Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Wed, 30 Jul 2025 10:38:42 +0200 Subject: [PATCH 6/8] Update lxml to 6.0.0 (#149640) --- homeassistant/components/scrape/manifest.json | 2 +- requirements_all.txt | 2 +- requirements_test_all.txt | 2 +- script/hassfest/requirements.py | 1 + 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/homeassistant/components/scrape/manifest.json b/homeassistant/components/scrape/manifest.json index 28e08372d68a28..8b9d7ddf37e506 100644 --- a/homeassistant/components/scrape/manifest.json +++ b/homeassistant/components/scrape/manifest.json @@ -6,5 +6,5 @@ "config_flow": true, "documentation": "https://www.home-assistant.io/integrations/scrape", "iot_class": "cloud_polling", - "requirements": ["beautifulsoup4==4.13.3", "lxml==5.3.0"] + "requirements": ["beautifulsoup4==4.13.3", "lxml==6.0.0"] } diff --git a/requirements_all.txt b/requirements_all.txt index 1e2f4ec081ea82..eafa0b0d47f7e3 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -1391,7 +1391,7 @@ lupupy==0.3.2 lw12==0.9.2 # homeassistant.components.scrape -lxml==5.3.0 +lxml==6.0.0 # homeassistant.components.matrix matrix-nio==0.25.2 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index d42453d82fe1bc..b4ed33e539b0a8 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -1189,7 +1189,7 @@ luftdaten==0.7.4 lupupy==0.3.2 # homeassistant.components.scrape -lxml==5.3.0 +lxml==6.0.0 # homeassistant.components.matrix matrix-nio==0.25.2 diff --git a/script/hassfest/requirements.py b/script/hassfest/requirements.py index 9c3f60a827c7d9..99a1c255e60f74 100644 --- a/script/hassfest/requirements.py +++ b/script/hassfest/requirements.py @@ -30,6 +30,7 @@ "bleak": "SemVer", "grpcio": "SemVer", "httpx": "SemVer", + "lxml": "SemVer", "mashumaro": "SemVer", "numpy": "SemVer", "pandas": "SemVer", From bb6bcfdd0158a03923751b7e73a21ce132ed75ad Mon Sep 17 00:00:00 2001 From: Martin Hjelmare Date: Wed, 30 Jul 2025 11:07:41 +0200 Subject: [PATCH 7/8] Add Z-Wave controller firmware updates (#149623) --- homeassistant/components/zwave_js/__init__.py | 17 +- homeassistant/components/zwave_js/update.py | 128 +++- tests/components/zwave_js/test_update.py | 707 ++++++++++++------ 3 files changed, 581 insertions(+), 271 deletions(-) diff --git a/homeassistant/components/zwave_js/__init__.py b/homeassistant/components/zwave_js/__init__.py index d754419c94c698..360969e83d4747 100644 --- a/homeassistant/components/zwave_js/__init__.py +++ b/homeassistant/components/zwave_js/__init__.py @@ -147,6 +147,7 @@ }, extra=vol.ALLOW_EXTRA, ) +MIN_CONTROLLER_FIRMWARE_SDK_VERSION = AwesomeVersion("6.50.0") PLATFORMS = [ Platform.BINARY_SENSOR, @@ -799,11 +800,19 @@ async def async_on_node_ready(self, node: ZwaveNode) -> None: node.on("notification", self.async_on_notification) ) - # Create a firmware update entity for each non-controller device that + # Create a firmware update entity for each device that # supports firmware updates - if not node.is_controller_node and any( - cc.id == CommandClass.FIRMWARE_UPDATE_MD.value - for cc in node.command_classes + controller = self.controller_events.driver_events.driver.controller + if ( + not (is_controller_node := node.is_controller_node) + and any( + cc.id == CommandClass.FIRMWARE_UPDATE_MD.value + for cc in node.command_classes + ) + ) or ( + is_controller_node + and (sdk_version := controller.sdk_version) is not None + and sdk_version >= MIN_CONTROLLER_FIRMWARE_SDK_VERSION ): async_dispatcher_send( self.hass, diff --git a/homeassistant/components/zwave_js/update.py b/homeassistant/components/zwave_js/update.py index 89fb4dd4abaa5b..42a4b4cf6dd60a 100644 --- a/homeassistant/components/zwave_js/update.py +++ b/homeassistant/components/zwave_js/update.py @@ -4,26 +4,28 @@ import asyncio from collections import Counter -from collections.abc import Callable +from collections.abc import Awaitable, Callable from dataclasses import dataclass from datetime import datetime, timedelta -from typing import Any, Final +from typing import Any, Final, cast from awesomeversion import AwesomeVersion from zwave_js_server.const import NodeStatus from zwave_js_server.exceptions import BaseZwaveJSServerError, FailedZWaveCommand from zwave_js_server.model.driver import Driver -from zwave_js_server.model.node import Node as ZwaveNode -from zwave_js_server.model.node.firmware import ( - NodeFirmwareUpdateInfo, - NodeFirmwareUpdateProgress, - NodeFirmwareUpdateResult, +from zwave_js_server.model.firmware import ( + FirmwareUpdateInfo, + FirmwareUpdateProgress, + FirmwareUpdateResult, ) +from zwave_js_server.model.node import Node as ZwaveNode +from zwave_js_server.model.node.firmware import NodeFirmwareUpdateInfo from homeassistant.components.update import ( ATTR_LATEST_VERSION, UpdateDeviceClass, UpdateEntity, + UpdateEntityDescription, UpdateEntityFeature, ) from homeassistant.const import EntityCategory @@ -45,11 +47,54 @@ ATTR_LATEST_VERSION_FIRMWARE = "latest_version_firmware" +@dataclass(frozen=True, kw_only=True) +class ZWaveUpdateEntityDescription(UpdateEntityDescription): + """Class describing Z-Wave update entity.""" + + install_method: Callable[ + [ZWaveFirmwareUpdateEntity, FirmwareUpdateInfo], + Awaitable[FirmwareUpdateResult], + ] + progress_method: Callable[[ZWaveFirmwareUpdateEntity], Callable[[], None]] + finished_method: Callable[[ZWaveFirmwareUpdateEntity], Callable[[], None]] + + +CONTROLLER_UPDATE_ENTITY_DESCRIPTION = ZWaveUpdateEntityDescription( + key="controller_firmware_update", + install_method=( + lambda entity, firmware_update_info: entity.driver.async_firmware_update_otw( + update_info=firmware_update_info + ) + ), + progress_method=lambda entity: entity.driver.on( + "firmware update progress", entity.update_progress + ), + finished_method=lambda entity: entity.driver.on( + "firmware update finished", entity.update_finished + ), +) +NODE_UPDATE_ENTITY_DESCRIPTION = ZWaveUpdateEntityDescription( + key="node_firmware_update", + install_method=( + lambda entity, + firmware_update_info: entity.driver.controller.async_firmware_update_ota( + entity.node, cast(NodeFirmwareUpdateInfo, firmware_update_info) + ) + ), + progress_method=lambda entity: entity.node.on( + "firmware update progress", entity.update_progress + ), + finished_method=lambda entity: entity.node.on( + "firmware update finished", entity.update_finished + ), +) + + @dataclass -class ZWaveNodeFirmwareUpdateExtraStoredData(ExtraStoredData): +class ZWaveFirmwareUpdateExtraStoredData(ExtraStoredData): """Extra stored data for Z-Wave node firmware update entity.""" - latest_version_firmware: NodeFirmwareUpdateInfo | None + latest_version_firmware: FirmwareUpdateInfo | None def as_dict(self) -> dict[str, Any]: """Return a dict representation of the extra data.""" @@ -60,7 +105,7 @@ def as_dict(self) -> dict[str, Any]: } @classmethod - def from_dict(cls, data: dict[str, Any]) -> ZWaveNodeFirmwareUpdateExtraStoredData: + def from_dict(cls, data: dict[str, Any]) -> ZWaveFirmwareUpdateExtraStoredData: """Initialize the extra data from a dict.""" # If there was no firmware info stored, or if it's stale info, we don't restore # anything. @@ -70,7 +115,7 @@ def from_dict(cls, data: dict[str, Any]) -> ZWaveNodeFirmwareUpdateExtraStoredDa ): return cls(None) - return cls(NodeFirmwareUpdateInfo.from_dict(firmware_dict)) + return cls(FirmwareUpdateInfo.from_dict(firmware_dict)) async def async_setup_entry( @@ -92,7 +137,23 @@ def async_add_firmware_update_entity(node: ZwaveNode) -> None: delay = timedelta(minutes=(cnt[UPDATE_DELAY_STRING] * UPDATE_DELAY_INTERVAL)) driver = client.driver assert driver is not None # Driver is ready before platforms are loaded. - async_add_entities([ZWaveNodeFirmwareUpdate(driver, node, delay)]) + if node.is_controller_node: + # If the node is a controller, we create a controller firmware update entity + entity = ZWaveFirmwareUpdateEntity( + driver, + node, + delay=delay, + entity_description=CONTROLLER_UPDATE_ENTITY_DESCRIPTION, + ) + else: + # If the node is not a controller, we create a node firmware update entity + entity = ZWaveFirmwareUpdateEntity( + driver, + node, + delay=delay, + entity_description=NODE_UPDATE_ENTITY_DESCRIPTION, + ) + async_add_entities([entity]) config_entry.async_on_unload( async_dispatcher_connect( @@ -103,9 +164,12 @@ def async_add_firmware_update_entity(node: ZwaveNode) -> None: ) -class ZWaveNodeFirmwareUpdate(UpdateEntity): +class ZWaveFirmwareUpdateEntity(UpdateEntity): """Representation of a firmware update entity.""" + driver: Driver + entity_description: ZWaveUpdateEntityDescription + node: ZwaveNode _attr_entity_category = EntityCategory.CONFIG _attr_device_class = UpdateDeviceClass.FIRMWARE _attr_supported_features = ( @@ -116,17 +180,24 @@ class ZWaveNodeFirmwareUpdate(UpdateEntity): _attr_has_entity_name = True _attr_should_poll = False - def __init__(self, driver: Driver, node: ZwaveNode, delay: timedelta) -> None: + def __init__( + self, + driver: Driver, + node: ZwaveNode, + delay: timedelta, + entity_description: ZWaveUpdateEntityDescription, + ) -> None: """Initialize a Z-Wave device firmware update entity.""" self.driver = driver + self.entity_description = entity_description self.node = node - self._latest_version_firmware: NodeFirmwareUpdateInfo | None = None + self._latest_version_firmware: FirmwareUpdateInfo | None = None self._status_unsub: Callable[[], None] | None = None self._poll_unsub: Callable[[], None] | None = None self._progress_unsub: Callable[[], None] | None = None self._finished_unsub: Callable[[], None] | None = None self._finished_event = asyncio.Event() - self._result: NodeFirmwareUpdateResult | None = None + self._result: FirmwareUpdateResult | None = None self._delay: Final[timedelta] = delay # Entity class attributes @@ -138,9 +209,9 @@ def __init__(self, driver: Driver, node: ZwaveNode, delay: timedelta) -> None: self._attr_device_info = get_device_info(driver, node) @property - def extra_restore_state_data(self) -> ZWaveNodeFirmwareUpdateExtraStoredData: + def extra_restore_state_data(self) -> ZWaveFirmwareUpdateExtraStoredData: """Return ZWave Node Firmware Update specific state data to be restored.""" - return ZWaveNodeFirmwareUpdateExtraStoredData(self._latest_version_firmware) + return ZWaveFirmwareUpdateExtraStoredData(self._latest_version_firmware) @callback def _update_on_status_change(self, _: dict[str, Any]) -> None: @@ -149,9 +220,9 @@ def _update_on_status_change(self, _: dict[str, Any]) -> None: self.hass.async_create_task(self._async_update()) @callback - def _update_progress(self, event: dict[str, Any]) -> None: + def update_progress(self, event: dict[str, Any]) -> None: """Update install progress on event.""" - progress: NodeFirmwareUpdateProgress = event["firmware_update_progress"] + progress: FirmwareUpdateProgress = event["firmware_update_progress"] if not self._latest_version_firmware: return self._attr_in_progress = True @@ -159,9 +230,9 @@ def _update_progress(self, event: dict[str, Any]) -> None: self.async_write_ha_state() @callback - def _update_finished(self, event: dict[str, Any]) -> None: + def update_finished(self, event: dict[str, Any]) -> None: """Update install progress on event.""" - result: NodeFirmwareUpdateResult = event["firmware_update_finished"] + result: FirmwareUpdateResult = event["firmware_update_finished"] self._result = result self._finished_event.set() @@ -266,15 +337,11 @@ async def async_install( self._attr_update_percentage = None self.async_write_ha_state() - self._progress_unsub = self.node.on( - "firmware update progress", self._update_progress - ) - self._finished_unsub = self.node.on( - "firmware update finished", self._update_finished - ) + self._progress_unsub = self.entity_description.progress_method(self) + self._finished_unsub = self.entity_description.finished_method(self) try: - await self.driver.controller.async_firmware_update_ota(self.node, firmware) + await self.entity_description.install_method(self, firmware) except BaseZwaveJSServerError as err: self._unsub_firmware_events_and_reset_progress() raise HomeAssistantError(err) from err @@ -342,8 +409,7 @@ async def async_added_to_hass(self) -> None: is not None and (extra_data := await self.async_get_last_extra_data()) and ( - latest_version_firmware - := ZWaveNodeFirmwareUpdateExtraStoredData.from_dict( + latest_version_firmware := ZWaveFirmwareUpdateExtraStoredData.from_dict( extra_data.as_dict() ).latest_version_firmware ) diff --git a/tests/components/zwave_js/test_update.py b/tests/components/zwave_js/test_update.py index 17f154f4f78657..fbe0a8bbea7c5d 100644 --- a/tests/components/zwave_js/test_update.py +++ b/tests/components/zwave_js/test_update.py @@ -1,12 +1,17 @@ """Test the Z-Wave JS update entities.""" import asyncio +from copy import deepcopy from datetime import timedelta +from typing import Any +from unittest.mock import MagicMock from freezegun.api import FrozenDateTimeFactory import pytest from zwave_js_server.event import Event from zwave_js_server.exceptions import FailedZWaveCommand +from zwave_js_server.model.driver.firmware import DriverFirmwareUpdateStatus +from zwave_js_server.model.node import Node from zwave_js_server.model.node.firmware import NodeFirmwareUpdateStatus from homeassistant.components.update import ( @@ -22,11 +27,16 @@ SERVICE_SKIP, ) from homeassistant.components.zwave_js.const import DOMAIN, SERVICE_REFRESH_VALUE -from homeassistant.components.zwave_js.helpers import get_valueless_base_unique_id -from homeassistant.const import ATTR_ENTITY_ID, STATE_OFF, STATE_ON, STATE_UNKNOWN +from homeassistant.const import ( + ATTR_ENTITY_ID, + STATE_OFF, + STATE_ON, + STATE_UNKNOWN, + Platform, +) from homeassistant.core import CoreState, HomeAssistant, State from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import entity_registry as er +from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.util import dt as dt_util from tests.common import ( @@ -37,7 +47,8 @@ ) from tests.typing import WebSocketGenerator -UPDATE_ENTITY = "update.z_wave_thermostat_firmware" +NODE_UPDATE_ENTITY = "update.z_wave_thermostat_firmware" +CONTROLLER_UPDATE_ENTITY = "update.z_stick_gen5_usb_controller_firmware" LATEST_VERSION_FIRMWARE = { "version": "11.2.4", "changelog": "blah 2", @@ -112,26 +123,54 @@ } +@pytest.fixture +def platforms() -> list[str]: + """Fixture to specify platforms to test.""" + return [Platform.UPDATE] + + +@pytest.fixture(name="controller_state", autouse=True) +def controller_state_fixture( + controller_state: dict[str, Any], +) -> dict[str, Any]: + """Load the controller state fixture data.""" + controller_state = deepcopy(controller_state) + # Set the minimum SDK version that supports firmware updates for controllers. + controller_state["controller"]["sdkVersion"] = "6.50.0" + return controller_state + + +@pytest.mark.parametrize( + ("entity_id", "installed_version"), + [(CONTROLLER_UPDATE_ENTITY, "1.2"), (NODE_UPDATE_ENTITY, "10.7")], +) async def test_update_entity_states( hass: HomeAssistant, + device_registry: dr.DeviceRegistry, entity_registry: er.EntityRegistry, - client, - climate_radio_thermostat_ct100_plus_different_endpoints, - integration, + client: MagicMock, + climate_radio_thermostat_ct100_plus_different_endpoints: Node, + integration: MockConfigEntry, caplog: pytest.LogCaptureFixture, hass_ws_client: WebSocketGenerator, + entity_id: str, + installed_version: str, ) -> None: """Test update entity states.""" ws_client = await hass_ws_client(hass) - assert hass.states.get(UPDATE_ENTITY).state == STATE_OFF + assert client.driver.controller.sdk_version == "6.50.0" + + state = hass.states.get(entity_id) + assert state + assert state.state == STATE_OFF client.async_send_command.return_value = {"updates": []} async_fire_time_changed(hass, dt_util.utcnow() + timedelta(minutes=5, days=1)) await hass.async_block_till_done() - state = hass.states.get(UPDATE_ENTITY) + state = hass.states.get(entity_id) assert state assert state.state == STATE_OFF @@ -139,7 +178,7 @@ async def test_update_entity_states( { "id": 1, "type": "update/release_notes", - "entity_id": UPDATE_ENTITY, + "entity_id": entity_id, } ) result = await ws_client.receive_json() @@ -150,12 +189,12 @@ async def test_update_entity_states( async_fire_time_changed(hass, dt_util.utcnow() + timedelta(minutes=5, days=2)) await hass.async_block_till_done() - state = hass.states.get(UPDATE_ENTITY) + state = hass.states.get(entity_id) assert state assert state.state == STATE_ON attrs = state.attributes assert not attrs[ATTR_AUTO_UPDATE] - assert attrs[ATTR_INSTALLED_VERSION] == "10.7" + assert attrs[ATTR_INSTALLED_VERSION] == installed_version assert attrs[ATTR_IN_PROGRESS] is False assert attrs[ATTR_LATEST_VERSION] == "11.2.4" assert attrs[ATTR_RELEASE_URL] is None @@ -165,7 +204,7 @@ async def test_update_entity_states( { "id": 2, "type": "update/release_notes", - "entity_id": UPDATE_ENTITY, + "entity_id": entity_id, } ) result = await ws_client.receive_json() @@ -176,7 +215,7 @@ async def test_update_entity_states( DOMAIN, SERVICE_REFRESH_VALUE, { - ATTR_ENTITY_ID: UPDATE_ENTITY, + ATTR_ENTITY_ID: entity_id, }, blocking=True, ) @@ -188,31 +227,21 @@ async def test_update_entity_states( async_fire_time_changed(hass, dt_util.utcnow() + timedelta(minutes=5, days=3)) await hass.async_block_till_done() - state = hass.states.get(UPDATE_ENTITY) + state = hass.states.get(entity_id) assert state assert state.state == STATE_OFF - # Assert a node firmware update entity is not created for the controller - driver = client.driver - node = driver.controller.nodes[1] - assert node.is_controller_node - assert ( - entity_registry.async_get_entity_id( - DOMAIN, - "sensor", - f"{get_valueless_base_unique_id(driver, node)}.firmware_update", - ) - is None - ) - - client.async_send_command.reset_mock() - +@pytest.mark.parametrize( + "entity_id", + [CONTROLLER_UPDATE_ENTITY, NODE_UPDATE_ENTITY], +) async def test_update_entity_install_raises( hass: HomeAssistant, - client, - climate_radio_thermostat_ct100_plus_different_endpoints, - integration, + client: MagicMock, + climate_radio_thermostat_ct100_plus_different_endpoints: Node, + integration: MockConfigEntry, + entity_id: str, ) -> None: """Test update entity install raises exception.""" client.async_send_command.return_value = FIRMWARE_UPDATES @@ -228,7 +257,7 @@ async def test_update_entity_install_raises( UPDATE_DOMAIN, SERVICE_INSTALL, { - ATTR_ENTITY_ID: UPDATE_ENTITY, + ATTR_ENTITY_ID: entity_id, }, blocking=True, ) @@ -236,9 +265,9 @@ async def test_update_entity_install_raises( async def test_update_entity_sleep( hass: HomeAssistant, - client, - zen_31, - integration, + client: MagicMock, + zen_31: Node, + integration: MockConfigEntry, ) -> None: """Test update occurs when device is asleep after it wakes up.""" event = Event( @@ -253,8 +282,15 @@ async def test_update_entity_sleep( async_fire_time_changed(hass, dt_util.utcnow() + timedelta(minutes=5, days=1)) await hass.async_block_till_done() - # Because node is asleep we shouldn't attempt to check for firmware updates - assert len(client.async_send_command.call_args_list) == 0 + # Two nodes in total, the controller node and the zen_31 node. + # The zen_31 node is asleep, + # so we should only check for updates for the controller node. + assert client.async_send_command.call_count == 1 + args = client.async_send_command.call_args[0][0] + assert args["command"] == "controller.get_available_firmware_updates" + assert args["nodeId"] == 1 + + client.async_send_command.reset_mock() event = Event( "wake up", @@ -263,19 +299,20 @@ async def test_update_entity_sleep( zen_31.receive_event(event) await hass.async_block_till_done() - # Now that the node is up we can check for updates - assert len(client.async_send_command.call_args_list) > 0 - - args = client.async_send_command.call_args_list[0][0][0] + # Now that the zen_31 node is awake we can check for updates for it. + # The controller node has already been checked, + # so won't get another check now. + assert client.async_send_command.call_count == 1 + args = client.async_send_command.call_args[0][0] assert args["command"] == "controller.get_available_firmware_updates" - assert args["nodeId"] == zen_31.node_id + assert args["nodeId"] == 94 async def test_update_entity_dead( hass: HomeAssistant, - client, - zen_31, - integration, + client: MagicMock, + zen_31: Node, + integration: MockConfigEntry, ) -> None: """Test update occurs even when device is dead.""" event = Event( @@ -290,18 +327,24 @@ async def test_update_entity_dead( async_fire_time_changed(hass, dt_util.utcnow() + timedelta(minutes=5, days=1)) await hass.async_block_till_done() - # Checking for firmware updates should proceed even for dead nodes - assert len(client.async_send_command.call_args_list) > 0 + # Two nodes in total, the controller node and the zen_31 node. + # Checking for firmware updates should proceed even for dead nodes. + assert client.async_send_command.call_count == 2 + calls = sorted( + client.async_send_command.call_args_list, key=lambda call: call[0][0]["nodeId"] + ) - args = client.async_send_command.call_args_list[0][0][0] - assert args["command"] == "controller.get_available_firmware_updates" - assert args["nodeId"] == zen_31.node_id + node_ids = (1, 94) + for node_id, call in zip(node_ids, calls, strict=True): + args = call[0][0] + assert args["command"] == "controller.get_available_firmware_updates" + assert args["nodeId"] == node_id async def test_update_entity_ha_not_running( hass: HomeAssistant, - client, - zen_31, + client: MagicMock, + zen_31: Node, hass_ws_client: WebSocketGenerator, ) -> None: """Test update occurs only after HA is running.""" @@ -314,81 +357,170 @@ async def test_update_entity_ha_not_running( await hass.config_entries.async_setup(entry.entry_id) await hass.async_block_till_done() - assert len(client.async_send_command.call_args_list) == 4 + client.async_send_command.reset_mock() + assert client.async_send_command.call_count == 0 await hass.async_start() await hass.async_block_till_done() - assert len(client.async_send_command.call_args_list) == 4 + assert client.async_send_command.call_count == 0 - # Update should be delayed by a day because HA is not running + # Update should be delayed by a day because Home Assistant is not running hass.set_state(CoreState.starting) async_fire_time_changed(hass, dt_util.utcnow() + timedelta(minutes=5)) await hass.async_block_till_done() - assert len(client.async_send_command.call_args_list) == 4 + assert client.async_send_command.call_count == 0 hass.set_state(CoreState.running) async_fire_time_changed(hass, dt_util.utcnow() + timedelta(minutes=5, days=1)) await hass.async_block_till_done() - assert len(client.async_send_command.call_args_list) == 5 - args = client.async_send_command.call_args_list[4][0][0] - assert args["command"] == "controller.get_available_firmware_updates" - assert args["nodeId"] == zen_31.node_id + # Two nodes in total, the controller node and the zen_31 node. + assert client.async_send_command.call_count == 2 + calls = sorted( + client.async_send_command.call_args_list, key=lambda call: call[0][0]["nodeId"] + ) + + node_ids = (1, 94) + for node_id, call in zip(node_ids, calls, strict=True): + args = call[0][0] + assert args["command"] == "controller.get_available_firmware_updates" + assert args["nodeId"] == node_id async def test_update_entity_update_failure( hass: HomeAssistant, - client, - climate_radio_thermostat_ct100_plus_different_endpoints, - integration, + client: MagicMock, + climate_radio_thermostat_ct100_plus_different_endpoints: Node, + integration: MockConfigEntry, ) -> None: """Test update entity update failed.""" - assert len(client.async_send_command.call_args_list) == 0 + assert client.async_send_command.call_count == 0 client.async_send_command.side_effect = FailedZWaveCommand("test", 260, "test") async_fire_time_changed(hass, dt_util.utcnow() + timedelta(minutes=5, days=1)) await hass.async_block_till_done() - state = hass.states.get(UPDATE_ENTITY) - assert state - assert state.state == STATE_OFF - assert len(client.async_send_command.call_args_list) == 1 - args = client.async_send_command.call_args_list[0][0][0] - assert args["command"] == "controller.get_available_firmware_updates" - assert ( - args["nodeId"] - == climate_radio_thermostat_ct100_plus_different_endpoints.node_id - ) + entity_ids = (CONTROLLER_UPDATE_ENTITY, NODE_UPDATE_ENTITY) + for entity_id in entity_ids: + state = hass.states.get(entity_id) + assert state + assert state.state == STATE_OFF + assert client.async_send_command.call_count == 2 + calls = sorted( + client.async_send_command.call_args_list, key=lambda call: call[0][0]["nodeId"] + ) + node_ids = (1, 26) + for node_id, call in zip(node_ids, calls, strict=True): + args = call[0][0] + assert args["command"] == "controller.get_available_firmware_updates" + assert args["nodeId"] == node_id + + +@pytest.mark.parametrize( + ( + "entity_id", + "installed_version", + "install_result", + "progress_event", + "finished_event", + ), + [ + ( + CONTROLLER_UPDATE_ENTITY, + "1.2", + {"status": 255, "success": True}, + Event( + type="firmware update progress", + data={ + "source": "driver", + "event": "firmware update progress", + "progress": { + "sentFragments": 1, + "totalFragments": 20, + "progress": 5.0, + }, + }, + ), + Event( + type="firmware update finished", + data={ + "source": "driver", + "event": "firmware update finished", + "result": { + "status": DriverFirmwareUpdateStatus.OK, + "success": True, + }, + }, + ), + ), + ( + NODE_UPDATE_ENTITY, + "10.7", + {"status": 254, "success": True, "reInterview": False}, + Event( + type="firmware update progress", + data={ + "source": "node", + "event": "firmware update progress", + "nodeId": 26, + "progress": { + "currentFile": 1, + "totalFiles": 1, + "sentFragments": 1, + "totalFragments": 20, + "progress": 5.0, + }, + }, + ), + Event( + type="firmware update finished", + data={ + "source": "node", + "event": "firmware update finished", + "nodeId": 26, + "result": { + "status": NodeFirmwareUpdateStatus.OK_NO_RESTART, + "success": True, + "reInterview": False, + }, + }, + ), + ), + ], +) async def test_update_entity_progress( hass: HomeAssistant, - client, - climate_radio_thermostat_ct100_plus_different_endpoints, - integration, + client: MagicMock, + climate_radio_thermostat_ct100_plus_different_endpoints: Node, + integration: MockConfigEntry, + entity_id: str, + installed_version: str, + install_result: dict[str, Any], + progress_event: Event, + finished_event: Event, ) -> None: """Test update entity progress.""" - node = climate_radio_thermostat_ct100_plus_different_endpoints client.async_send_command.return_value = FIRMWARE_UPDATES + driver = client.driver async_fire_time_changed(hass, dt_util.utcnow() + timedelta(minutes=5, days=1)) await hass.async_block_till_done() - state = hass.states.get(UPDATE_ENTITY) + state = hass.states.get(entity_id) assert state assert state.state == STATE_ON attrs = state.attributes - assert attrs[ATTR_INSTALLED_VERSION] == "10.7" + assert attrs[ATTR_INSTALLED_VERSION] == installed_version assert attrs[ATTR_LATEST_VERSION] == "11.2.4" client.async_send_command.reset_mock() - client.async_send_command.return_value = { - "result": {"status": 2, "success": False, "reInterview": False} - } + client.async_send_command.return_value = {"result": install_result} # Test successful install call without a version install_task = hass.async_create_task( @@ -396,64 +528,36 @@ async def test_update_entity_progress( UPDATE_DOMAIN, SERVICE_INSTALL, { - ATTR_ENTITY_ID: UPDATE_ENTITY, + ATTR_ENTITY_ID: entity_id, }, blocking=True, ) ) # Sleep so that task starts - await asyncio.sleep(0.1) + await asyncio.sleep(0.05) - state = hass.states.get(UPDATE_ENTITY) + state = hass.states.get(entity_id) assert state attrs = state.attributes assert attrs[ATTR_IN_PROGRESS] is True assert attrs[ATTR_UPDATE_PERCENTAGE] is None - event = Event( - type="firmware update progress", - data={ - "source": "node", - "event": "firmware update progress", - "nodeId": node.node_id, - "progress": { - "currentFile": 1, - "totalFiles": 1, - "sentFragments": 1, - "totalFragments": 20, - "progress": 5.0, - }, - }, - ) - node.receive_event(event) + driver.receive_event(progress_event) + await asyncio.sleep(0.05) # Validate that the progress is updated - state = hass.states.get(UPDATE_ENTITY) + state = hass.states.get(entity_id) assert state attrs = state.attributes assert attrs[ATTR_IN_PROGRESS] is True assert attrs[ATTR_UPDATE_PERCENTAGE] == 5 - event = Event( - type="firmware update finished", - data={ - "source": "node", - "event": "firmware update finished", - "nodeId": node.node_id, - "result": { - "status": NodeFirmwareUpdateStatus.OK_NO_RESTART, - "success": True, - "reInterview": False, - }, - }, - ) - - node.receive_event(event) + driver.receive_event(finished_event) await hass.async_block_till_done() # Validate that progress is reset and entity reflects new version - state = hass.states.get(UPDATE_ENTITY) + state = hass.states.get(entity_id) assert state attrs = state.attributes assert attrs[ATTR_IN_PROGRESS] is False @@ -465,31 +569,106 @@ async def test_update_entity_progress( await install_task +@pytest.mark.parametrize( + ( + "entity_id", + "installed_version", + "install_result", + "progress_event", + "finished_event", + ), + [ + ( + CONTROLLER_UPDATE_ENTITY, + "1.2", + {"status": 0, "success": False}, + Event( + type="firmware update progress", + data={ + "source": "driver", + "event": "firmware update progress", + "progress": { + "sentFragments": 1, + "totalFragments": 20, + "progress": 5.0, + }, + }, + ), + Event( + type="firmware update finished", + data={ + "source": "driver", + "event": "firmware update finished", + "result": { + "status": DriverFirmwareUpdateStatus.ERROR_TIMEOUT, + "success": False, + }, + }, + ), + ), + ( + NODE_UPDATE_ENTITY, + "10.7", + {"status": -1, "success": False, "reInterview": False}, + Event( + type="firmware update progress", + data={ + "source": "node", + "event": "firmware update progress", + "nodeId": 26, + "progress": { + "currentFile": 1, + "totalFiles": 1, + "sentFragments": 1, + "totalFragments": 20, + "progress": 5.0, + }, + }, + ), + Event( + type="firmware update finished", + data={ + "source": "node", + "event": "firmware update finished", + "nodeId": 26, + "result": { + "status": NodeFirmwareUpdateStatus.ERROR_TIMEOUT, + "success": False, + "reInterview": False, + }, + }, + ), + ), + ], +) async def test_update_entity_install_failed( hass: HomeAssistant, - client, - climate_radio_thermostat_ct100_plus_different_endpoints, - integration, + client: MagicMock, + climate_radio_thermostat_ct100_plus_different_endpoints: Node, + integration: MockConfigEntry, caplog: pytest.LogCaptureFixture, + entity_id: str, + installed_version: str, + install_result: dict[str, Any], + progress_event: Event, + finished_event: Event, ) -> None: """Test update entity install returns error status.""" - node = climate_radio_thermostat_ct100_plus_different_endpoints + driver = client.driver client.async_send_command.return_value = FIRMWARE_UPDATES async_fire_time_changed(hass, dt_util.utcnow() + timedelta(minutes=5, days=1)) await hass.async_block_till_done() - state = hass.states.get(UPDATE_ENTITY) + state = hass.states.get(entity_id) assert state assert state.state == STATE_ON attrs = state.attributes - assert attrs[ATTR_INSTALLED_VERSION] == "10.7" + assert attrs[ATTR_INSTALLED_VERSION] == installed_version assert attrs[ATTR_LATEST_VERSION] == "11.2.4" client.async_send_command.reset_mock() - client.async_send_command.return_value = { - "result": {"status": 2, "success": False, "reInterview": False} - } + client.async_send_command.return_value = {"result": install_result} # Test install call - we expect it to finish fail install_task = hass.async_create_task( @@ -497,63 +676,35 @@ async def test_update_entity_install_failed( UPDATE_DOMAIN, SERVICE_INSTALL, { - ATTR_ENTITY_ID: UPDATE_ENTITY, + ATTR_ENTITY_ID: entity_id, }, blocking=True, ) ) # Sleep so that task starts - await asyncio.sleep(0.1) + await asyncio.sleep(0.05) - event = Event( - type="firmware update progress", - data={ - "source": "node", - "event": "firmware update progress", - "nodeId": node.node_id, - "progress": { - "currentFile": 1, - "totalFiles": 1, - "sentFragments": 1, - "totalFragments": 20, - "progress": 5.0, - }, - }, - ) - node.receive_event(event) + driver.receive_event(progress_event) + await asyncio.sleep(0.05) # Validate that the progress is updated - state = hass.states.get(UPDATE_ENTITY) + state = hass.states.get(entity_id) assert state attrs = state.attributes assert attrs[ATTR_IN_PROGRESS] is True assert attrs[ATTR_UPDATE_PERCENTAGE] == 5 - event = Event( - type="firmware update finished", - data={ - "source": "node", - "event": "firmware update finished", - "nodeId": node.node_id, - "result": { - "status": NodeFirmwareUpdateStatus.ERROR_TIMEOUT, - "success": False, - "reInterview": False, - }, - }, - ) - - node.receive_event(event) + driver.receive_event(finished_event) await hass.async_block_till_done() # Validate that progress is reset and entity reflects old version - state = hass.states.get(UPDATE_ENTITY) + state = hass.states.get(entity_id) assert state attrs = state.attributes assert attrs[ATTR_IN_PROGRESS] is False assert attrs[ATTR_UPDATE_PERCENTAGE] is None - assert attrs[ATTR_INSTALLED_VERSION] == "10.7" + assert attrs[ATTR_INSTALLED_VERSION] == installed_version assert attrs[ATTR_LATEST_VERSION] == "11.2.4" assert state.state == STATE_ON @@ -562,21 +713,30 @@ async def test_update_entity_install_failed( await install_task +@pytest.mark.parametrize( + ("entity_id", "installed_version"), + [(CONTROLLER_UPDATE_ENTITY, "1.2"), (NODE_UPDATE_ENTITY, "10.7")], +) async def test_update_entity_reload( hass: HomeAssistant, - client, - climate_radio_thermostat_ct100_plus_different_endpoints, - integration, + client: MagicMock, + climate_radio_thermostat_ct100_plus_different_endpoints: Node, + integration: MockConfigEntry, + entity_id: str, + installed_version: str, ) -> None: """Test update entity maintains state after reload.""" - assert hass.states.get(UPDATE_ENTITY).state == STATE_OFF + config_entry = integration + state = hass.states.get(entity_id) + assert state + assert state.state == STATE_OFF client.async_send_command.return_value = {"updates": []} async_fire_time_changed(hass, dt_util.utcnow() + timedelta(minutes=5, days=1)) await hass.async_block_till_done() - state = hass.states.get(UPDATE_ENTITY) + state = hass.states.get(entity_id) assert state assert state.state == STATE_OFF @@ -585,12 +745,12 @@ async def test_update_entity_reload( async_fire_time_changed(hass, dt_util.utcnow() + timedelta(minutes=5, days=2)) await hass.async_block_till_done() - state = hass.states.get(UPDATE_ENTITY) + state = hass.states.get(entity_id) assert state assert state.state == STATE_ON attrs = state.attributes assert not attrs[ATTR_AUTO_UPDATE] - assert attrs[ATTR_INSTALLED_VERSION] == "10.7" + assert attrs[ATTR_INSTALLED_VERSION] == installed_version assert attrs[ATTR_IN_PROGRESS] is False assert attrs[ATTR_UPDATE_PERCENTAGE] is None assert attrs[ATTR_LATEST_VERSION] == "11.2.4" @@ -600,24 +760,24 @@ async def test_update_entity_reload( UPDATE_DOMAIN, SERVICE_SKIP, { - ATTR_ENTITY_ID: UPDATE_ENTITY, + ATTR_ENTITY_ID: entity_id, }, blocking=True, ) - state = hass.states.get(UPDATE_ENTITY) + state = hass.states.get(entity_id) assert state assert state.state == STATE_OFF assert state.attributes[ATTR_SKIPPED_VERSION] == "11.2.4" - await hass.config_entries.async_reload(integration.entry_id) + await hass.config_entries.async_reload(config_entry.entry_id) await hass.async_block_till_done() # Trigger another update and make sure the skipped version is still skipped async_fire_time_changed(hass, dt_util.utcnow() + timedelta(minutes=5, days=4)) await hass.async_block_till_done() - state = hass.states.get(UPDATE_ENTITY) + state = hass.states.get(entity_id) assert state assert state.state == STATE_OFF assert state.attributes[ATTR_SKIPPED_VERSION] == "11.2.4" @@ -625,9 +785,9 @@ async def test_update_entity_reload( async def test_update_entity_delay( hass: HomeAssistant, - client, - ge_in_wall_dimmer_switch, - zen_31, + client: MagicMock, + ge_in_wall_dimmer_switch: Node, + zen_31: Node, hass_ws_client: WebSocketGenerator, freezer: FrozenDateTimeFactory, ) -> None: @@ -641,12 +801,13 @@ async def test_update_entity_delay( await hass.config_entries.async_setup(entry.entry_id) await hass.async_block_till_done() - assert len(client.async_send_command.call_args_list) == 6 + client.async_send_command.reset_mock() + assert client.async_send_command.call_count == 0 await hass.async_start() await hass.async_block_till_done() - assert len(client.async_send_command.call_args_list) == 6 + assert client.async_send_command.call_count == 0 update_interval = timedelta(minutes=5) freezer.tick(update_interval) @@ -655,8 +816,17 @@ async def test_update_entity_delay( nodes: set[int] = set() - assert len(client.async_send_command.call_args_list) == 7 - args = client.async_send_command.call_args_list[6][0][0] + assert client.async_send_command.call_count == 1 + args = client.async_send_command.call_args[0][0] + assert args["command"] == "controller.get_available_firmware_updates" + nodes.add(args["nodeId"]) + + freezer.tick(update_interval) + async_fire_time_changed(hass) + await hass.async_block_till_done() + + assert client.async_send_command.call_count == 2 + args = client.async_send_command.call_args[0][0] assert args["command"] == "controller.get_available_firmware_updates" nodes.add(args["nodeId"]) @@ -664,30 +834,36 @@ async def test_update_entity_delay( async_fire_time_changed(hass) await hass.async_block_till_done() - assert len(client.async_send_command.call_args_list) == 8 - args = client.async_send_command.call_args_list[7][0][0] + assert client.async_send_command.call_count == 3 + args = client.async_send_command.call_args[0][0] assert args["command"] == "controller.get_available_firmware_updates" nodes.add(args["nodeId"]) - assert len(nodes) == 2 - assert nodes == {ge_in_wall_dimmer_switch.node_id, zen_31.node_id} + assert len(nodes) == 3 + assert nodes == {1, ge_in_wall_dimmer_switch.node_id, zen_31.node_id} +@pytest.mark.parametrize( + ("entity_id", "installed_version"), + [(CONTROLLER_UPDATE_ENTITY, "1.2"), (NODE_UPDATE_ENTITY, "10.7")], +) async def test_update_entity_partial_restore_data( hass: HomeAssistant, - client, - climate_radio_thermostat_ct100_plus_different_endpoints, + client: MagicMock, + climate_radio_thermostat_ct100_plus_different_endpoints: Node, hass_ws_client: WebSocketGenerator, + entity_id: str, + installed_version: str, ) -> None: """Test update entity with partial restore data resets state.""" mock_restore_cache( hass, [ State( - UPDATE_ENTITY, + entity_id, STATE_OFF, { - ATTR_INSTALLED_VERSION: "10.7", + ATTR_INSTALLED_VERSION: installed_version, ATTR_LATEST_VERSION: "11.2.4", ATTR_SKIPPED_VERSION: "11.2.4", }, @@ -699,16 +875,22 @@ async def test_update_entity_partial_restore_data( await hass.config_entries.async_setup(entry.entry_id) await hass.async_block_till_done() - state = hass.states.get(UPDATE_ENTITY) + state = hass.states.get(entity_id) assert state assert state.state == STATE_UNKNOWN +@pytest.mark.parametrize( + ("entity_id", "installed_version"), + [(CONTROLLER_UPDATE_ENTITY, "1.2"), (NODE_UPDATE_ENTITY, "10.7")], +) async def test_update_entity_partial_restore_data_2( hass: HomeAssistant, - client, - climate_radio_thermostat_ct100_plus_different_endpoints, + client: MagicMock, + climate_radio_thermostat_ct100_plus_different_endpoints: Node, hass_ws_client: WebSocketGenerator, + entity_id: str, + installed_version: str, ) -> None: """Test second scenario where update entity has partial restore data.""" mock_restore_cache_with_extra_data( @@ -716,10 +898,10 @@ async def test_update_entity_partial_restore_data_2( [ ( State( - UPDATE_ENTITY, + entity_id, STATE_ON, { - ATTR_INSTALLED_VERSION: "10.7", + ATTR_INSTALLED_VERSION: installed_version, ATTR_LATEST_VERSION: "10.8", ATTR_SKIPPED_VERSION: None, }, @@ -733,18 +915,24 @@ async def test_update_entity_partial_restore_data_2( await hass.config_entries.async_setup(entry.entry_id) await hass.async_block_till_done() - state = hass.states.get(UPDATE_ENTITY) + state = hass.states.get(entity_id) assert state assert state.state == STATE_UNKNOWN assert state.attributes[ATTR_SKIPPED_VERSION] is None assert state.attributes[ATTR_LATEST_VERSION] is None +@pytest.mark.parametrize( + ("entity_id", "installed_version"), + [(CONTROLLER_UPDATE_ENTITY, "1.2"), (NODE_UPDATE_ENTITY, "10.7")], +) async def test_update_entity_full_restore_data_skipped_version( hass: HomeAssistant, - client, - climate_radio_thermostat_ct100_plus_different_endpoints, + client: MagicMock, + climate_radio_thermostat_ct100_plus_different_endpoints: Node, hass_ws_client: WebSocketGenerator, + entity_id: str, + installed_version: str, ) -> None: """Test update entity with full restore data (skipped version) restores state.""" mock_restore_cache_with_extra_data( @@ -752,10 +940,10 @@ async def test_update_entity_full_restore_data_skipped_version( [ ( State( - UPDATE_ENTITY, + entity_id, STATE_OFF, { - ATTR_INSTALLED_VERSION: "10.7", + ATTR_INSTALLED_VERSION: installed_version, ATTR_LATEST_VERSION: "11.2.4", ATTR_SKIPPED_VERSION: "11.2.4", }, @@ -769,18 +957,44 @@ async def test_update_entity_full_restore_data_skipped_version( await hass.config_entries.async_setup(entry.entry_id) await hass.async_block_till_done() - state = hass.states.get(UPDATE_ENTITY) + state = hass.states.get(entity_id) assert state assert state.state == STATE_OFF assert state.attributes[ATTR_SKIPPED_VERSION] == "11.2.4" assert state.attributes[ATTR_LATEST_VERSION] == "11.2.4" +@pytest.mark.parametrize( + ("entity_id", "installed_version", "install_result", "install_command_params"), + [ + ( + CONTROLLER_UPDATE_ENTITY, + "1.2", + {"status": 255, "success": True}, + { + "command": "driver.firmware_update_otw", + }, + ), + ( + NODE_UPDATE_ENTITY, + "10.7", + {"status": 255, "success": True, "reInterview": False}, + { + "command": "controller.firmware_update_ota", + "nodeId": 26, + }, + ), + ], +) async def test_update_entity_full_restore_data_update_available( hass: HomeAssistant, - client, - climate_radio_thermostat_ct100_plus_different_endpoints, + client: MagicMock, + climate_radio_thermostat_ct100_plus_different_endpoints: Node, hass_ws_client: WebSocketGenerator, + entity_id: str, + installed_version: str, + install_result: dict[str, Any], + install_command_params: dict[str, Any], ) -> None: """Test update entity with full restore data (update available) restores state.""" mock_restore_cache_with_extra_data( @@ -788,10 +1002,10 @@ async def test_update_entity_full_restore_data_update_available( [ ( State( - UPDATE_ENTITY, + entity_id, STATE_OFF, { - ATTR_INSTALLED_VERSION: "10.7", + ATTR_INSTALLED_VERSION: installed_version, ATTR_LATEST_VERSION: "11.2.4", ATTR_SKIPPED_VERSION: None, }, @@ -805,15 +1019,14 @@ async def test_update_entity_full_restore_data_update_available( await hass.config_entries.async_setup(entry.entry_id) await hass.async_block_till_done() - state = hass.states.get(UPDATE_ENTITY) + state = hass.states.get(entity_id) assert state assert state.state == STATE_ON assert state.attributes[ATTR_SKIPPED_VERSION] is None assert state.attributes[ATTR_LATEST_VERSION] == "11.2.4" - client.async_send_command.return_value = { - "result": {"status": 255, "success": True, "reInterview": False} - } + client.async_send_command.reset_mock() + client.async_send_command.return_value = {"result": install_result} # Test successful install call without a version install_task = hass.async_create_task( @@ -821,25 +1034,24 @@ async def test_update_entity_full_restore_data_update_available( UPDATE_DOMAIN, SERVICE_INSTALL, { - ATTR_ENTITY_ID: UPDATE_ENTITY, + ATTR_ENTITY_ID: entity_id, }, blocking=True, ) ) # Sleep so that task starts - await asyncio.sleep(0.1) + await asyncio.sleep(0.05) - state = hass.states.get(UPDATE_ENTITY) + state = hass.states.get(entity_id) assert state attrs = state.attributes assert attrs[ATTR_IN_PROGRESS] is True assert attrs[ATTR_UPDATE_PERCENTAGE] is None - assert len(client.async_send_command.call_args_list) == 5 - assert client.async_send_command.call_args_list[4][0][0] == { - "command": "controller.firmware_update_ota", - "nodeId": climate_radio_thermostat_ct100_plus_different_endpoints.node_id, + assert client.async_send_command.call_count == 1 + assert client.async_send_command.call_args[0][0] == { + **install_command_params, "updateInfo": { "version": "11.2.4", "changelog": "blah 2", @@ -862,11 +1074,18 @@ async def test_update_entity_full_restore_data_update_available( install_task.cancel() +@pytest.mark.parametrize( + ("entity_id", "installed_version", "latest_version"), + [(CONTROLLER_UPDATE_ENTITY, "1.2", "1.2"), (NODE_UPDATE_ENTITY, "10.7", "10.7")], +) async def test_update_entity_full_restore_data_no_update_available( hass: HomeAssistant, - client, - climate_radio_thermostat_ct100_plus_different_endpoints, + client: MagicMock, + climate_radio_thermostat_ct100_plus_different_endpoints: Node, hass_ws_client: WebSocketGenerator, + entity_id: str, + installed_version: str, + latest_version: str, ) -> None: """Test entity with full restore data (no update available) restores state.""" mock_restore_cache_with_extra_data( @@ -874,11 +1093,11 @@ async def test_update_entity_full_restore_data_no_update_available( [ ( State( - UPDATE_ENTITY, + entity_id, STATE_OFF, { - ATTR_INSTALLED_VERSION: "10.7", - ATTR_LATEST_VERSION: "10.7", + ATTR_INSTALLED_VERSION: installed_version, + ATTR_LATEST_VERSION: latest_version, ATTR_SKIPPED_VERSION: None, }, ), @@ -891,18 +1110,25 @@ async def test_update_entity_full_restore_data_no_update_available( await hass.config_entries.async_setup(entry.entry_id) await hass.async_block_till_done() - state = hass.states.get(UPDATE_ENTITY) + state = hass.states.get(entity_id) assert state assert state.state == STATE_OFF assert state.attributes[ATTR_SKIPPED_VERSION] is None - assert state.attributes[ATTR_LATEST_VERSION] == "10.7" + assert state.attributes[ATTR_LATEST_VERSION] == latest_version +@pytest.mark.parametrize( + ("entity_id", "installed_version", "latest_version"), + [(CONTROLLER_UPDATE_ENTITY, "1.2", "1.2"), (NODE_UPDATE_ENTITY, "10.7", "10.7")], +) async def test_update_entity_no_latest_version( hass: HomeAssistant, - client, - climate_radio_thermostat_ct100_plus_different_endpoints, + client: MagicMock, + climate_radio_thermostat_ct100_plus_different_endpoints: Node, hass_ws_client: WebSocketGenerator, + entity_id: str, + installed_version: str, + latest_version: str, ) -> None: """Test entity with no `latest_version` attr restores state.""" mock_restore_cache_with_extra_data( @@ -910,10 +1136,10 @@ async def test_update_entity_no_latest_version( [ ( State( - UPDATE_ENTITY, + entity_id, STATE_OFF, { - ATTR_INSTALLED_VERSION: "10.7", + ATTR_INSTALLED_VERSION: installed_version, ATTR_LATEST_VERSION: None, ATTR_SKIPPED_VERSION: None, }, @@ -927,24 +1153,33 @@ async def test_update_entity_no_latest_version( await hass.config_entries.async_setup(entry.entry_id) await hass.async_block_till_done() - state = hass.states.get(UPDATE_ENTITY) + state = hass.states.get(entity_id) assert state assert state.state == STATE_OFF assert state.attributes[ATTR_SKIPPED_VERSION] is None - assert state.attributes[ATTR_LATEST_VERSION] == "10.7" + assert state.attributes[ATTR_LATEST_VERSION] == latest_version async def test_update_entity_unload_asleep_node( - hass: HomeAssistant, client, wallmote_central_scene, integration + hass: HomeAssistant, + client: MagicMock, + wallmote_central_scene: Node, + integration: MockConfigEntry, ) -> None: """Test unloading config entry after attempting an update for an asleep node.""" - assert len(client.async_send_command.call_args_list) == 0 + config_entry = integration + assert client.async_send_command.call_count == 0 + + client.async_send_command.reset_mock() + client.async_send_command.return_value = {"updates": []} async_fire_time_changed(hass, dt_util.utcnow() + timedelta(minutes=5, days=1)) await hass.async_block_till_done() - assert len(client.async_send_command.call_args_list) == 0 - assert len(wallmote_central_scene._listeners["wake up"]) == 2 + # Once call completed for the (awake) controller node. + assert client.async_send_command.call_count == 1 + assert len(wallmote_central_scene._listeners["wake up"]) == 1 - await hass.config_entries.async_unload(integration.entry_id) + await hass.config_entries.async_unload(config_entry.entry_id) + assert client.async_send_command.call_count == 1 assert len(wallmote_central_scene._listeners["wake up"]) == 0 From 9d66b19c0328de6047d4feefa0d3700c8eb5bd2c Mon Sep 17 00:00:00 2001 From: Petro31 <35082313+Petro31@users.noreply.github.com> Date: Wed, 30 Jul 2025 05:20:04 -0400 Subject: [PATCH 8/8] Add assumed optimistic to template number entities (#148499) --- homeassistant/components/template/number.py | 220 ++++++++++---------- tests/components/template/test_number.py | 103 +++++++-- 2 files changed, 198 insertions(+), 125 deletions(-) diff --git a/homeassistant/components/template/number.py b/homeassistant/components/template/number.py index 31a6338f594dc1..362a7e9d5c574c 100644 --- a/homeassistant/components/template/number.py +++ b/homeassistant/components/template/number.py @@ -17,14 +17,9 @@ NumberEntity, ) from homeassistant.config_entries import ConfigEntry -from homeassistant.const import ( - CONF_NAME, - CONF_OPTIMISTIC, - CONF_STATE, - CONF_UNIT_OF_MEASUREMENT, -) +from homeassistant.const import CONF_NAME, CONF_STATE, CONF_UNIT_OF_MEASUREMENT from homeassistant.core import HomeAssistant, callback -from homeassistant.helpers import config_validation as cv +from homeassistant.helpers import config_validation as cv, template from homeassistant.helpers.entity_platform import ( AddConfigEntryEntitiesCallback, AddEntitiesCallback, @@ -33,6 +28,7 @@ from . import TriggerUpdateCoordinator from .const import CONF_MAX, CONF_MIN, CONF_STEP, DOMAIN +from .entity import AbstractTemplateEntity from .helpers import ( async_setup_template_entry, async_setup_template_platform, @@ -40,6 +36,7 @@ ) from .template_entity import ( TEMPLATE_ENTITY_COMMON_CONFIG_ENTRY_SCHEMA, + TEMPLATE_ENTITY_OPTIMISTIC_SCHEMA, TemplateEntity, make_template_entity_common_modern_schema, ) @@ -57,21 +54,15 @@ vol.Optional(CONF_MAX, default=DEFAULT_MAX_VALUE): cv.template, vol.Optional(CONF_MIN, default=DEFAULT_MIN_VALUE): cv.template, vol.Required(CONF_SET_VALUE): cv.SCRIPT_SCHEMA, - vol.Required(CONF_STATE): cv.template, - vol.Required(CONF_STEP): cv.template, + vol.Optional(CONF_STATE): cv.template, + vol.Optional(CONF_STEP, default=DEFAULT_STEP): cv.template, vol.Optional(CONF_UNIT_OF_MEASUREMENT): cv.string, } -) +).extend(make_template_entity_common_modern_schema(DEFAULT_NAME).schema) -NUMBER_YAML_SCHEMA = ( - vol.Schema( - { - vol.Optional(CONF_OPTIMISTIC, default=DEFAULT_OPTIMISTIC): cv.boolean, - } - ) - .extend(make_template_entity_common_modern_schema(DEFAULT_NAME).schema) - .extend(NUMBER_COMMON_SCHEMA.schema) -) +NUMBER_YAML_SCHEMA = NUMBER_COMMON_SCHEMA.extend( + TEMPLATE_ENTITY_OPTIMISTIC_SCHEMA +).extend(make_template_entity_common_modern_schema(DEFAULT_NAME).schema) NUMBER_CONFIG_ENTRY_SCHEMA = NUMBER_COMMON_SCHEMA.extend( TEMPLATE_ENTITY_COMMON_CONFIG_ENTRY_SCHEMA.schema @@ -121,90 +112,97 @@ def async_create_preview_number( ) -class StateNumberEntity(TemplateEntity, NumberEntity): +class AbstractTemplateNumber(AbstractTemplateEntity, NumberEntity): + """Representation of a template number features.""" + + _entity_id_format = ENTITY_ID_FORMAT + _optimistic_entity = True + + # The super init is not called because TemplateEntity and TriggerEntity will call AbstractTemplateEntity.__init__. + # This ensures that the __init__ on AbstractTemplateEntity is not called twice. + def __init__(self, config: dict[str, Any]) -> None: # pylint: disable=super-init-not-called + """Initialize the features.""" + self._step_template = config[CONF_STEP] + self._min_template = config[CONF_MIN] + self._max_template = config[CONF_MAX] + + self._attr_native_unit_of_measurement = config.get(CONF_UNIT_OF_MEASUREMENT) + self._attr_native_step = DEFAULT_STEP + self._attr_native_min_value = DEFAULT_MIN_VALUE + self._attr_native_max_value = DEFAULT_MAX_VALUE + + async def async_set_native_value(self, value: float) -> None: + """Set value of the number.""" + if self._attr_assumed_state: + self._attr_native_value = value + self.async_write_ha_state() + if set_value := self._action_scripts.get(CONF_SET_VALUE): + await self.async_run_script( + set_value, + run_variables={ATTR_VALUE: value}, + context=self._context, + ) + + +class StateNumberEntity(TemplateEntity, AbstractTemplateNumber): """Representation of a template number.""" _attr_should_poll = False - _entity_id_format = ENTITY_ID_FORMAT def __init__( self, hass: HomeAssistant, - config, + config: ConfigType, unique_id: str | None, ) -> None: """Initialize the number.""" TemplateEntity.__init__(self, hass, config, unique_id) - if TYPE_CHECKING: - assert self._attr_name is not None + AbstractTemplateNumber.__init__(self, config) - self._value_template = config[CONF_STATE] - self.add_script(CONF_SET_VALUE, config[CONF_SET_VALUE], self._attr_name, DOMAIN) + name = self._attr_name + if TYPE_CHECKING: + assert name is not None - self._step_template = config[CONF_STEP] - self._min_value_template = config[CONF_MIN] - self._max_value_template = config[CONF_MAX] - self._attr_assumed_state = self._optimistic = config.get(CONF_OPTIMISTIC) - self._attr_native_unit_of_measurement = config.get(CONF_UNIT_OF_MEASUREMENT) - self._attr_native_step = DEFAULT_STEP - self._attr_native_min_value = DEFAULT_MIN_VALUE - self._attr_native_max_value = DEFAULT_MAX_VALUE + self.add_script(CONF_SET_VALUE, config[CONF_SET_VALUE], name, DOMAIN) @callback def _async_setup_templates(self) -> None: """Set up templates.""" - self.add_template_attribute( - "_attr_native_value", - self._value_template, - validator=vol.Coerce(float), - none_on_template_error=True, - ) - self.add_template_attribute( - "_attr_native_step", - self._step_template, - validator=vol.Coerce(float), - none_on_template_error=True, - ) - if self._min_value_template is not None: + if self._template is not None: + self.add_template_attribute( + "_attr_native_value", + self._template, + vol.Coerce(float), + none_on_template_error=True, + ) + if self._step_template is not None: + self.add_template_attribute( + "_attr_native_step", + self._step_template, + vol.Coerce(float), + none_on_template_error=True, + ) + if self._min_template is not None: self.add_template_attribute( "_attr_native_min_value", - self._min_value_template, + self._min_template, validator=vol.Coerce(float), none_on_template_error=True, ) - if self._max_value_template is not None: + if self._max_template is not None: self.add_template_attribute( "_attr_native_max_value", - self._max_value_template, + self._max_template, validator=vol.Coerce(float), none_on_template_error=True, ) super()._async_setup_templates() - async def async_set_native_value(self, value: float) -> None: - """Set value of the number.""" - if self._optimistic: - self._attr_native_value = value - self.async_write_ha_state() - if set_value := self._action_scripts.get(CONF_SET_VALUE): - await self.async_run_script( - set_value, - run_variables={ATTR_VALUE: value}, - context=self._context, - ) - -class TriggerNumberEntity(TriggerEntity, NumberEntity): +class TriggerNumberEntity(TriggerEntity, AbstractTemplateNumber): """Number entity based on trigger data.""" - _entity_id_format = ENTITY_ID_FORMAT domain = NUMBER_DOMAIN - extra_template_keys = ( - CONF_STATE, - CONF_STEP, - CONF_MIN, - CONF_MAX, - ) def __init__( self, @@ -213,47 +211,49 @@ def __init__( config: dict, ) -> None: """Initialize the entity.""" - super().__init__(hass, coordinator, config) - - name = self._rendered.get(CONF_NAME, DEFAULT_NAME) - self.add_script(CONF_SET_VALUE, config[CONF_SET_VALUE], name, DOMAIN) - - self._attr_native_unit_of_measurement = config.get(CONF_UNIT_OF_MEASUREMENT) - - @property - def native_value(self) -> float | None: - """Return the currently selected option.""" - return vol.Any(vol.Coerce(float), None)(self._rendered.get(CONF_STATE)) - - @property - def native_min_value(self) -> int: - """Return the minimum value.""" - return vol.Any(vol.Coerce(float), None)( - self._rendered.get(CONF_MIN, super().native_min_value) + TriggerEntity.__init__(self, hass, coordinator, config) + AbstractTemplateNumber.__init__(self, config) + + for key in ( + CONF_STATE, + CONF_STEP, + CONF_MIN, + CONF_MAX, + ): + if isinstance(config.get(key), template.Template): + self._to_render_simple.append(key) + self._parse_result.add(key) + + self.add_script( + CONF_SET_VALUE, + config[CONF_SET_VALUE], + self._rendered.get(CONF_NAME, DEFAULT_NAME), + DOMAIN, ) - @property - def native_max_value(self) -> int: - """Return the maximum value.""" - return vol.Any(vol.Coerce(float), None)( - self._rendered.get(CONF_MAX, super().native_max_value) - ) + def _handle_coordinator_update(self): + """Handle updated data from the coordinator.""" + self._process_data() - @property - def native_step(self) -> int: - """Return the increment/decrement step.""" - return vol.Any(vol.Coerce(float), None)( - self._rendered.get(CONF_STEP, super().native_step) - ) - - async def async_set_native_value(self, value: float) -> None: - """Set value of the number.""" - if self._config[CONF_OPTIMISTIC]: - self._attr_native_value = value + if not self.available: + self.async_write_ha_state() + return + + write_ha_state = False + for key, attr in ( + (CONF_STATE, "_attr_native_value"), + (CONF_STEP, "_attr_native_step"), + (CONF_MIN, "_attr_native_min_value"), + (CONF_MAX, "_attr_native_max_value"), + ): + if (rendered := self._rendered.get(key)) is not None: + setattr(self, attr, vol.Any(vol.Coerce(float), None)(rendered)) + write_ha_state = True + + if len(self._rendered) > 0: + # In case any non optimistic template + write_ha_state = True + + if write_ha_state: + self.async_set_context(self.coordinator.data["context"]) self.async_write_ha_state() - if set_value := self._action_scripts.get(CONF_SET_VALUE): - await self.async_run_script( - set_value, - run_variables={ATTR_VALUE: value}, - context=self._context, - ) diff --git a/tests/components/template/test_number.py b/tests/components/template/test_number.py index 21dea28b73fc6c..0ae98a23ae45d1 100644 --- a/tests/components/template/test_number.py +++ b/tests/components/template/test_number.py @@ -29,6 +29,7 @@ CONF_ENTITY_ID, CONF_ICON, CONF_UNIT_OF_MEASUREMENT, + STATE_UNAVAILABLE, STATE_UNKNOWN, ) from homeassistant.core import Context, HomeAssistant, ServiceCall @@ -63,11 +64,11 @@ } TEST_STATE_ENTITY_ID = "number.test_state" - +TEST_AVAILABILITY_ENTITY_ID = "binary_sensor.test_availability" TEST_STATE_TRIGGER = { "trigger": { "trigger": "state", - "entity_id": [TEST_STATE_ENTITY_ID], + "entity_id": [TEST_STATE_ENTITY_ID, TEST_AVAILABILITY_ENTITY_ID], }, "variables": {"triggering_entity": "{{ trigger.entity_id }}"}, "action": [ @@ -191,19 +192,6 @@ async def test_missing_optional_config(hass: HomeAssistant) -> None: async def test_missing_required_keys(hass: HomeAssistant) -> None: """Test: missing required fields will fail.""" - with assert_setup_component(0, "template"): - assert await setup.async_setup_component( - hass, - "template", - { - "template": { - "number": { - "set_value": {"service": "script.set_value"}, - } - } - }, - ) - with assert_setup_component(0, "template"): assert await setup.async_setup_component( hass, @@ -578,6 +566,91 @@ async def test_device_id( assert template_entity.device_id == device_entry.id +@pytest.mark.parametrize( + ("count", "number_config"), + [ + ( + 1, + { + "set_value": [], + }, + ) + ], +) +@pytest.mark.parametrize( + "style", + [ConfigurationStyle.MODERN, ConfigurationStyle.TRIGGER], +) +@pytest.mark.usefixtures("setup_number") +async def test_optimistic(hass: HomeAssistant) -> None: + """Test configuration with optimistic state.""" + await hass.services.async_call( + number.DOMAIN, + number.SERVICE_SET_VALUE, + {ATTR_ENTITY_ID: _TEST_NUMBER, "value": 4}, + blocking=True, + ) + + state = hass.states.get(_TEST_NUMBER) + assert float(state.state) == 4 + + await hass.services.async_call( + number.DOMAIN, + number.SERVICE_SET_VALUE, + {ATTR_ENTITY_ID: _TEST_NUMBER, "value": 2}, + blocking=True, + ) + + state = hass.states.get(_TEST_NUMBER) + assert float(state.state) == 2 + + +@pytest.mark.parametrize( + ("count", "number_config"), + [ + ( + 1, + { + "set_value": [], + "state": "{{ states('number.test_state') }}", + "availability": "{{ is_state('binary_sensor.test_availability', 'on') }}", + }, + ) + ], +) +@pytest.mark.parametrize( + "style", [ConfigurationStyle.MODERN, ConfigurationStyle.TRIGGER] +) +@pytest.mark.usefixtures("setup_number") +async def test_availability(hass: HomeAssistant) -> None: + """Test configuration with optimistic state.""" + + hass.states.async_set(TEST_AVAILABILITY_ENTITY_ID, "on") + hass.states.async_set(TEST_STATE_ENTITY_ID, "4.0") + await hass.async_block_till_done() + + state = hass.states.get(_TEST_NUMBER) + assert float(state.state) == 4 + + hass.states.async_set(TEST_AVAILABILITY_ENTITY_ID, "off") + await hass.async_block_till_done() + + state = hass.states.get(_TEST_NUMBER) + assert state.state == STATE_UNAVAILABLE + + hass.states.async_set(TEST_STATE_ENTITY_ID, "2.0") + await hass.async_block_till_done() + + state = hass.states.get(_TEST_NUMBER) + assert state.state == STATE_UNAVAILABLE + + hass.states.async_set(TEST_AVAILABILITY_ENTITY_ID, "on") + await hass.async_block_till_done() + + state = hass.states.get(_TEST_NUMBER) + assert float(state.state) == 2 + + @pytest.mark.parametrize( ("count", "number_config"), [