Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't return TTS URL in Assist pipeline #105164

Merged
merged 2 commits into from Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 5 additions & 5 deletions homeassistant/components/assist_pipeline/pipeline.py
Expand Up @@ -9,7 +9,7 @@
from enum import StrEnum
import logging
from pathlib import Path
from queue import Queue
from queue import Empty, Queue
from threading import Thread
import time
from typing import TYPE_CHECKING, Any, Final, cast
Expand Down Expand Up @@ -1010,8 +1010,8 @@ async def prepare_text_to_speech(self) -> None:
self.tts_engine = engine
self.tts_options = tts_options

async def text_to_speech(self, tts_input: str) -> str:
"""Run text-to-speech portion of pipeline. Returns URL of TTS audio."""
async def text_to_speech(self, tts_input: str) -> None:
"""Run text-to-speech portion of pipeline."""
self.process_event(
PipelineEvent(
PipelineEventType.TTS_START,
Expand Down Expand Up @@ -1058,8 +1058,6 @@ async def text_to_speech(self, tts_input: str) -> str:
PipelineEvent(PipelineEventType.TTS_END, {"tts_output": tts_output})
)

return tts_media.url

def _capture_chunk(self, audio_bytes: bytes | None) -> None:
"""Forward audio chunk to various capturing mechanisms."""
if self.debug_recording_queue is not None:
Expand Down Expand Up @@ -1246,6 +1244,8 @@ def _pipeline_debug_recording_thread_proc(
# Chunk of 16-bit mono audio at 16Khz
if wav_writer is not None:
wav_writer.writeframes(message)
except Empty:
pass # occurs when pipeline has unexpected error
except Exception: # pylint: disable=broad-exception-caught
_LOGGER.exception("Unexpected error in debug recording thread")
finally:
Expand Down
64 changes: 64 additions & 0 deletions tests/components/assist_pipeline/test_init.py
@@ -1,4 +1,5 @@
"""Test Voice Assistant init."""
import asyncio
from dataclasses import asdict
import itertools as it
from pathlib import Path
Expand Down Expand Up @@ -569,6 +570,69 @@ async def audio_data():
)


async def test_pipeline_saved_audio_empty_queue(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,
mock_wake_word_provider_entity: MockWakeWordEntity,
init_supporting_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test that saved audio thread closes WAV file even if there's an empty queue."""
with tempfile.TemporaryDirectory() as temp_dir_str:
# Enable audio recording to temporary directory
temp_dir = Path(temp_dir_str)
assert await async_setup_component(
hass,
DOMAIN,
{DOMAIN: {CONF_DEBUG_RECORDING_DIR: temp_dir_str}},
)

def event_callback(event: assist_pipeline.PipelineEvent):
if event.type == "run-end":
# Verify WAV file exists, but contains no data
pipeline_dirs = list(temp_dir.iterdir())
run_dirs = list(pipeline_dirs[0].iterdir())
wav_path = next(run_dirs[0].iterdir())
with wave.open(str(wav_path), "rb") as wav_file:
assert wav_file.getnframes() == 0

async def audio_data():
# Force timeout in _pipeline_debug_recording_thread_proc
await asyncio.sleep(1)
yield b"not used"

# Wrap original function to time out immediately
_pipeline_debug_recording_thread_proc = (
assist_pipeline.pipeline._pipeline_debug_recording_thread_proc
)

def proc_wrapper(run_recording_dir, queue):
_pipeline_debug_recording_thread_proc(
run_recording_dir, queue, message_timeout=0
)

with patch(
"homeassistant.components.assist_pipeline.pipeline._pipeline_debug_recording_thread_proc",
proc_wrapper,
):
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
context=Context(),
event_callback=event_callback,
stt_metadata=stt.SpeechMetadata(
language="",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_data(),
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
end_stage=assist_pipeline.PipelineStage.STT,
)


async def test_wake_word_detection_aborted(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,
Expand Down