From 4d72906533cbc967d90851fec00b0acf9bab2b32 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 24 Apr 2024 07:00:35 +0200 Subject: [PATCH] Move thread safety check in async_register/async_remove Move the thread safety check in the async_register and async_remove path to be sooner so it catches the problem before the service is being registered. Previously it would catch it when async_fire was called which was confusing to the dev because it would complain about async_fire and not async_register or async_remove. --- homeassistant/core.py | 44 +++++++++++++++++++++++++++++++++++++++---- tests/test_core.py | 23 ++++++++++++++++++++++ 2 files changed, 63 insertions(+), 4 deletions(-) diff --git a/homeassistant/core.py b/homeassistant/core.py index 189dc2f9d8a565..a3150adc2215af 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -2456,7 +2456,7 @@ def register( """ run_callback_threadsafe( self._hass.loop, - self.async_register, + self._async_register, domain, service, service_func, @@ -2484,6 +2484,33 @@ def async_register( Schema is called to coerce and validate the service data. + This method must be run in the event loop. + """ + self._hass.verify_event_loop_thread("async_register") + self._async_register( + domain, service, service_func, schema, supports_response, job_type + ) + + @callback + def _async_register( + self, + domain: str, + service: str, + service_func: Callable[ + [ServiceCall], + Coroutine[Any, Any, ServiceResponse | EntityServiceResponse] + | ServiceResponse + | EntityServiceResponse + | None, + ], + schema: vol.Schema | None = None, + supports_response: SupportsResponse = SupportsResponse.NONE, + job_type: HassJobType | None = None, + ) -> None: + """Register a service. + + Schema is called to coerce and validate the service data. + This method must be run in the event loop. """ domain = domain.lower() @@ -2502,20 +2529,29 @@ def async_register( else: self._services[domain] = {service: service_obj} - self._hass.bus.async_fire( + self._hass.bus.async_fire_internal( EVENT_SERVICE_REGISTERED, {ATTR_DOMAIN: domain, ATTR_SERVICE: service} ) def remove(self, domain: str, service: str) -> None: """Remove a registered service from service handler.""" run_callback_threadsafe( - self._hass.loop, self.async_remove, domain, service + self._hass.loop, self._async_remove, domain, service ).result() @callback def async_remove(self, domain: str, service: str) -> None: """Remove a registered service from service handler. + This method must be run in the event loop. + """ + self._hass.verify_event_loop_thread("async_remove") + self._async_remove(domain, service) + + @callback + def _async_remove(self, domain: str, service: str) -> None: + """Remove a registered service from service handler. + This method must be run in the event loop. """ domain = domain.lower() @@ -2530,7 +2566,7 @@ def async_remove(self, domain: str, service: str) -> None: if not self._services[domain]: self._services.pop(domain) - self._hass.bus.async_fire( + self._hass.bus.async_fire_internal( EVENT_SERVICE_REMOVED, {ATTR_DOMAIN: domain, ATTR_SERVICE: service} ) diff --git a/tests/test_core.py b/tests/test_core.py index 6bab89bca854aa..a553d5bbbedfef 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -3457,3 +3457,26 @@ async def test_async_fire_thread_safety(hass: HomeAssistant) -> None: await hass.async_add_executor_job(hass.bus.async_fire, "test_event") assert len(events) == 1 + + +async def test_async_register_thread_safety(hass: HomeAssistant) -> None: + """Test async_register thread safety.""" + with pytest.raises( + RuntimeError, match="Detected code that calls async_register from a thread." + ): + await hass.async_add_executor_job( + hass.services.async_register, + "test_domain", + "test_service", + lambda call: None, + ) + + +async def test_async_remove_thread_safety(hass: HomeAssistant) -> None: + """Test async_remove thread safety.""" + with pytest.raises( + RuntimeError, match="Detected code that calls async_remove from a thread." + ): + await hass.async_add_executor_job( + hass.services.async_remove, "test_domain", "test_service" + )