Skip to content

Commit

Permalink
Fix ESPHome service removal when the device name contains a dash (#10…
Browse files Browse the repository at this point in the history
…7015)

* Fix ESPHome service removal when the device name contains a dash

If the device name contains a dash the service name is mutated to
replace the dash with an underscore, but the remove function did
not do the same mutation so it would fail to remove the service

* add more coverage

* more cover
  • Loading branch information
bdraco committed Jan 4, 2024
1 parent afcf8c9 commit 01d0031
Show file tree
Hide file tree
Showing 2 changed files with 275 additions and 32 deletions.
79 changes: 48 additions & 31 deletions homeassistant/components/esphome/manager.py
Expand Up @@ -3,6 +3,7 @@

import asyncio
from collections.abc import Coroutine
from functools import partial
import logging
from typing import TYPE_CHECKING, Any, NamedTuple

Expand Down Expand Up @@ -456,12 +457,10 @@ async def on_connect(self) -> None:

self.device_id = _async_setup_device_registry(hass, entry, entry_data)
entry_data.async_update_device_state(hass)
await asyncio.gather(
entry_data.async_update_static_infos(
hass, entry, entity_infos, device_info.mac_address
),
_setup_services(hass, entry_data, services),
await entry_data.async_update_static_infos(
hass, entry, entity_infos, device_info.mac_address
)
_setup_services(hass, entry_data, services)

setup_coros_with_disconnect_callbacks: list[
Coroutine[Any, Any, CALLBACK_TYPE]
Expand Down Expand Up @@ -586,7 +585,7 @@ async def async_start(self) -> None:
await entry_data.async_update_static_infos(
hass, entry, infos, entry.unique_id.upper()
)
await _setup_services(hass, entry_data, services)
_setup_services(hass, entry_data, services)

if entry_data.device_info is not None and entry_data.device_info.name:
reconnect_logic.name = entry_data.device_info.name
Expand Down Expand Up @@ -708,12 +707,27 @@ class ServiceMetadata(NamedTuple):
}


async def _register_service(
hass: HomeAssistant, entry_data: RuntimeEntryData, service: UserService
async def execute_service(
entry_data: RuntimeEntryData, service: UserService, call: ServiceCall
) -> None:
if entry_data.device_info is None:
raise ValueError("Device Info needs to be fetched first")
service_name = f"{entry_data.device_info.name.replace('-', '_')}_{service.name}"
"""Execute a service on a node."""
await entry_data.client.execute_service(service, call.data)


def build_service_name(device_info: EsphomeDeviceInfo, service: UserService) -> str:
"""Build a service name for a node."""
return f"{device_info.name.replace('-', '_')}_{service.name}"


@callback
def _async_register_service(
hass: HomeAssistant,
entry_data: RuntimeEntryData,
device_info: EsphomeDeviceInfo,
service: UserService,
) -> None:
"""Register a service on a node."""
service_name = build_service_name(device_info, service)
schema = {}
fields = {}

Expand All @@ -736,33 +750,36 @@ async def _register_service(
"selector": metadata.selector,
}

async def execute_service(call: ServiceCall) -> None:
await entry_data.client.execute_service(service, call.data)

hass.services.async_register(
DOMAIN, service_name, execute_service, vol.Schema(schema)
DOMAIN,
service_name,
partial(execute_service, entry_data, service),
vol.Schema(schema),
)
async_set_service_schema(
hass,
DOMAIN,
service_name,
{
"description": (
f"Calls the service {service.name} of the node {device_info.name}"
),
"fields": fields,
},
)

service_desc = {
"description": (
f"Calls the service {service.name} of the node"
f" {entry_data.device_info.name}"
),
"fields": fields,
}

async_set_service_schema(hass, DOMAIN, service_name, service_desc)


async def _setup_services(
@callback
def _setup_services(
hass: HomeAssistant, entry_data: RuntimeEntryData, services: list[UserService]
) -> None:
if entry_data.device_info is None:
device_info = entry_data.device_info
if device_info is None:
# Can happen if device has never connected or .storage cleared
return
old_services = entry_data.services.copy()
to_unregister = []
to_register = []
to_unregister: list[UserService] = []
to_register: list[UserService] = []
for service in services:
if service.key in old_services:
# Already exists
Expand All @@ -780,11 +797,11 @@ async def _setup_services(
entry_data.services = {serv.key: serv for serv in services}

for service in to_unregister:
service_name = f"{entry_data.device_info.name}_{service.name}"
service_name = build_service_name(device_info, service)
hass.services.async_remove(DOMAIN, service_name)

for service in to_register:
await _register_service(hass, entry_data, service)
_async_register_service(hass, entry_data, device_info, service)


async def cleanup_instance(hass: HomeAssistant, entry: ConfigEntry) -> RuntimeEntryData:
Expand Down
228 changes: 227 additions & 1 deletion tests/components/esphome/test_manager.py
Expand Up @@ -2,7 +2,15 @@
from collections.abc import Awaitable, Callable
from unittest.mock import AsyncMock, call

from aioesphomeapi import APIClient, DeviceInfo, EntityInfo, EntityState, UserService
from aioesphomeapi import (
APIClient,
DeviceInfo,
EntityInfo,
EntityState,
UserService,
UserServiceArg,
UserServiceArgType,
)
import pytest

from homeassistant import config_entries
Expand Down Expand Up @@ -374,3 +382,221 @@ async def test_debug_logging(
)
await hass.async_block_till_done()
mock_client.set_debug.assert_has_calls([call(False)])


async def test_esphome_device_with_dash_in_name_user_services(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test a device with user services and a dash in the name."""
entity_info = []
states = []
service1 = UserService(
name="my_service",
key=1,
args=[
UserServiceArg(name="arg1", type=UserServiceArgType.BOOL),
UserServiceArg(name="arg2", type=UserServiceArgType.INT),
UserServiceArg(name="arg3", type=UserServiceArgType.FLOAT),
UserServiceArg(name="arg4", type=UserServiceArgType.STRING),
UserServiceArg(name="arg5", type=UserServiceArgType.BOOL_ARRAY),
UserServiceArg(name="arg6", type=UserServiceArgType.INT_ARRAY),
UserServiceArg(name="arg7", type=UserServiceArgType.FLOAT_ARRAY),
UserServiceArg(name="arg8", type=UserServiceArgType.STRING_ARRAY),
],
)
service2 = UserService(
name="simple_service",
key=2,
args=[
UserServiceArg(name="arg1", type=UserServiceArgType.BOOL),
],
)
device = await mock_esphome_device(
mock_client=mock_client,
entity_info=entity_info,
user_service=[service1, service2],
device_info={"name": "with-dash"},
states=states,
)
await hass.async_block_till_done()
assert hass.services.has_service(DOMAIN, "with_dash_my_service")
assert hass.services.has_service(DOMAIN, "with_dash_simple_service")

await hass.services.async_call(DOMAIN, "with_dash_simple_service", {"arg1": True})
await hass.async_block_till_done()

mock_client.execute_service.assert_has_calls(
[
call(
UserService(
name="simple_service",
key=2,
args=[UserServiceArg(name="arg1", type=UserServiceArgType.BOOL)],
),
{"arg1": True},
)
]
)
mock_client.execute_service.reset_mock()

# Verify the service can be removed
mock_client.list_entities_services = AsyncMock(
return_value=(entity_info, [service1])
)
await device.mock_disconnect(True)
await hass.async_block_till_done()
await device.mock_connect()
await hass.async_block_till_done()
assert hass.services.has_service(DOMAIN, "with_dash_my_service")
assert not hass.services.has_service(DOMAIN, "with_dash_simple_service")


async def test_esphome_user_services_ignores_invalid_arg_types(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test a device with user services and a dash in the name."""
entity_info = []
states = []
service1 = UserService(
name="bad_service",
key=1,
args=[
UserServiceArg(name="arg1", type="wrong"),
],
)
service2 = UserService(
name="simple_service",
key=2,
args=[
UserServiceArg(name="arg1", type=UserServiceArgType.BOOL),
],
)
device = await mock_esphome_device(
mock_client=mock_client,
entity_info=entity_info,
user_service=[service1, service2],
device_info={"name": "with-dash"},
states=states,
)
await hass.async_block_till_done()
assert not hass.services.has_service(DOMAIN, "with_dash_bad_service")
assert hass.services.has_service(DOMAIN, "with_dash_simple_service")

await hass.services.async_call(DOMAIN, "with_dash_simple_service", {"arg1": True})
await hass.async_block_till_done()

mock_client.execute_service.assert_has_calls(
[
call(
UserService(
name="simple_service",
key=2,
args=[UserServiceArg(name="arg1", type=UserServiceArgType.BOOL)],
),
{"arg1": True},
)
]
)
mock_client.execute_service.reset_mock()

# Verify the service can be removed
mock_client.list_entities_services = AsyncMock(
return_value=(entity_info, [service2])
)
await device.mock_disconnect(True)
await hass.async_block_till_done()
await device.mock_connect()
await hass.async_block_till_done()
assert hass.services.has_service(DOMAIN, "with_dash_simple_service")
assert not hass.services.has_service(DOMAIN, "with_dash_bad_service")


async def test_esphome_user_services_changes(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test a device with user services that change arguments."""
entity_info = []
states = []
service1 = UserService(
name="simple_service",
key=2,
args=[
UserServiceArg(name="arg1", type=UserServiceArgType.BOOL),
],
)
device = await mock_esphome_device(
mock_client=mock_client,
entity_info=entity_info,
user_service=[service1],
device_info={"name": "with-dash"},
states=states,
)
await hass.async_block_till_done()
assert hass.services.has_service(DOMAIN, "with_dash_simple_service")

await hass.services.async_call(DOMAIN, "with_dash_simple_service", {"arg1": True})
await hass.async_block_till_done()

mock_client.execute_service.assert_has_calls(
[
call(
UserService(
name="simple_service",
key=2,
args=[UserServiceArg(name="arg1", type=UserServiceArgType.BOOL)],
),
{"arg1": True},
)
]
)
mock_client.execute_service.reset_mock()

new_service1 = UserService(
name="simple_service",
key=2,
args=[
UserServiceArg(name="arg1", type=UserServiceArgType.FLOAT),
],
)

# Verify the service can be updated
mock_client.list_entities_services = AsyncMock(
return_value=(entity_info, [new_service1])
)
await device.mock_disconnect(True)
await hass.async_block_till_done()
await device.mock_connect()
await hass.async_block_till_done()
assert hass.services.has_service(DOMAIN, "with_dash_simple_service")

await hass.services.async_call(DOMAIN, "with_dash_simple_service", {"arg1": 4.5})
await hass.async_block_till_done()

mock_client.execute_service.assert_has_calls(
[
call(
UserService(
name="simple_service",
key=2,
args=[UserServiceArg(name="arg1", type=UserServiceArgType.FLOAT)],
),
{"arg1": 4.5},
)
]
)
mock_client.execute_service.reset_mock()

0 comments on commit 01d0031

Please sign in to comment.