Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add websocket command to capture audio from a device #103936

Merged
merged 7 commits into from Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 8 additions & 1 deletion homeassistant/components/assist_pipeline/__init__.py
Expand Up @@ -9,7 +9,13 @@
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers.typing import ConfigType

from .const import CONF_DEBUG_RECORDING_DIR, DATA_CONFIG, DATA_LAST_WAKE_UP, DOMAIN
from .const import (
CONF_DEBUG_RECORDING_DIR,
DATA_CONFIG,
DATA_LAST_WAKE_UP,
DOMAIN,
EVENT_RECORDING,
)
from .error import PipelineNotFound
from .pipeline import (
AudioSettings,
Expand Down Expand Up @@ -40,6 +46,7 @@
"PipelineEventType",
"PipelineNotFound",
"WakeWordSettings",
"EVENT_RECORDING",
)

CONFIG_SCHEMA = vol.Schema(
Expand Down
2 changes: 2 additions & 0 deletions homeassistant/components/assist_pipeline/const.py
Expand Up @@ -11,3 +11,5 @@

DATA_LAST_WAKE_UP = f"{DOMAIN}.last_wake_up"
DEFAULT_WAKE_WORD_COOLDOWN = 2 # seconds

EVENT_RECORDING = f"{DOMAIN}_recording"
40 changes: 40 additions & 0 deletions homeassistant/components/assist_pipeline/logbook.py
@@ -0,0 +1,40 @@
"""Describe assist_pipeline logbook events."""
from __future__ import annotations

from collections.abc import Callable

from homeassistant.components.logbook import LOGBOOK_ENTRY_MESSAGE, LOGBOOK_ENTRY_NAME
from homeassistant.const import ATTR_DEVICE_ID, ATTR_SECONDS
from homeassistant.core import Event, HomeAssistant, callback
import homeassistant.helpers.device_registry as dr

from .const import DOMAIN, EVENT_RECORDING


@callback
def async_describe_events(
hass: HomeAssistant,
async_describe_event: Callable[[str, str, Callable[[Event], dict[str, str]]], None],
) -> None:
"""Describe logbook events."""
device_registry = dr.async_get(hass)

@callback
def async_describe_logbook_event(event: Event) -> dict[str, str]:
"""Describe logbook event."""
device: dr.DeviceEntry | None = None
device_name: str = "Unknown device"

device = device_registry.devices[event.data[ATTR_DEVICE_ID]]
if device:
device_name = device.name_by_user or device.name or "Unknown device"

timeout_seconds = event.data[ATTR_SECONDS]
message = f"{device_name} will record audio for {timeout_seconds} second(s)"
synesthesiam marked this conversation as resolved.
Show resolved Hide resolved

return {
LOGBOOK_ENTRY_NAME: device_name,
LOGBOOK_ENTRY_MESSAGE: message,
}

async_describe_event(DOMAIN, EVENT_RECORDING, async_describe_logbook_event)
62 changes: 51 additions & 11 deletions homeassistant/components/assist_pipeline/pipeline.py
Expand Up @@ -503,6 +503,9 @@ class PipelineRun:
audio_processor_buffer: AudioBuffer = field(init=False, repr=False)
"""Buffer used when splitting audio into chunks for audio processing"""

_device_id: str | None = None
"""Optional device id set during run start."""

def __post_init__(self) -> None:
"""Set language for pipeline."""
self.language = self.pipeline.language or self.hass.config.language
Expand Down Expand Up @@ -554,7 +557,8 @@ def process_event(self, event: PipelineEvent) -> None:

def start(self, device_id: str | None) -> None:
"""Emit run start event."""
self._start_debug_recording_thread(device_id)
self._device_id = device_id
self._start_debug_recording_thread()

data = {
"pipeline": self.pipeline.id,
Expand All @@ -567,6 +571,9 @@ def start(self, device_id: str | None) -> None:

async def end(self) -> None:
"""Emit run end event."""
# Signal end of stream to listeners
self._capture_chunk(None)

# Stop the recording thread before emitting run-end.
# This ensures that files are properly closed if the event handler reads them.
await self._stop_debug_recording_thread()
Expand Down Expand Up @@ -746,9 +753,7 @@ async def _wake_word_audio_stream(
if self.abort_wake_word_detection:
raise WakeWordDetectionAborted

if self.debug_recording_queue is not None:
self.debug_recording_queue.put_nowait(chunk.audio)

self._capture_chunk(chunk.audio)
yield chunk.audio, chunk.timestamp_ms

# Wake-word-detection occurs *after* the wake word was actually
Expand Down Expand Up @@ -870,8 +875,7 @@ async def _speech_to_text_stream(
chunk_seconds = AUDIO_PROCESSOR_SAMPLES / sample_rate
sent_vad_start = False
async for chunk in audio_stream:
if self.debug_recording_queue is not None:
self.debug_recording_queue.put_nowait(chunk.audio)
self._capture_chunk(chunk.audio)

if stt_vad is not None:
if not stt_vad.process(chunk_seconds, chunk.is_speech):
Expand Down Expand Up @@ -1057,7 +1061,28 @@ async def text_to_speech(self, tts_input: str) -> str:

return tts_media.url

def _start_debug_recording_thread(self, device_id: str | None) -> None:
def _capture_chunk(self, audio_bytes: bytes | None) -> None:
"""Forward audio chunk to various capturing mechanisms."""
if self.debug_recording_queue is not None:
# Forward to debug WAV file recording
self.debug_recording_queue.put_nowait(audio_bytes)

if self._device_id is None:
return

# Forward to device audio capture
pipeline_data: PipelineData = self.hass.data[DOMAIN]
audio_queue = pipeline_data.device_audio_queues.get(self._device_id)
if audio_queue is None:
return

try:
audio_queue.queue.put_nowait(audio_bytes)
except asyncio.QueueFull:
audio_queue.overflow = True
_LOGGER.warning("Audio queue full for device %s", self._device_id)

def _start_debug_recording_thread(self) -> None:
"""Start thread to record wake/stt audio if debug_recording_dir is set."""
if self.debug_recording_thread is not None:
# Already started
Expand All @@ -1068,7 +1093,7 @@ def _start_debug_recording_thread(self, device_id: str | None) -> None:
if debug_recording_dir := self.hass.data[DATA_CONFIG].get(
CONF_DEBUG_RECORDING_DIR
):
if device_id is None:
if self._device_id is None:
# <debug_recording_dir>/<pipeline.name>/<run.id>
run_recording_dir = (
Path(debug_recording_dir)
Expand All @@ -1079,7 +1104,7 @@ def _start_debug_recording_thread(self, device_id: str | None) -> None:
# <debug_recording_dir>/<device_id>/<pipeline.name>/<run.id>
run_recording_dir = (
Path(debug_recording_dir)
/ device_id
/ self._device_id
/ self.pipeline.name
/ str(time.monotonic_ns())
)
Expand All @@ -1100,8 +1125,8 @@ async def _stop_debug_recording_thread(self) -> None:
# Not running
return

# Signal thread to stop gracefully
self.debug_recording_queue.put(None)
# NOTE: Expecting a None to have been put in self.debug_recording_queue
# in self.end() to signal the thread to stop.

# Wait until the thread has finished to ensure that files are fully written
await self.hass.async_add_executor_job(self.debug_recording_thread.join)
Expand Down Expand Up @@ -1632,6 +1657,20 @@ async def _change_listener(
pipeline_run.abort_wake_word_detection = True


@dataclass
class DeviceAudioQueue:
"""Audio capture queue for a satellite device."""

queue: asyncio.Queue[bytes | None]
"""Queue of audio chunks (None = stop signal)"""

id: str = field(default_factory=ulid_util.ulid)
"""Unique id to ensure the correct audio queue is cleaned up in websocket API."""

overflow: bool = False
"""Flag to be set if audio samples were dropped because the queue was full."""


class PipelineData:
"""Store and debug data stored in hass.data."""

Expand All @@ -1641,6 +1680,7 @@ def __init__(self, pipeline_store: PipelineStorageCollection) -> None:
self.pipeline_debug: dict[str, LimitedSizeDict[str, PipelineRunDebug]] = {}
self.pipeline_devices: set[str] = set()
self.pipeline_runs = PipelineRuns(pipeline_store)
self.device_audio_queues: dict[str, DeviceAudioQueue] = {}


@dataclass
Expand Down
119 changes: 116 additions & 3 deletions homeassistant/components/assist_pipeline/websocket_api.py
Expand Up @@ -3,22 +3,31 @@

# Suppressing disable=deprecated-module is needed for Python 3.11
import audioop # pylint: disable=deprecated-module
import base64
from collections.abc import AsyncGenerator, Callable
import contextlib
import logging
from typing import Any
import math
from typing import Any, Final

import voluptuous as vol

from homeassistant.components import conversation, stt, tts, websocket_api
from homeassistant.const import MATCH_ALL
from homeassistant.const import ATTR_DEVICE_ID, ATTR_SECONDS, MATCH_ALL
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv
from homeassistant.util import language as language_util

from .const import DEFAULT_PIPELINE_TIMEOUT, DEFAULT_WAKE_WORD_TIMEOUT, DOMAIN
from .const import (
DEFAULT_PIPELINE_TIMEOUT,
DEFAULT_WAKE_WORD_TIMEOUT,
DOMAIN,
EVENT_RECORDING,
)
from .error import PipelineNotFound
from .pipeline import (
AudioSettings,
DeviceAudioQueue,
PipelineData,
PipelineError,
PipelineEvent,
Expand All @@ -32,6 +41,11 @@

_LOGGER = logging.getLogger(__name__)

CAPTURE_RATE: Final = 16000
CAPTURE_WIDTH: Final = 2
CAPTURE_CHANNELS: Final = 1
MAX_CAPTURE_TIMEOUT: Final = 60.0


@callback
def async_register_websocket_api(hass: HomeAssistant) -> None:
Expand All @@ -40,6 +54,7 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
websocket_api.async_register_command(hass, websocket_list_languages)
websocket_api.async_register_command(hass, websocket_list_runs)
websocket_api.async_register_command(hass, websocket_get_run)
websocket_api.async_register_command(hass, websocket_device_capture)


@websocket_api.websocket_command(
Expand Down Expand Up @@ -371,3 +386,101 @@ async def websocket_list_languages(
else pipeline_languages
},
)


@websocket_api.require_admin
@websocket_api.websocket_command(
{
vol.Required("type"): "assist_pipeline/device/capture",
vol.Required("device_id"): str,
vol.Required("timeout"): vol.All(
# 0 < timeout <= MAX_CAPTURE_TIMEOUT
vol.Coerce(float),
vol.Range(min=0, min_included=False, max=MAX_CAPTURE_TIMEOUT),
),
}
)
@websocket_api.async_response
async def websocket_device_capture(
hass: HomeAssistant,
connection: websocket_api.connection.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Capture raw audio from a satellite device and forward to client."""
pipeline_data: PipelineData = hass.data[DOMAIN]
device_id = msg["device_id"]

# Number of seconds to record audio in wall clock time
timeout_seconds = msg["timeout"]

# We don't know the chunk size, so the upper bound is calculated assuming a
# single sample (16 bits) per queue item.
max_queue_items = (
# +1 for None to signal end
int(math.ceil(timeout_seconds * CAPTURE_RATE))
+ 1
)

audio_queue = DeviceAudioQueue(queue=asyncio.Queue(maxsize=max_queue_items))

# Running simultaneous captures for a single device will not work by design.
# The new capture will cause the old capture to stop.
if (
old_audio_queue := pipeline_data.device_audio_queues.pop(device_id, None)
) is not None:
with contextlib.suppress(asyncio.QueueFull):
# Signal other websocket command that we're taking over
old_audio_queue.queue.put_nowait(None)

# Only one client can be capturing audio at a time
pipeline_data.device_audio_queues[device_id] = audio_queue
synesthesiam marked this conversation as resolved.
Show resolved Hide resolved

def clean_up_queue() -> None:
# Clean up our audio queue
maybe_audio_queue = pipeline_data.device_audio_queues.get(device_id)
if (maybe_audio_queue is not None) and (maybe_audio_queue.id == audio_queue.id):
# Only pop if this is our queue
pipeline_data.device_audio_queues.pop(device_id)

# Unsubscribe cleans up queue
connection.subscriptions[msg["id"]] = clean_up_queue

# Audio will follow as events
connection.send_result(msg["id"])

# Record to logbook
hass.bus.async_fire(
EVENT_RECORDING,
{
ATTR_DEVICE_ID: device_id,
ATTR_SECONDS: timeout_seconds,
},
)

try:
with contextlib.suppress(asyncio.TimeoutError):
async with asyncio.timeout(timeout_seconds):
while True:
# Send audio chunks encoded as base64
audio_bytes = await audio_queue.queue.get()
if audio_bytes is None:
# Signal to stop
break

connection.send_event(
msg["id"],
{
"type": "audio",
"rate": CAPTURE_RATE, # hertz
"width": CAPTURE_WIDTH, # bytes
"channels": CAPTURE_CHANNELS,
"audio": base64.b64encode(audio_bytes).decode("ascii"),
},
)

# Capture has ended
connection.send_event(
msg["id"], {"type": "end", "overflow": audio_queue.overflow}
)
finally:
clean_up_queue()