Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cloud tts entity #108293

Merged
merged 6 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions homeassistant/components/cloud/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@

DEFAULT_MODE = MODE_PROD

PLATFORMS = [Platform.BINARY_SENSOR, Platform.STT]
PLATFORMS = [Platform.BINARY_SENSOR, Platform.STT, Platform.TTS]

SERVICE_REMOTE_CONNECT = "remote_connect"
SERVICE_REMOTE_DISCONNECT = "remote_disconnect"
Expand Down Expand Up @@ -288,9 +288,11 @@ async def async_startup_repairs(_: datetime) -> None:
loaded = False
stt_platform_loaded = asyncio.Event()
tts_platform_loaded = asyncio.Event()
stt_tts_entities_added = asyncio.Event()
hass.data[DATA_PLATFORMS_SETUP] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should really be a dataclass instance instead. OK to do in a follow-up if it adds too much noise to this PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'll clean it up in the next PR.

Platform.STT: stt_platform_loaded,
Platform.TTS: tts_platform_loaded,
"stt_tts_entities_added": stt_tts_entities_added,
}

async def _on_start() -> None:
Expand Down Expand Up @@ -330,6 +332,7 @@ async def _on_initialized() -> None:

account_link.async_setup(hass)

# Load legacy tts platform for backwards compatibility.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How long do we need to maintain this compatibility? If it can be removed 2025.2 it makes sense to mention that also in this comment.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We haven't decided if and when we want to deprecate it.

hass.async_create_task(
async_load_platform(
hass,
Expand Down Expand Up @@ -377,8 +380,10 @@ async def remote_prefs_updated(prefs: CloudPreferences) -> None:
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up a config entry."""
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
stt_platform_loaded: asyncio.Event = hass.data[DATA_PLATFORMS_SETUP][Platform.STT]
stt_platform_loaded.set()
stt_tts_entities_added: asyncio.Event = hass.data[DATA_PLATFORMS_SETUP][
"stt_tts_entities_added"
]
stt_tts_entities_added.set()
MartinHjelmare marked this conversation as resolved.
Show resolved Hide resolved

return True

Expand Down
55 changes: 37 additions & 18 deletions homeassistant/components/cloud/assist_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,23 @@
)
from homeassistant.components.conversation import HOME_ASSISTANT_AGENT
from homeassistant.components.stt import DOMAIN as STT_DOMAIN
from homeassistant.components.tts import DOMAIN as TTS_DOMAIN
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant
import homeassistant.helpers.entity_registry as er

from .const import DATA_PLATFORMS_SETUP, DOMAIN, STT_ENTITY_UNIQUE_ID
from .const import (
DATA_PLATFORMS_SETUP,
DOMAIN,
STT_ENTITY_UNIQUE_ID,
TTS_ENTITY_UNIQUE_ID,
)


async def async_create_cloud_pipeline(hass: HomeAssistant) -> str | None:
"""Create a cloud assist pipeline."""
# Wait for stt and tts platforms to set up before creating the pipeline.
# Wait for stt and tts platforms to set up and entities to be added
# before creating the pipeline.
platforms_setup: dict[str, asyncio.Event] = hass.data[DATA_PLATFORMS_SETUP]
await asyncio.gather(*(event.wait() for event in platforms_setup.values()))
# Make sure the pipeline store is loaded, needed because assist_pipeline
Expand All @@ -29,8 +36,11 @@ async def async_create_cloud_pipeline(hass: HomeAssistant) -> str | None:
new_stt_engine_id = entity_registry.async_get_entity_id(
STT_DOMAIN, DOMAIN, STT_ENTITY_UNIQUE_ID
)
if new_stt_engine_id is None:
# If there's no cloud stt entity, we can't create a cloud pipeline.
new_tts_engine_id = entity_registry.async_get_entity_id(
TTS_DOMAIN, DOMAIN, TTS_ENTITY_UNIQUE_ID
)
if new_stt_engine_id is None or new_tts_engine_id is None:
# If there's no cloud stt or tts entity, we can't create a cloud pipeline.
return None

def cloud_assist_pipeline(hass: HomeAssistant) -> str | None:
Expand All @@ -43,7 +53,7 @@ def cloud_assist_pipeline(hass: HomeAssistant) -> str | None:
if (
pipeline.conversation_engine == HOME_ASSISTANT_AGENT
and pipeline.stt_engine in (DOMAIN, new_stt_engine_id)
and pipeline.tts_engine == DOMAIN
and pipeline.tts_engine in (DOMAIN, new_tts_engine_id)
):
return pipeline.id
return None
Expand All @@ -52,7 +62,7 @@ def cloud_assist_pipeline(hass: HomeAssistant) -> str | None:
cloud_pipeline := await async_create_default_pipeline(
hass,
stt_engine_id=new_stt_engine_id,
tts_engine_id=DOMAIN,
tts_engine_id=new_tts_engine_id,
pipeline_name="Home Assistant Cloud",
)
) is None:
Expand All @@ -61,25 +71,34 @@ def cloud_assist_pipeline(hass: HomeAssistant) -> str | None:
return cloud_pipeline.id


async def async_migrate_cloud_pipeline_stt_engine(
hass: HomeAssistant, stt_engine_id: str
async def async_migrate_cloud_pipeline_engine(
hass: HomeAssistant, platform: Platform, engine_id: str
) -> None:
"""Migrate the speech-to-text engine in the cloud assist pipeline."""
# Migrate existing pipelines with cloud stt to use new cloud stt engine id.
# Added in 2024.01.0. Can be removed in 2025.01.0.
"""Migrate the pipeline engines in the cloud assist pipeline."""
# Migrate existing pipelines with cloud stt or tts to use new cloud engine id.
# Added in 2024.02.0. Can be removed in 2025.02.0.

# We need to make sure that both stt and tts are loaded before this migration.
# Assist pipeline will call default engine when setting up the store.
# Wait for the stt or tts platform loaded event here.
if platform == Platform.STT:
wait_for_platform = Platform.TTS
pipeline_attribute = "stt_engine"
elif platform == Platform.TTS:
wait_for_platform = Platform.STT
pipeline_attribute = "tts_engine"
else:
raise ValueError(f"Invalid platform {platform}")

# We need to make sure that tts is loaded before this migration.
# Assist pipeline will call default engine of tts when setting up the store.
# Wait for the tts platform loaded event here.
platforms_setup: dict[str, asyncio.Event] = hass.data[DATA_PLATFORMS_SETUP]
await platforms_setup[Platform.TTS].wait()
await platforms_setup[wait_for_platform].wait()

# Make sure the pipeline store is loaded, needed because assist_pipeline
# is an after dependency of cloud
await async_setup_pipeline_store(hass)

kwargs: dict[str, str] = {pipeline_attribute: engine_id}
pipelines = async_get_pipelines(hass)
for pipeline in pipelines:
if pipeline.stt_engine != DOMAIN:
continue
await async_update_pipeline(hass, pipeline, stt_engine=stt_engine_id)
if getattr(pipeline, pipeline_attribute) == DOMAIN:
await async_update_pipeline(hass, pipeline, **kwargs)
1 change: 1 addition & 0 deletions homeassistant/components/cloud/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,4 @@
DISPATCHER_REMOTE_UPDATE: SignalType[Any] = SignalType("cloud_remote_update")

STT_ENTITY_UNIQUE_ID = "cloud-speech-to-text"
TTS_ENTITY_UNIQUE_ID = "cloud-text-to-speech"
10 changes: 9 additions & 1 deletion homeassistant/components/cloud/prefs.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,18 @@ async def async_initialize(self) -> None:
@callback
def async_listen_updates(
self, listener: Callable[[CloudPreferences], Coroutine[Any, Any, None]]
) -> None:
) -> Callable[[], None]:
"""Listen for updates to the preferences."""

@callback
def unsubscribe() -> None:
"""Remove the listener."""
self._listeners.remove(listener)

self._listeners.append(listener)

return unsubscribe

async def async_update(
self,
*,
Expand Down
16 changes: 11 additions & 5 deletions homeassistant/components/cloud/stt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Support for the cloud for speech to text service."""
from __future__ import annotations

import asyncio
from collections.abc import AsyncIterable
import logging

Expand All @@ -19,12 +20,13 @@
SpeechToTextEntity,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback

from .assist_pipeline import async_migrate_cloud_pipeline_stt_engine
from .assist_pipeline import async_migrate_cloud_pipeline_engine
from .client import CloudClient
from .const import DOMAIN, STT_ENTITY_UNIQUE_ID
from .const import DATA_PLATFORMS_SETUP, DOMAIN, STT_ENTITY_UNIQUE_ID

_LOGGER = logging.getLogger(__name__)

Expand All @@ -35,18 +37,20 @@ async def async_setup_entry(
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up Home Assistant Cloud speech platform via config entry."""
stt_platform_loaded: asyncio.Event = hass.data[DATA_PLATFORMS_SETUP][Platform.STT]
stt_platform_loaded.set()
cloud: Cloud[CloudClient] = hass.data[DOMAIN]
async_add_entities([CloudProviderEntity(cloud)])


class CloudProviderEntity(SpeechToTextEntity):
"""NabuCasa speech API provider."""
"""Home Assistant Cloud speech API provider."""

_attr_name = "Home Assistant Cloud"
_attr_unique_id = STT_ENTITY_UNIQUE_ID

def __init__(self, cloud: Cloud[CloudClient]) -> None:
"""Home Assistant NabuCasa Speech to text."""
"""Initialize cloud Speech to text entity."""
self.cloud = cloud

@property
Expand Down Expand Up @@ -81,7 +85,9 @@ def supported_channels(self) -> list[AudioChannels]:

async def async_added_to_hass(self) -> None:
"""Run when entity is about to be added to hass."""
await async_migrate_cloud_pipeline_stt_engine(self.hass, self.entity_id)
await async_migrate_cloud_pipeline_engine(
self.hass, platform=Platform.STT, engine_id=self.entity_id
)

async def async_process_audio_stream(
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
Expand Down
103 changes: 98 additions & 5 deletions homeassistant/components/cloud/tts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Support for the cloud for text-to-speech service."""
from __future__ import annotations

import asyncio
import logging
from typing import Any

Expand All @@ -12,16 +13,21 @@
ATTR_AUDIO_OUTPUT,
ATTR_VOICE,
CONF_LANG,
PLATFORM_SCHEMA,
PLATFORM_SCHEMA as TTS_PLATFORM_SCHEMA,
Provider,
TextToSpeechEntity,
TtsAudioType,
Voice,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType

from .assist_pipeline import async_migrate_cloud_pipeline_engine
from .client import CloudClient
from .const import DOMAIN
from .const import DATA_PLATFORMS_SETUP, DOMAIN, TTS_ENTITY_UNIQUE_ID
from .prefs import CloudPreferences

ATTR_GENDER = "gender"
Expand All @@ -48,7 +54,7 @@ def validate_lang(value: dict[str, Any]) -> dict[str, Any]:


PLATFORM_SCHEMA = vol.All(
PLATFORM_SCHEMA.extend(
TTS_PLATFORM_SCHEMA.extend(
{
vol.Optional(CONF_LANG): str,
vol.Optional(ATTR_GENDER): str,
Expand Down Expand Up @@ -81,8 +87,95 @@ async def async_get_engine(
return cloud_provider


async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up Home Assistant Cloud text-to-speech platform."""
tts_platform_loaded: asyncio.Event = hass.data[DATA_PLATFORMS_SETUP][Platform.TTS]
tts_platform_loaded.set()
cloud: Cloud[CloudClient] = hass.data[DOMAIN]
async_add_entities([CloudTTSEntity(cloud)])


class CloudTTSEntity(TextToSpeechEntity):
"""Home Assistant Cloud text-to-speech entity."""

_attr_name = "Home Assistant Cloud"
_attr_unique_id = TTS_ENTITY_UNIQUE_ID

def __init__(self, cloud: Cloud[CloudClient]) -> None:
"""Initialize cloud text-to-speech entity."""
self.cloud = cloud
self._language, self._gender = cloud.client.prefs.tts_default_voice

async def _sync_prefs(self, prefs: CloudPreferences) -> None:
"""Sync preferences."""
self._language, self._gender = prefs.tts_default_voice

@property
def default_language(self) -> str:
"""Return the default language."""
return self._language

@property
def default_options(self) -> dict[str, Any]:
"""Return a dict include default options."""
return {
ATTR_GENDER: self._gender,
ATTR_AUDIO_OUTPUT: AudioOutput.MP3,
}

@property
def supported_languages(self) -> list[str]:
"""Return list of supported languages."""
return SUPPORT_LANGUAGES

@property
def supported_options(self) -> list[str]:
"""Return list of supported options like voice, emotion."""
return [ATTR_GENDER, ATTR_VOICE, ATTR_AUDIO_OUTPUT]

async def async_added_to_hass(self) -> None:
"""Handle entity which will be added."""
await super().async_added_to_hass()
await async_migrate_cloud_pipeline_engine(
self.hass, platform=Platform.TTS, engine_id=self.entity_id
)
self.async_on_remove(
self.cloud.client.prefs.async_listen_updates(self._sync_prefs)
)

@callback
def async_get_supported_voices(self, language: str) -> list[Voice] | None:
"""Return a list of supported voices for a language."""
if not (voices := TTS_VOICES.get(language)):
return None
return [Voice(voice, voice) for voice in voices]

async def async_get_tts_audio(
self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType:
"""Load TTS from Home Assistant Cloud."""
# Process TTS
try:
data = await self.cloud.voice.process_tts(
text=message,
language=language,
gender=options.get(ATTR_GENDER),
voice=options.get(ATTR_VOICE),
output=options[ATTR_AUDIO_OUTPUT],
)
except VoiceError as err:
_LOGGER.error("Voice error: %s", err)
return (None, None)

return (str(options[ATTR_AUDIO_OUTPUT].value), data)


class CloudProvider(Provider):
"""NabuCasa Cloud speech API provider."""
"""Home Assistant Cloud speech API provider."""

def __init__(
self, cloud: Cloud[CloudClient], language: str | None, gender: str | None
Expand Down Expand Up @@ -136,7 +229,7 @@ def default_options(self) -> dict[str, Any]:
async def async_get_tts_audio(
self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType:
"""Load TTS from NabuCasa Cloud."""
"""Load TTS from Home Assistant Cloud."""
# Process TTS
try:
data = await self.cloud.voice.process_tts(
Expand Down
Loading