Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatically convert TTS audio to MP3 on demand #102814

Merged
merged 16 commits into from Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 6 additions & 2 deletions homeassistant/components/assist_pipeline/pipeline.py
Expand Up @@ -971,12 +971,16 @@ async def prepare_text_to_speech(self) -> None:
# pipeline.tts_engine can't be None or this function is not called
engine = cast(str, self.pipeline.tts_engine)

tts_options = {}
tts_options: dict[str, Any] = {}
if self.pipeline.tts_voice is not None:
tts_options[tts.ATTR_VOICE] = self.pipeline.tts_voice

if self.tts_audio_output is not None:
tts_options[tts.ATTR_AUDIO_OUTPUT] = self.tts_audio_output
tts_options[tts.ATTR_PREFERRED_FORMAT] = self.tts_audio_output
if self.tts_audio_output == "wav":
# 16 Khz, 16-bit mono
tts_options[tts.ATTR_PREFERRED_SAMPLE_RATE] = 16000
tts_options[tts.ATTR_PREFERRED_SAMPLE_CHANNELS] = 1

try:
options_supported = await tts.async_support_options(
balloob marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/components/cloud/tts.py
Expand Up @@ -150,4 +150,4 @@
_LOGGER.error("Voice error: %s", err)
return (None, None)

return (str(options[ATTR_AUDIO_OUTPUT]), data)
return (str(options[ATTR_AUDIO_OUTPUT].value), data)

Check warning on line 153 in homeassistant/components/cloud/tts.py

View check run for this annotation

Codecov / codecov/patch

homeassistant/components/cloud/tts.py#L153

Added line #L153 was not covered by tests
31 changes: 28 additions & 3 deletions homeassistant/components/esphome/voice_assistant.py
Expand Up @@ -3,9 +3,11 @@

import asyncio
from collections.abc import AsyncIterable, Callable
import io
import logging
import socket
from typing import cast
import wave

from aioesphomeapi import (
VoiceAssistantAudioSettings,
Expand Down Expand Up @@ -88,6 +90,7 @@ def __init__(
self.handle_event = handle_event
self.handle_finished = handle_finished
self._tts_done = asyncio.Event()
self._tts_task: asyncio.Task | None = None

async def start_server(self) -> int:
"""Start accepting connections."""
Expand Down Expand Up @@ -189,7 +192,7 @@ def _event_callback(self, event: PipelineEvent) -> None:

if self.device_info.voice_assistant_version >= 2:
media_id = event.data["tts_output"]["media_id"]
self.hass.async_create_background_task(
self._tts_task = self.hass.async_create_background_task(
self._send_tts(media_id), "esphome_voice_assistant_tts"
)
else:
Expand Down Expand Up @@ -228,7 +231,7 @@ async def run_pipeline(
audio_settings = VoiceAssistantAudioSettings()

tts_audio_output = (
"raw" if self.device_info.voice_assistant_version >= 2 else "mp3"
"wav" if self.device_info.voice_assistant_version >= 2 else "mp3"
)

_LOGGER.debug("Starting pipeline")
Expand Down Expand Up @@ -302,11 +305,32 @@ async def _send_tts(self, media_id: str) -> None:
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {}
)

_extension, audio_bytes = await tts.async_get_media_source_audio(
extension, data = await tts.async_get_media_source_audio(
self.hass,
media_id,
)

if extension != "wav":
raise ValueError(f"Only WAV audio can be streamed, got {extension}")

with io.BytesIO(data) as wav_io:
with wave.open(wav_io, "rb") as wav_file:
balloob marked this conversation as resolved.
Show resolved Hide resolved
sample_rate = wav_file.getframerate()
sample_width = wav_file.getsampwidth()
sample_channels = wav_file.getnchannels()

if (
(sample_rate != 16000)
or (sample_width != 2)
or (sample_channels != 1)
):
raise ValueError(
"Expected rate/width/channels as 16000/2/1,"
" got {sample_rate}/{sample_width}/{sample_channels}}"
)

audio_bytes = wav_file.readframes(wav_file.getnframes())

_LOGGER.debug("Sending %d bytes of audio", len(audio_bytes))

bytes_per_sample = stt.AudioBitRates.BITRATE_16 // 8
Expand All @@ -330,4 +354,5 @@ async def _send_tts(self, media_id: str) -> None:
self.handle_event(
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, {}
)
self._tts_task = None
self._tts_done.set()
148 changes: 129 additions & 19 deletions homeassistant/components/tts/__init__.py
Expand Up @@ -13,14 +13,16 @@
import mimetypes
import os
import re
import subprocess
import tempfile
from typing import Any, TypedDict, final

from aiohttp import web
import mutagen
from mutagen.id3 import ID3, TextFrame as ID3Text
import voluptuous as vol

from homeassistant.components import websocket_api
from homeassistant.components import ffmpeg, websocket_api
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.media_player import (
ATTR_MEDIA_ANNOUNCE,
Expand Down Expand Up @@ -72,11 +74,15 @@
"async_get_media_source_audio",
"async_support_options",
"ATTR_AUDIO_OUTPUT",
"ATTR_PREFERRED_FORMAT",
"ATTR_PREFERRED_SAMPLE_RATE",
"ATTR_PREFERRED_SAMPLE_CHANNELS",
"CONF_LANG",
"DEFAULT_CACHE_DIR",
"generate_media_source_id",
"PLATFORM_SCHEMA_BASE",
"PLATFORM_SCHEMA",
"SampleFormat",
"Provider",
"TtsAudioType",
"Voice",
Expand All @@ -86,6 +92,9 @@

ATTR_PLATFORM = "platform"
ATTR_AUDIO_OUTPUT = "audio_output"
ATTR_PREFERRED_FORMAT = "preferred_format"
ATTR_PREFERRED_SAMPLE_RATE = "preferred_sample_rate"
ATTR_PREFERRED_SAMPLE_CHANNELS = "preferred_sample_channels"
ATTR_MEDIA_PLAYER_ENTITY_ID = "media_player_entity_id"
ATTR_VOICE = "voice"

Expand Down Expand Up @@ -199,6 +208,83 @@ def async_get_text_to_speech_languages(hass: HomeAssistant) -> set[str]:
return languages


async def async_convert_audio(
hass: HomeAssistant,
from_extension: str,
audio_bytes: bytes,
to_extension: str,
to_sample_rate: int | None = None,
to_sample_channels: int | None = None,
) -> bytes:
"""Convert audio to a preferred format using ffmpeg."""
ffmpeg_manager = ffmpeg.get_ffmpeg_manager(hass)
return await hass.async_add_executor_job(
lambda: _convert_audio(
ffmpeg_manager.binary,
from_extension,
audio_bytes,
to_extension,
to_sample_rate=to_sample_rate,
to_sample_channels=to_sample_channels,
)
)


def _convert_audio(
ffmpeg_binary: str,
from_extension: str,
audio_bytes: bytes,
to_extension: str,
to_sample_rate: int | None = None,
to_sample_channels: int | None = None,
) -> bytes:
"""Convert audio to a preferred format using ffmpeg."""

# We have to use a temporary file here because some formats like WAV store
# the length of the file in the header, and therefore cannot be written in a
# streaming fashion.
with tempfile.NamedTemporaryFile(
synesthesiam marked this conversation as resolved.
Show resolved Hide resolved
mode="wb+", suffix=f".{to_extension}"
) as output_file:
# input
command = [
ffmpeg_binary,
"-y", # overwrite temp file
"-f",
from_extension,
"-i",
"pipe:", # input from stdin
]

# output
command.extend(["-f", to_extension])

if to_sample_rate is not None:
command.extend(["-ar", str(to_sample_rate)])

if to_sample_channels is not None:
command.extend(["-ac", str(to_sample_channels)])

if to_extension == "mp3":
synesthesiam marked this conversation as resolved.
Show resolved Hide resolved
# Max quality for MP3
command.extend(["-q:a", "0"])

command.append(output_file.name)

with subprocess.Popen(
command, stdin=subprocess.PIPE, stderr=subprocess.PIPE
) as proc:
_stdout, stderr = proc.communicate(input=audio_bytes)
if proc.returncode != 0:
_LOGGER.error(stderr.decode())
raise RuntimeError(
f"Unexpected error while running ffmpeg with arguments: {command}. See log for details."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please break long strings around max 88 characters per line.

)

output_file.seek(0)
return output_file.read()


async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up TTS."""
websocket_api.async_register_command(hass, websocket_list_engines)
Expand Down Expand Up @@ -482,7 +568,18 @@ def process_options(
merged_options = dict(engine_instance.default_options or {})
merged_options.update(options or {})

supported_options = engine_instance.supported_options or []
supported_options = list(engine_instance.supported_options or [])

# ATTR_PREFERRED_* options are always "supported" since they're used to
# convert audio after the TTS has run (if necessary).
supported_options.extend(
(
ATTR_PREFERRED_FORMAT,
ATTR_PREFERRED_SAMPLE_RATE,
ATTR_PREFERRED_SAMPLE_CHANNELS,
)
)

invalid_opts = [
opt_name for opt_name in merged_options if opt_name not in supported_options
]
Expand Down Expand Up @@ -520,12 +617,7 @@ async def async_get_url_path(
# Load speech from engine into memory
else:
filename = await self._async_get_tts_audio(
engine_instance,
cache_key,
message,
use_cache,
language,
options,
engine_instance, cache_key, message, use_cache, language, options
)

return f"/api/tts_proxy/{filename}"
Expand Down Expand Up @@ -590,10 +682,10 @@ async def _async_get_tts_audio(

This method is a coroutine.
"""
if options is not None and ATTR_AUDIO_OUTPUT in options:
expected_extension = options[ATTR_AUDIO_OUTPUT]
else:
expected_extension = None
options = options or {}

# Default to MP3 unless a different format is preferred
final_extension = options.get(ATTR_PREFERRED_FORMAT, "mp3")

async def get_tts_data() -> str:
"""Handle data available."""
Expand All @@ -614,8 +706,27 @@ async def get_tts_data() -> str:
f"No TTS from {engine_instance.name} for '{message}'"
)

# Only convert if we have a preferred format different than the
# expected format from the TTS system, or if a specific sample
# rate/format/channel count is requested.
needs_conversion = (
(final_extension != extension)
or (ATTR_PREFERRED_SAMPLE_RATE in options)
or (ATTR_PREFERRED_SAMPLE_CHANNELS in options)
)

if needs_conversion:
data = await async_convert_audio(
self.hass,
extension,
data,
to_extension=final_extension,
to_sample_rate=options.get(ATTR_PREFERRED_SAMPLE_RATE),
to_sample_channels=options.get(ATTR_PREFERRED_SAMPLE_CHANNELS),
)

# Create file infos
filename = f"{cache_key}.{extension}".lower()
filename = f"{cache_key}.{final_extension}".lower()

# Validate filename
if not _RE_VOICE_FILE.match(filename) and not _RE_LEGACY_VOICE_FILE.match(
Expand All @@ -626,10 +737,11 @@ async def get_tts_data() -> str:
)

# Save to memory
if extension == "mp3":
if final_extension == "mp3":
data = self.write_tags(
filename, data, engine_instance.name, message, language, options
)

self._async_store_to_memcache(cache_key, filename, data)

if cache:
Expand All @@ -641,17 +753,14 @@ async def get_tts_data() -> str:

audio_task = self.hass.async_create_task(get_tts_data())

if expected_extension is None:
return await audio_task

def handle_error(_future: asyncio.Future) -> None:
"""Handle error."""
if audio_task.exception():
self.mem_cache.pop(cache_key, None)

audio_task.add_done_callback(handle_error)

filename = f"{cache_key}.{expected_extension}".lower()
filename = f"{cache_key}.{final_extension}".lower()
self.mem_cache[cache_key] = {
"filename": filename,
"voice": b"",
Expand Down Expand Up @@ -747,11 +856,12 @@ async def async_read_tts(self, filename: str) -> tuple[str | None, bytes]:
raise HomeAssistantError(f"{cache_key} not in cache!")
await self._async_file_to_mem(cache_key)

content, _ = mimetypes.guess_type(filename)
cached = self.mem_cache[cache_key]
if pending := cached.get("pending"):
await pending
cached = self.mem_cache[cache_key]

content, _ = mimetypes.guess_type(filename)
return content, cached["voice"]

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/components/tts/manifest.json
Expand Up @@ -3,7 +3,7 @@
"name": "Text-to-speech (TTS)",
"after_dependencies": ["media_player"],
"codeowners": ["@home-assistant/core", "@pvizeli"],
"dependencies": ["http"],
"dependencies": ["http", "ffmpeg"],
"documentation": "https://www.home-assistant.io/integrations/tts",
"integration_type": "entity",
"loggers": ["mutagen"],
Expand Down