From f90b88d274ff35b417e2fd67c63befffe41c04ee Mon Sep 17 00:00:00 2001 From: Martin Hjelmare Date: Wed, 27 Dec 2023 12:57:34 +0100 Subject: [PATCH 1/6] Add cloud tts entity --- homeassistant/components/cloud/__init__.py | 11 +- .../components/cloud/assist_pipeline.py | 54 ++++++--- homeassistant/components/cloud/const.py | 1 + homeassistant/components/cloud/prefs.py | 10 +- homeassistant/components/cloud/stt.py | 16 ++- homeassistant/components/cloud/tts.py | 103 +++++++++++++++++- tests/components/cloud/conftest.py | 11 ++ .../components/cloud/test_assist_pipeline.py | 16 +++ tests/components/cloud/test_http_api.py | 4 +- tests/components/cloud/test_stt.py | 6 - tests/components/cloud/test_tts.py | 92 +++++++++++++++- 11 files changed, 284 insertions(+), 40 deletions(-) create mode 100644 tests/components/cloud/test_assist_pipeline.py diff --git a/homeassistant/components/cloud/__init__.py b/homeassistant/components/cloud/__init__.py index cdaae0d6272dc3..888e99e3a3433e 100644 --- a/homeassistant/components/cloud/__init__.py +++ b/homeassistant/components/cloud/__init__.py @@ -65,7 +65,7 @@ DEFAULT_MODE = MODE_PROD -PLATFORMS = [Platform.BINARY_SENSOR, Platform.STT] +PLATFORMS = [Platform.BINARY_SENSOR, Platform.STT, Platform.TTS] SERVICE_REMOTE_CONNECT = "remote_connect" SERVICE_REMOTE_DISCONNECT = "remote_disconnect" @@ -288,9 +288,11 @@ async def async_startup_repairs(_: datetime) -> None: loaded = False stt_platform_loaded = asyncio.Event() tts_platform_loaded = asyncio.Event() + stt_tts_entities_added = asyncio.Event() hass.data[DATA_PLATFORMS_SETUP] = { Platform.STT: stt_platform_loaded, Platform.TTS: tts_platform_loaded, + "stt_tts_entities_added": stt_tts_entities_added, } async def _on_start() -> None: @@ -330,6 +332,7 @@ async def _on_initialized() -> None: account_link.async_setup(hass) + # Load legacy tts platform for backwards compatibility. hass.async_create_task( async_load_platform( hass, @@ -377,8 +380,10 @@ async def remote_prefs_updated(prefs: CloudPreferences) -> None: async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Set up a config entry.""" await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) - stt_platform_loaded: asyncio.Event = hass.data[DATA_PLATFORMS_SETUP][Platform.STT] - stt_platform_loaded.set() + stt_tts_entities_added: asyncio.Event = hass.data[DATA_PLATFORMS_SETUP][ + "stt_tts_entities_added" + ] + stt_tts_entities_added.set() return True diff --git a/homeassistant/components/cloud/assist_pipeline.py b/homeassistant/components/cloud/assist_pipeline.py index 31e990cdb81170..7bc3eed0c855a3 100644 --- a/homeassistant/components/cloud/assist_pipeline.py +++ b/homeassistant/components/cloud/assist_pipeline.py @@ -9,16 +9,23 @@ ) from homeassistant.components.conversation import HOME_ASSISTANT_AGENT from homeassistant.components.stt import DOMAIN as STT_DOMAIN +from homeassistant.components.tts import DOMAIN as TTS_DOMAIN from homeassistant.const import Platform from homeassistant.core import HomeAssistant import homeassistant.helpers.entity_registry as er -from .const import DATA_PLATFORMS_SETUP, DOMAIN, STT_ENTITY_UNIQUE_ID +from .const import ( + DATA_PLATFORMS_SETUP, + DOMAIN, + STT_ENTITY_UNIQUE_ID, + TTS_ENTITY_UNIQUE_ID, +) async def async_create_cloud_pipeline(hass: HomeAssistant) -> str | None: """Create a cloud assist pipeline.""" - # Wait for stt and tts platforms to set up before creating the pipeline. + # Wait for stt and tts platforms to set up and entities to be added + # before creating the pipeline. platforms_setup: dict[str, asyncio.Event] = hass.data[DATA_PLATFORMS_SETUP] await asyncio.gather(*(event.wait() for event in platforms_setup.values())) # Make sure the pipeline store is loaded, needed because assist_pipeline @@ -29,8 +36,11 @@ async def async_create_cloud_pipeline(hass: HomeAssistant) -> str | None: new_stt_engine_id = entity_registry.async_get_entity_id( STT_DOMAIN, DOMAIN, STT_ENTITY_UNIQUE_ID ) - if new_stt_engine_id is None: - # If there's no cloud stt entity, we can't create a cloud pipeline. + new_tts_engine_id = entity_registry.async_get_entity_id( + TTS_DOMAIN, DOMAIN, TTS_ENTITY_UNIQUE_ID + ) + if new_stt_engine_id is None or new_tts_engine_id is None: + # If there's no cloud stt or tts entity, we can't create a cloud pipeline. return None def cloud_assist_pipeline(hass: HomeAssistant) -> str | None: @@ -43,7 +53,7 @@ def cloud_assist_pipeline(hass: HomeAssistant) -> str | None: if ( pipeline.conversation_engine == HOME_ASSISTANT_AGENT and pipeline.stt_engine in (DOMAIN, new_stt_engine_id) - and pipeline.tts_engine == DOMAIN + and pipeline.tts_engine in (DOMAIN, new_tts_engine_id) ): return pipeline.id return None @@ -52,7 +62,7 @@ def cloud_assist_pipeline(hass: HomeAssistant) -> str | None: cloud_pipeline := await async_create_default_pipeline( hass, stt_engine_id=new_stt_engine_id, - tts_engine_id=DOMAIN, + tts_engine_id=new_tts_engine_id, pipeline_name="Home Assistant Cloud", ) ) is None: @@ -61,18 +71,28 @@ def cloud_assist_pipeline(hass: HomeAssistant) -> str | None: return cloud_pipeline.id -async def async_migrate_cloud_pipeline_stt_engine( - hass: HomeAssistant, stt_engine_id: str +async def async_migrate_cloud_pipeline_engine( + hass: HomeAssistant, platform: Platform, engine_id: str ) -> None: - """Migrate the speech-to-text engine in the cloud assist pipeline.""" - # Migrate existing pipelines with cloud stt to use new cloud stt engine id. - # Added in 2024.01.0. Can be removed in 2025.01.0. + """Migrate the pipeline engines in the cloud assist pipeline.""" + # Migrate existing pipelines with cloud stt or tts to use new cloud engine id. + # Added in 2024.02.0. Can be removed in 2025.02.0. + + # We need to make sure that that both stt and tts are loaded before this migration. + # Assist pipeline will call default engine when setting up the store. + # Wait for the stt or tts platform loaded event here. + kwargs: dict[str, str] = {} + if platform == Platform.STT: + wait_for_platform = Platform.TTS + kwargs["stt_engine"] = engine_id + elif platform == Platform.TTS: + wait_for_platform = Platform.STT + kwargs["tts_engine"] = engine_id + else: + raise ValueError(f"Invalid platform {platform}") - # We need to make sure that tts is loaded before this migration. - # Assist pipeline will call default engine of tts when setting up the store. - # Wait for the tts platform loaded event here. platforms_setup: dict[str, asyncio.Event] = hass.data[DATA_PLATFORMS_SETUP] - await platforms_setup[Platform.TTS].wait() + await platforms_setup[wait_for_platform].wait() # Make sure the pipeline store is loaded, needed because assist_pipeline # is an after dependency of cloud @@ -80,6 +100,6 @@ async def async_migrate_cloud_pipeline_stt_engine( pipelines = async_get_pipelines(hass) for pipeline in pipelines: - if pipeline.stt_engine != DOMAIN: + if pipeline.stt_engine != DOMAIN or pipeline.tts_engine != DOMAIN: continue - await async_update_pipeline(hass, pipeline, stt_engine=stt_engine_id) + await async_update_pipeline(hass, pipeline, **kwargs) diff --git a/homeassistant/components/cloud/const.py b/homeassistant/components/cloud/const.py index da012c20bab123..97d2345f16bb21 100644 --- a/homeassistant/components/cloud/const.py +++ b/homeassistant/components/cloud/const.py @@ -73,3 +73,4 @@ DISPATCHER_REMOTE_UPDATE: SignalType[Any] = SignalType("cloud_remote_update") STT_ENTITY_UNIQUE_ID = "cloud-speech-to-text" +TTS_ENTITY_UNIQUE_ID = "cloud-text-to-speech" diff --git a/homeassistant/components/cloud/prefs.py b/homeassistant/components/cloud/prefs.py index 4cc028673472aa..af5f9213e4dd78 100644 --- a/homeassistant/components/cloud/prefs.py +++ b/homeassistant/components/cloud/prefs.py @@ -104,10 +104,18 @@ async def async_initialize(self) -> None: @callback def async_listen_updates( self, listener: Callable[[CloudPreferences], Coroutine[Any, Any, None]] - ) -> None: + ) -> Callable[[], None]: """Listen for updates to the preferences.""" + + @callback + def unsubscribe() -> None: + """Remove the listener.""" + self._listeners.remove(listener) + self._listeners.append(listener) + return unsubscribe + async def async_update( self, *, diff --git a/homeassistant/components/cloud/stt.py b/homeassistant/components/cloud/stt.py index b652a36fa8a701..3368f25f94a955 100644 --- a/homeassistant/components/cloud/stt.py +++ b/homeassistant/components/cloud/stt.py @@ -1,6 +1,7 @@ """Support for the cloud for speech to text service.""" from __future__ import annotations +import asyncio from collections.abc import AsyncIterable import logging @@ -19,12 +20,13 @@ SpeechToTextEntity, ) from homeassistant.config_entries import ConfigEntry +from homeassistant.const import Platform from homeassistant.core import HomeAssistant from homeassistant.helpers.entity_platform import AddEntitiesCallback -from .assist_pipeline import async_migrate_cloud_pipeline_stt_engine +from .assist_pipeline import async_migrate_cloud_pipeline_engine from .client import CloudClient -from .const import DOMAIN, STT_ENTITY_UNIQUE_ID +from .const import DATA_PLATFORMS_SETUP, DOMAIN, STT_ENTITY_UNIQUE_ID _LOGGER = logging.getLogger(__name__) @@ -35,18 +37,20 @@ async def async_setup_entry( async_add_entities: AddEntitiesCallback, ) -> None: """Set up Home Assistant Cloud speech platform via config entry.""" + stt_platform_loaded: asyncio.Event = hass.data[DATA_PLATFORMS_SETUP][Platform.STT] + stt_platform_loaded.set() cloud: Cloud[CloudClient] = hass.data[DOMAIN] async_add_entities([CloudProviderEntity(cloud)]) class CloudProviderEntity(SpeechToTextEntity): - """NabuCasa speech API provider.""" + """Home Assistant Cloud speech API provider.""" _attr_name = "Home Assistant Cloud" _attr_unique_id = STT_ENTITY_UNIQUE_ID def __init__(self, cloud: Cloud[CloudClient]) -> None: - """Home Assistant NabuCasa Speech to text.""" + """Initialize cloud Speech to text entity.""" self.cloud = cloud @property @@ -81,7 +85,9 @@ def supported_channels(self) -> list[AudioChannels]: async def async_added_to_hass(self) -> None: """Run when entity is about to be added to hass.""" - await async_migrate_cloud_pipeline_stt_engine(self.hass, self.entity_id) + await async_migrate_cloud_pipeline_engine( + self.hass, platform=Platform.STT, engine_id=self.entity_id + ) async def async_process_audio_stream( self, metadata: SpeechMetadata, stream: AsyncIterable[bytes] diff --git a/homeassistant/components/cloud/tts.py b/homeassistant/components/cloud/tts.py index f8152243bf5852..2626c01e66f806 100644 --- a/homeassistant/components/cloud/tts.py +++ b/homeassistant/components/cloud/tts.py @@ -1,6 +1,7 @@ """Support for the cloud for text-to-speech service.""" from __future__ import annotations +import asyncio import logging from typing import Any @@ -12,16 +13,21 @@ ATTR_AUDIO_OUTPUT, ATTR_VOICE, CONF_LANG, - PLATFORM_SCHEMA, + PLATFORM_SCHEMA as TTS_PLATFORM_SCHEMA, Provider, + TextToSpeechEntity, TtsAudioType, Voice, ) +from homeassistant.config_entries import ConfigEntry +from homeassistant.const import Platform from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType +from .assist_pipeline import async_migrate_cloud_pipeline_engine from .client import CloudClient -from .const import DOMAIN +from .const import DATA_PLATFORMS_SETUP, DOMAIN, TTS_ENTITY_UNIQUE_ID from .prefs import CloudPreferences ATTR_GENDER = "gender" @@ -48,7 +54,7 @@ def validate_lang(value: dict[str, Any]) -> dict[str, Any]: PLATFORM_SCHEMA = vol.All( - PLATFORM_SCHEMA.extend( + TTS_PLATFORM_SCHEMA.extend( { vol.Optional(CONF_LANG): str, vol.Optional(ATTR_GENDER): str, @@ -81,8 +87,95 @@ async def async_get_engine( return cloud_provider +async def async_setup_entry( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddEntitiesCallback, +) -> None: + """Set up Home Assistant Cloud text-to-speech platform.""" + tts_platform_loaded: asyncio.Event = hass.data[DATA_PLATFORMS_SETUP][Platform.TTS] + tts_platform_loaded.set() + cloud: Cloud[CloudClient] = hass.data[DOMAIN] + async_add_entities([CloudTTSEntity(cloud)]) + + +class CloudTTSEntity(TextToSpeechEntity): + """Home Assistant Cloud text-to-speech entity.""" + + _attr_name = "Home Assistant Cloud" + _attr_unique_id = TTS_ENTITY_UNIQUE_ID + + def __init__(self, cloud: Cloud[CloudClient]) -> None: + """Initialize cloud text-to-speech entity.""" + self.cloud = cloud + self._language, self._gender = cloud.client.prefs.tts_default_voice + + async def _sync_prefs(self, prefs: CloudPreferences) -> None: + """Sync preferences.""" + self._language, self._gender = prefs.tts_default_voice + + @property + def default_language(self) -> str: + """Return the default language.""" + return self._language + + @property + def default_options(self) -> dict[str, Any]: + """Return a dict include default options.""" + return { + ATTR_GENDER: self._gender, + ATTR_AUDIO_OUTPUT: AudioOutput.MP3, + } + + @property + def supported_languages(self) -> list[str]: + """Return list of supported languages.""" + return SUPPORT_LANGUAGES + + @property + def supported_options(self) -> list[str]: + """Return list of supported options like voice, emotion.""" + return [ATTR_GENDER, ATTR_VOICE, ATTR_AUDIO_OUTPUT] + + async def async_added_to_hass(self) -> None: + """Handle entity which will be added.""" + await super().async_added_to_hass() + await async_migrate_cloud_pipeline_engine( + self.hass, platform=Platform.TTS, engine_id=self.entity_id + ) + self.async_on_remove( + self.cloud.client.prefs.async_listen_updates(self._sync_prefs) + ) + + @callback + def async_get_supported_voices(self, language: str) -> list[Voice] | None: + """Return a list of supported voices for a language.""" + if not (voices := TTS_VOICES.get(language)): + return None + return [Voice(voice, voice) for voice in voices] + + async def async_get_tts_audio( + self, message: str, language: str, options: dict[str, Any] + ) -> TtsAudioType: + """Load TTS from Home Assistant Cloud.""" + # Process TTS + try: + data = await self.cloud.voice.process_tts( + text=message, + language=language, + gender=options.get(ATTR_GENDER), + voice=options.get(ATTR_VOICE), + output=options[ATTR_AUDIO_OUTPUT], + ) + except VoiceError as err: + _LOGGER.error("Voice error: %s", err) + return (None, None) + + return (str(options[ATTR_AUDIO_OUTPUT].value), data) + + class CloudProvider(Provider): - """NabuCasa Cloud speech API provider.""" + """Home Assistant Cloud speech API provider.""" def __init__( self, cloud: Cloud[CloudClient], language: str | None, gender: str | None @@ -136,7 +229,7 @@ def default_options(self) -> dict[str, Any]: async def async_get_tts_audio( self, message: str, language: str, options: dict[str, Any] ) -> TtsAudioType: - """Load TTS from NabuCasa Cloud.""" + """Load TTS from Home Assistant Cloud.""" # Process TTS try: data = await self.cloud.voice.process_tts( diff --git a/tests/components/cloud/conftest.py b/tests/components/cloud/conftest.py index 1e1877ae13c408..7421914d3d4b3f 100644 --- a/tests/components/cloud/conftest.py +++ b/tests/components/cloud/conftest.py @@ -15,11 +15,22 @@ import pytest from homeassistant.components.cloud import CloudClient, const, prefs +from homeassistant.core import HomeAssistant +from homeassistant.setup import async_setup_component from homeassistant.util.dt import utcnow from . import mock_cloud, mock_cloud_prefs +@pytest.fixture(autouse=True) +async def load_homeassistant(hass: HomeAssistant) -> None: + """Load the homeassistant integration. + + This is needed for the cloud integration to work. + """ + assert await async_setup_component(hass, "homeassistant", {}) + + @pytest.fixture(name="cloud") async def cloud_fixture() -> AsyncGenerator[MagicMock, None]: """Mock the cloud object. diff --git a/tests/components/cloud/test_assist_pipeline.py b/tests/components/cloud/test_assist_pipeline.py new file mode 100644 index 00000000000000..7f1411dab455da --- /dev/null +++ b/tests/components/cloud/test_assist_pipeline.py @@ -0,0 +1,16 @@ +"""Test the cloud assist pipeline.""" +import pytest + +from homeassistant.components.cloud.assist_pipeline import ( + async_migrate_cloud_pipeline_engine, +) +from homeassistant.const import Platform +from homeassistant.core import HomeAssistant + + +async def test_migrate_pipeline_invalid_platform(hass: HomeAssistant) -> None: + """Test migrate pipeline with invalid platform.""" + with pytest.raises(ValueError): + await async_migrate_cloud_pipeline_engine( + hass, Platform.BINARY_SENSOR, "test-engine-id" + ) diff --git a/tests/components/cloud/test_http_api.py b/tests/components/cloud/test_http_api.py index 409d86d6e37133..08a3defb37d7c5 100644 --- a/tests/components/cloud/test_http_api.py +++ b/tests/components/cloud/test_http_api.py @@ -243,7 +243,7 @@ async def test_login_view_create_pipeline( create_pipeline_mock.assert_awaited_once_with( hass, stt_engine_id="stt.home_assistant_cloud", - tts_engine_id="cloud", + tts_engine_id="tts.home_assistant_cloud", pipeline_name="Home Assistant Cloud", ) @@ -282,7 +282,7 @@ async def test_login_view_create_pipeline_fail( create_pipeline_mock.assert_awaited_once_with( hass, stt_engine_id="stt.home_assistant_cloud", - tts_engine_id="cloud", + tts_engine_id="tts.home_assistant_cloud", pipeline_name="Home Assistant Cloud", ) diff --git a/tests/components/cloud/test_stt.py b/tests/components/cloud/test_stt.py index 666d8ae7d65b07..e3b8326116a5cc 100644 --- a/tests/components/cloud/test_stt.py +++ b/tests/components/cloud/test_stt.py @@ -65,12 +65,6 @@ } -@pytest.fixture(autouse=True) -async def load_homeassistant(hass: HomeAssistant) -> None: - """Load the homeassistant integration.""" - assert await async_setup_component(hass, "homeassistant", {}) - - @pytest.fixture(autouse=True) async def delay_save_fixture() -> AsyncGenerator[None, None]: """Load the homeassistant integration.""" diff --git a/tests/components/cloud/test_tts.py b/tests/components/cloud/test_tts.py index 4069edcb744cca..4d29b6d47dabc6 100644 --- a/tests/components/cloud/test_tts.py +++ b/tests/components/cloud/test_tts.py @@ -12,7 +12,9 @@ from homeassistant.components.tts import DOMAIN as TTS_DOMAIN from homeassistant.components.tts.helper import get_engine_instance from homeassistant.config import async_process_ha_core_config +from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN from homeassistant.core import HomeAssistant +from homeassistant.helpers.entity_registry import EntityRegistry from homeassistant.setup import async_setup_component from tests.typing import ClientSessionGenerator @@ -70,6 +72,10 @@ def test_schema() -> None: "gender": "female", }, ), + ( + "tts.home_assistant_cloud", + None, + ), ], ) async def test_prefs_default_voice( @@ -104,9 +110,17 @@ async def test_prefs_default_voice( assert engine.default_options == {"gender": "male", "audio_output": "mp3"} +@pytest.mark.parametrize( + "engine_id", + [ + DOMAIN, + "tts.home_assistant_cloud", + ], +) async def test_provider_properties( hass: HomeAssistant, cloud: MagicMock, + engine_id: str, ) -> None: """Test cloud provider.""" assert await async_setup_component(hass, "homeassistant", {}) @@ -115,7 +129,7 @@ async def test_provider_properties( on_start_callback = cloud.register_on_start.call_args[0][0] await on_start_callback() - engine = get_engine_instance(hass, DOMAIN) + engine = get_engine_instance(hass, engine_id) assert engine is not None assert engine.supported_options == ["gender", "voice", "audio_output"] @@ -132,6 +146,7 @@ async def test_provider_properties( [ ({"platform": DOMAIN}, DOMAIN), ({"engine_id": DOMAIN}, DOMAIN), + ({"engine_id": "tts.home_assistant_cloud"}, "tts.home_assistant_cloud"), ], ) @pytest.mark.parametrize( @@ -241,3 +256,78 @@ async def test_get_tts_audio_logged_out( assert mock_process_tts.call_args.kwargs["language"] == "en-US" assert mock_process_tts.call_args.kwargs["gender"] == "female" assert mock_process_tts.call_args.kwargs["output"] == "mp3" + + +@pytest.mark.parametrize( + ("mock_process_tts_return_value", "mock_process_tts_side_effect"), + [ + (b"", None), + (None, VoiceError("Boom!")), + ], +) +async def test_tts_entity( + hass: HomeAssistant, + hass_client: ClientSessionGenerator, + entity_registry: EntityRegistry, + cloud: MagicMock, + mock_process_tts_return_value: bytes | None, + mock_process_tts_side_effect: Exception | None, +) -> None: + """Test text-to-speech entity.""" + mock_process_tts = AsyncMock( + return_value=mock_process_tts_return_value, + side_effect=mock_process_tts_side_effect, + ) + cloud.voice.process_tts = mock_process_tts + assert await async_setup_component(hass, "homeassistant", {}) + assert await async_setup_component(hass, DOMAIN, {DOMAIN: {}}) + await hass.async_block_till_done() + on_start_callback = cloud.register_on_start.call_args[0][0] + await on_start_callback() + client = await hass_client() + entity_id = "tts.home_assistant_cloud" + + state = hass.states.get(entity_id) + assert state + assert state.state == STATE_UNKNOWN + + url = "/api/tts_get_url" + data = { + "engine_id": entity_id, + "message": "There is someone at the door.", + } + + req = await client.post(url, json=data) + assert req.status == HTTPStatus.OK + response = await req.json() + + assert response == { + "url": ( + "http://example.local:8123/api/tts_proxy/" + "42f18378fd4393d18c8dd11d03fa9563c1e54491" + f"_en-us_e09b5a0968_{entity_id}.mp3" + ), + "path": ( + "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491" + f"_en-us_e09b5a0968_{entity_id}.mp3" + ), + } + await hass.async_block_till_done() + + assert mock_process_tts.call_count == 1 + assert mock_process_tts.call_args is not None + assert mock_process_tts.call_args.kwargs["text"] == "There is someone at the door." + assert mock_process_tts.call_args.kwargs["language"] == "en-US" + assert mock_process_tts.call_args.kwargs["gender"] == "female" + assert mock_process_tts.call_args.kwargs["output"] == "mp3" + + state = hass.states.get(entity_id) + assert state + assert state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN) + + # Test removing the entity + entity_registry.async_remove(entity_id) + await hass.async_block_till_done() + + state = hass.states.get(entity_id) + assert state is None From fca4ace90d0524668e12ef4a1012c2e5ac3aa6f6 Mon Sep 17 00:00:00 2001 From: Martin Hjelmare Date: Thu, 18 Jan 2024 16:53:22 +0100 Subject: [PATCH 2/6] Test test_login_view_missing_entity --- tests/components/cloud/test_http_api.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/components/cloud/test_http_api.py b/tests/components/cloud/test_http_api.py index 08a3defb37d7c5..4602c054392dae 100644 --- a/tests/components/cloud/test_http_api.py +++ b/tests/components/cloud/test_http_api.py @@ -147,15 +147,19 @@ async def test_google_actions_sync_fails( assert mock_request_sync.call_count == 1 -async def test_login_view_missing_stt_entity( +@pytest.mark.parametrize( + "entity_id", ["stt.home_assistant_cloud", "tts.home_assistant_cloud"] +) +async def test_login_view_missing_entity( hass: HomeAssistant, setup_cloud: None, entity_registry: er.EntityRegistry, hass_client: ClientSessionGenerator, + entity_id: str, ) -> None: - """Test logging in when the cloud stt entity is missing.""" - # Make sure that the cloud stt entity does not exist. - entity_registry.async_remove("stt.home_assistant_cloud") + """Test logging in when a cloud assist pipeline needed entity is missing.""" + # Make sure that the cloud entity does not exist. + entity_registry.async_remove(entity_id) await hass.async_block_till_done() cloud_client = await hass_client() From 4cedcd54c2a10c0398b57c30236ba3a19ba15cdb Mon Sep 17 00:00:00 2001 From: Martin Hjelmare Date: Fri, 19 Jan 2024 10:46:16 +0100 Subject: [PATCH 3/6] Fix pipeline iteration for migration --- homeassistant/components/cloud/assist_pipeline.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/homeassistant/components/cloud/assist_pipeline.py b/homeassistant/components/cloud/assist_pipeline.py index 7bc3eed0c855a3..0d4ec057e65571 100644 --- a/homeassistant/components/cloud/assist_pipeline.py +++ b/homeassistant/components/cloud/assist_pipeline.py @@ -100,6 +100,5 @@ async def async_migrate_cloud_pipeline_engine( pipelines = async_get_pipelines(hass) for pipeline in pipelines: - if pipeline.stt_engine != DOMAIN or pipeline.tts_engine != DOMAIN: - continue - await async_update_pipeline(hass, pipeline, **kwargs) + if DOMAIN in (pipeline.stt_engine, pipeline.tts_engine): + await async_update_pipeline(hass, pipeline, **kwargs) From e287a134be62c2d1a3c2204e7c7d4f49a9312689 Mon Sep 17 00:00:00 2001 From: Martin Hjelmare Date: Fri, 19 Jan 2024 10:46:30 +0100 Subject: [PATCH 4/6] Update tests --- tests/components/cloud/__init__.py | 48 ++++++++++++++++++ tests/components/cloud/test_stt.py | 64 ++++------------------- tests/components/cloud/test_tts.py | 81 +++++++++++++++++++++++++++++- 3 files changed, 136 insertions(+), 57 deletions(-) diff --git a/tests/components/cloud/__init__.py b/tests/components/cloud/__init__.py index 22b84f032f6de5..e6e793ed106fdf 100644 --- a/tests/components/cloud/__init__.py +++ b/tests/components/cloud/__init__.py @@ -7,6 +7,54 @@ from homeassistant.components.cloud import const, prefs as cloud_prefs from homeassistant.setup import async_setup_component +PIPELINE_DATA = { + "items": [ + { + "conversation_engine": "conversation_engine_1", + "conversation_language": "language_1", + "id": "01GX8ZWBAQYWNB1XV3EXEZ75DY", + "language": "language_1", + "name": "Home Assistant Cloud", + "stt_engine": "cloud", + "stt_language": "language_1", + "tts_engine": "cloud", + "tts_language": "language_1", + "tts_voice": "Arnold Schwarzenegger", + "wake_word_entity": None, + "wake_word_id": None, + }, + { + "conversation_engine": "conversation_engine_2", + "conversation_language": "language_2", + "id": "01GX8ZWBAQTKFQNK4W7Q4CTRCX", + "language": "language_2", + "name": "name_2", + "stt_engine": "stt_engine_2", + "stt_language": "language_2", + "tts_engine": "tts_engine_2", + "tts_language": "language_2", + "tts_voice": "The Voice", + "wake_word_entity": None, + "wake_word_id": None, + }, + { + "conversation_engine": "conversation_engine_3", + "conversation_language": "language_3", + "id": "01GX8ZWBAQSV1HP3WGJPFWEJ8J", + "language": "language_3", + "name": "name_3", + "stt_engine": None, + "stt_language": None, + "tts_engine": None, + "tts_language": None, + "tts_voice": None, + "wake_word_entity": None, + "wake_word_id": None, + }, + ], + "preferred_item": "01GX8ZWBAQYWNB1XV3EXEZ75DY", +} + async def mock_cloud(hass, config=None): """Mock cloud.""" diff --git a/tests/components/cloud/test_stt.py b/tests/components/cloud/test_stt.py index e3b8326116a5cc..305780e33e11ca 100644 --- a/tests/components/cloud/test_stt.py +++ b/tests/components/cloud/test_stt.py @@ -14,55 +14,9 @@ from homeassistant.core import HomeAssistant from homeassistant.setup import async_setup_component -from tests.typing import ClientSessionGenerator +from . import PIPELINE_DATA -PIPELINE_DATA = { - "items": [ - { - "conversation_engine": "conversation_engine_1", - "conversation_language": "language_1", - "id": "01GX8ZWBAQYWNB1XV3EXEZ75DY", - "language": "language_1", - "name": "Home Assistant Cloud", - "stt_engine": "cloud", - "stt_language": "language_1", - "tts_engine": "cloud", - "tts_language": "language_1", - "tts_voice": "Arnold Schwarzenegger", - "wake_word_entity": None, - "wake_word_id": None, - }, - { - "conversation_engine": "conversation_engine_2", - "conversation_language": "language_2", - "id": "01GX8ZWBAQTKFQNK4W7Q4CTRCX", - "language": "language_2", - "name": "name_2", - "stt_engine": "stt_engine_2", - "stt_language": "language_2", - "tts_engine": "tts_engine_2", - "tts_language": "language_2", - "tts_voice": "The Voice", - "wake_word_entity": None, - "wake_word_id": None, - }, - { - "conversation_engine": "conversation_engine_3", - "conversation_language": "language_3", - "id": "01GX8ZWBAQSV1HP3WGJPFWEJ8J", - "language": "language_3", - "name": "name_3", - "stt_engine": None, - "stt_language": None, - "tts_engine": None, - "tts_language": None, - "tts_voice": None, - "wake_word_entity": None, - "wake_word_id": None, - }, - ], - "preferred_item": "01GX8ZWBAQYWNB1XV3EXEZ75DY", -} +from tests.typing import ClientSessionGenerator @pytest.fixture(autouse=True) @@ -137,6 +91,7 @@ async def test_migrating_pipelines( hass_storage: dict[str, Any], ) -> None: """Test migrating pipelines when cloud stt entity is added.""" + entity_id = "stt.home_assistant_cloud" cloud.voice.process_stt = AsyncMock( return_value=STTResponse(True, "Turn the Kitchen Lights on") ) @@ -151,18 +106,18 @@ async def test_migrating_pipelines( assert await async_setup_component(hass, DOMAIN, {"cloud": {}}) await hass.async_block_till_done() - on_start_callback = cloud.register_on_start.call_args[0][0] - await on_start_callback() + await cloud.login("test-user", "test-pass") await hass.async_block_till_done() - state = hass.states.get("stt.home_assistant_cloud") + state = hass.states.get(entity_id) assert state assert state.state == STATE_UNKNOWN - # The stt engine should be updated to the new cloud stt engine id. + # The stt/tts engines should have been updated to the new cloud engine ids. + assert hass_storage[STORAGE_KEY]["data"]["items"][0]["stt_engine"] == entity_id assert ( - hass_storage[STORAGE_KEY]["data"]["items"][0]["stt_engine"] - == "stt.home_assistant_cloud" + hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_engine"] + == "tts.home_assistant_cloud" ) # The other items should stay the same. @@ -183,7 +138,6 @@ async def test_migrating_pipelines( hass_storage[STORAGE_KEY]["data"]["items"][0]["name"] == "Home Assistant Cloud" ) assert hass_storage[STORAGE_KEY]["data"]["items"][0]["stt_language"] == "language_1" - assert hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_engine"] == "cloud" assert hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_language"] == "language_1" assert ( hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_voice"] diff --git a/tests/components/cloud/test_tts.py b/tests/components/cloud/test_tts.py index 4d29b6d47dabc6..b75d2361070dc6 100644 --- a/tests/components/cloud/test_tts.py +++ b/tests/components/cloud/test_tts.py @@ -1,13 +1,15 @@ """Tests for cloud tts.""" -from collections.abc import Callable, Coroutine +from collections.abc import AsyncGenerator, Callable, Coroutine +from copy import deepcopy from http import HTTPStatus from typing import Any -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch from hass_nabucasa.voice import MAP_VOICE, VoiceError, VoiceTokenError import pytest import voluptuous as vol +from homeassistant.components.assist_pipeline.pipeline import STORAGE_KEY from homeassistant.components.cloud import DOMAIN, const, tts from homeassistant.components.tts import DOMAIN as TTS_DOMAIN from homeassistant.components.tts.helper import get_engine_instance @@ -17,9 +19,18 @@ from homeassistant.helpers.entity_registry import EntityRegistry from homeassistant.setup import async_setup_component +from . import PIPELINE_DATA + from tests.typing import ClientSessionGenerator +@pytest.fixture(autouse=True) +async def delay_save_fixture() -> AsyncGenerator[None, None]: + """Load the homeassistant integration.""" + with patch("homeassistant.helpers.collection.SAVE_DELAY", new=0): + yield + + @pytest.fixture(autouse=True) async def internal_url_mock(hass: HomeAssistant) -> None: """Mock internal URL of the instance.""" @@ -331,3 +342,69 @@ async def test_tts_entity( state = hass.states.get(entity_id) assert state is None + + +async def test_migrating_pipelines( + hass: HomeAssistant, + cloud: MagicMock, + hass_client: ClientSessionGenerator, + hass_storage: dict[str, Any], +) -> None: + """Test migrating pipelines when cloud tts entity is added.""" + entity_id = "tts.home_assistant_cloud" + mock_process_tts = AsyncMock( + return_value=b"", + ) + cloud.voice.process_tts = mock_process_tts + hass_storage[STORAGE_KEY] = { + "version": 1, + "minor_version": 1, + "key": "assist_pipeline.pipelines", + "data": deepcopy(PIPELINE_DATA), + } + + assert await async_setup_component(hass, "assist_pipeline", {}) + assert await async_setup_component(hass, DOMAIN, {"cloud": {}}) + await hass.async_block_till_done() + + await cloud.login("test-user", "test-pass") + await hass.async_block_till_done() + + state = hass.states.get(entity_id) + assert state + assert state.state == STATE_UNKNOWN + + # The stt/tts engines should have been updated to the new cloud engine ids. + assert ( + hass_storage[STORAGE_KEY]["data"]["items"][0]["stt_engine"] + == "stt.home_assistant_cloud" + ) + assert hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_engine"] == entity_id + + # The other items should stay the same. + assert ( + hass_storage[STORAGE_KEY]["data"]["items"][0]["conversation_engine"] + == "conversation_engine_1" + ) + assert ( + hass_storage[STORAGE_KEY]["data"]["items"][0]["conversation_language"] + == "language_1" + ) + assert ( + hass_storage[STORAGE_KEY]["data"]["items"][0]["id"] + == "01GX8ZWBAQYWNB1XV3EXEZ75DY" + ) + assert hass_storage[STORAGE_KEY]["data"]["items"][0]["language"] == "language_1" + assert ( + hass_storage[STORAGE_KEY]["data"]["items"][0]["name"] == "Home Assistant Cloud" + ) + assert hass_storage[STORAGE_KEY]["data"]["items"][0]["stt_language"] == "language_1" + assert hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_language"] == "language_1" + assert ( + hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_voice"] + == "Arnold Schwarzenegger" + ) + assert hass_storage[STORAGE_KEY]["data"]["items"][0]["wake_word_entity"] is None + assert hass_storage[STORAGE_KEY]["data"]["items"][0]["wake_word_id"] is None + assert hass_storage[STORAGE_KEY]["data"]["items"][1] == PIPELINE_DATA["items"][1] + assert hass_storage[STORAGE_KEY]["data"]["items"][2] == PIPELINE_DATA["items"][2] From e4a9f616b6bf67e43f8c0fc6f0fc0d5810a50ecb Mon Sep 17 00:00:00 2001 From: Martin Hjelmare Date: Mon, 22 Jan 2024 12:08:54 +0100 Subject: [PATCH 5/6] Make migration more strict --- homeassistant/components/cloud/assist_pipeline.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/homeassistant/components/cloud/assist_pipeline.py b/homeassistant/components/cloud/assist_pipeline.py index 0d4ec057e65571..4ababed7e54f7e 100644 --- a/homeassistant/components/cloud/assist_pipeline.py +++ b/homeassistant/components/cloud/assist_pipeline.py @@ -81,13 +81,12 @@ async def async_migrate_cloud_pipeline_engine( # We need to make sure that that both stt and tts are loaded before this migration. # Assist pipeline will call default engine when setting up the store. # Wait for the stt or tts platform loaded event here. - kwargs: dict[str, str] = {} if platform == Platform.STT: wait_for_platform = Platform.TTS - kwargs["stt_engine"] = engine_id + pipeline_attribute = "stt_engine" elif platform == Platform.TTS: wait_for_platform = Platform.STT - kwargs["tts_engine"] = engine_id + pipeline_attribute = "tts_engine" else: raise ValueError(f"Invalid platform {platform}") @@ -98,7 +97,8 @@ async def async_migrate_cloud_pipeline_engine( # is an after dependency of cloud await async_setup_pipeline_store(hass) + kwargs: dict[str, str] = {pipeline_attribute: engine_id} pipelines = async_get_pipelines(hass) for pipeline in pipelines: - if DOMAIN in (pipeline.stt_engine, pipeline.tts_engine): + if getattr(pipeline, pipeline_attribute) == DOMAIN: await async_update_pipeline(hass, pipeline, **kwargs) From a71aa3710d667fa44ca94e16de84243a22f82f4e Mon Sep 17 00:00:00 2001 From: Martin Hjelmare Date: Mon, 22 Jan 2024 12:13:00 +0100 Subject: [PATCH 6/6] Fix docstring --- homeassistant/components/cloud/assist_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/homeassistant/components/cloud/assist_pipeline.py b/homeassistant/components/cloud/assist_pipeline.py index 4ababed7e54f7e..2c381dd0ac0636 100644 --- a/homeassistant/components/cloud/assist_pipeline.py +++ b/homeassistant/components/cloud/assist_pipeline.py @@ -78,7 +78,7 @@ async def async_migrate_cloud_pipeline_engine( # Migrate existing pipelines with cloud stt or tts to use new cloud engine id. # Added in 2024.02.0. Can be removed in 2025.02.0. - # We need to make sure that that both stt and tts are loaded before this migration. + # We need to make sure that both stt and tts are loaded before this migration. # Assist pipeline will call default engine when setting up the store. # Wait for the stt or tts platform loaded event here. if platform == Platform.STT: