Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add speech-to-text cooldown for local wake word #108806

Merged
merged 14 commits into from Feb 27, 2024
Merged
2 changes: 2 additions & 0 deletions homeassistant/components/assist_pipeline/__init__.py
Expand Up @@ -83,6 +83,7 @@ async def async_pipeline_from_audio_stream(
event_callback: PipelineEventCallback,
stt_metadata: stt.SpeechMetadata,
stt_stream: AsyncIterable[bytes],
wake_word_phrase: str | None = None,
pipeline_id: str | None = None,
conversation_id: str | None = None,
tts_audio_output: str | None = None,
Expand All @@ -101,6 +102,7 @@ async def async_pipeline_from_audio_stream(
device_id=device_id,
stt_metadata=stt_metadata,
stt_stream=stt_stream,
wake_word_phrase=wake_word_phrase,
run=PipelineRun(
hass,
context=context,
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/components/assist_pipeline/const.py
Expand Up @@ -10,6 +10,6 @@
CONF_DEBUG_RECORDING_DIR = "debug_recording_dir"

DATA_LAST_WAKE_UP = f"{DOMAIN}.last_wake_up"
DEFAULT_WAKE_WORD_COOLDOWN = 2 # seconds
WAKE_WORD_COOLDOWN = 2 # seconds

EVENT_RECORDING = f"{DOMAIN}_recording"
11 changes: 11 additions & 0 deletions homeassistant/components/assist_pipeline/error.py
Expand Up @@ -38,6 +38,17 @@ class SpeechToTextError(PipelineError):
"""Error in speech-to-text portion of pipeline."""


class DuplicateWakeUpDetectedError(WakeWordDetectionError):
"""Error when multiple voice assistants wake up at the same time (same wake word)."""

def __init__(self, wake_up_phrase: str) -> None:
"""Set error message."""
super().__init__(
"duplicate_wake_up_detected",
f"Duplicate wake-up detected for {wake_up_phrase}",
)


class IntentRecognitionError(PipelineError):
"""Error in intent recognition portion of pipeline."""

Expand Down
46 changes: 36 additions & 10 deletions homeassistant/components/assist_pipeline/pipeline.py
Expand Up @@ -55,10 +55,11 @@
CONF_DEBUG_RECORDING_DIR,
DATA_CONFIG,
DATA_LAST_WAKE_UP,
DEFAULT_WAKE_WORD_COOLDOWN,
DOMAIN,
WAKE_WORD_COOLDOWN,
)
from .error import (
DuplicateWakeUpDetectedError,
IntentRecognitionError,
PipelineError,
PipelineNotFound,
Expand Down Expand Up @@ -453,9 +454,6 @@ class WakeWordSettings:
audio_seconds_to_buffer: float = 0
"""Seconds of audio to buffer before detection and forward to STT."""

cooldown_seconds: float = DEFAULT_WAKE_WORD_COOLDOWN
"""Seconds after a wake word detection where other detections are ignored."""


@dataclass(frozen=True)
class AudioSettings:
Expand Down Expand Up @@ -742,16 +740,22 @@ async def wake_word_detection(
wake_word_output: dict[str, Any] = {}
else:
# Avoid duplicate detections by checking cooldown
wake_up_key = f"{self.wake_word_entity_id}.{result.wake_word_id}"
last_wake_up = self.hass.data[DATA_LAST_WAKE_UP].get(wake_up_key)
last_wake_up = self.hass.data[DATA_LAST_WAKE_UP].get(
result.wake_word_phrase
)
if last_wake_up is not None:
sec_since_last_wake_up = time.monotonic() - last_wake_up
if sec_since_last_wake_up < wake_word_settings.cooldown_seconds:
_LOGGER.debug("Duplicate wake word detection occurred")
raise WakeWordDetectionAborted
if sec_since_last_wake_up < WAKE_WORD_COOLDOWN:
_LOGGER.debug(
"Duplicate wake word detection occurred for %s",
result.wake_word_phrase,
)
raise DuplicateWakeUpDetectedError(result.wake_word_phrase)

# Record last wake up time to block duplicate detections
self.hass.data[DATA_LAST_WAKE_UP][wake_up_key] = time.monotonic()
self.hass.data[DATA_LAST_WAKE_UP][
result.wake_word_phrase
] = time.monotonic()

if result.queued_audio:
# Add audio that was pending at detection.
Expand Down Expand Up @@ -1308,6 +1312,9 @@ class PipelineInput:
stt_stream: AsyncIterable[bytes] | None = None
"""Input audio for stt. Required when start_stage = stt."""

wake_word_phrase: str | None = None
"""Optional key used to de-duplicate wake-ups for local wake word detection."""

intent_input: str | None = None
"""Input for conversation agent. Required when start_stage = intent."""

Expand Down Expand Up @@ -1352,6 +1359,25 @@ async def execute(self) -> None:
assert self.stt_metadata is not None
assert stt_processed_stream is not None

if self.wake_word_phrase is not None:
# Avoid duplicate wake-ups by checking cooldown
last_wake_up = self.run.hass.data[DATA_LAST_WAKE_UP].get(
self.wake_word_phrase
)
if last_wake_up is not None:
sec_since_last_wake_up = time.monotonic() - last_wake_up
if sec_since_last_wake_up < WAKE_WORD_COOLDOWN:
_LOGGER.debug(
"Speech-to-text cancelled to avoid duplicate wake-up for %s",
self.wake_word_phrase,
)
raise DuplicateWakeUpDetectedError(self.wake_word_phrase)

# Record last wake up time to block duplicate detections
self.run.hass.data[DATA_LAST_WAKE_UP][
self.wake_word_phrase
] = time.monotonic()

stt_input_stream = stt_processed_stream

if stt_audio_buffer:
Expand Down
11 changes: 10 additions & 1 deletion homeassistant/components/assist_pipeline/websocket_api.py
Expand Up @@ -97,7 +97,12 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
extra=vol.ALLOW_EXTRA,
),
PipelineStage.STT: vol.Schema(
{vol.Required("input"): {vol.Required("sample_rate"): int}},
{
vol.Required("input"): {
vol.Required("sample_rate"): int,
vol.Optional("wake_word_phrase"): str,
}
},
extra=vol.ALLOW_EXTRA,
),
PipelineStage.INTENT: vol.Schema(
Expand Down Expand Up @@ -149,12 +154,15 @@ async def websocket_run(
msg_input = msg["input"]
audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
incoming_sample_rate = msg_input["sample_rate"]
wake_word_phrase: str | None = None

if start_stage == PipelineStage.WAKE_WORD:
wake_word_settings = WakeWordSettings(
timeout=msg["input"].get("timeout", DEFAULT_WAKE_WORD_TIMEOUT),
audio_seconds_to_buffer=msg_input.get("audio_seconds_to_buffer", 0),
)
elif start_stage == PipelineStage.STT:
wake_word_phrase = msg["input"].get("wake_word_phrase")

async def stt_stream() -> AsyncGenerator[bytes, None]:
state = None
Expand Down Expand Up @@ -189,6 +197,7 @@ def handle_binary(
channel=stt.AudioChannels.CHANNEL_MONO,
)
input_args["stt_stream"] = stt_stream()
input_args["wake_word_phrase"] = wake_word_phrase

# Audio settings
audio_settings = AudioSettings(
Expand Down
9 changes: 9 additions & 0 deletions homeassistant/components/wake_word/models.py
Expand Up @@ -7,7 +7,13 @@ class WakeWord:
"""Wake word model."""

id: str
"""Id of wake word model"""

name: str
"""Name of wake word model"""

phrase: str | None = None
"""Wake word phrase used to trigger model"""


@dataclass
Expand All @@ -17,6 +23,9 @@ class DetectionResult:
wake_word_id: str
"""Id of detected wake word"""

wake_word_phrase: str
"""Normalized phrase for the detected wake word"""

timestamp: int | None
"""Timestamp of audio chunk with detected wake word"""

Expand Down
2 changes: 1 addition & 1 deletion homeassistant/components/wyoming/manifest.json
Expand Up @@ -6,6 +6,6 @@
"dependencies": ["assist_pipeline"],
"documentation": "https://www.home-assistant.io/integrations/wyoming",
"iot_class": "local_push",
"requirements": ["wyoming==1.5.2"],
"requirements": ["wyoming==1.5.3"],
"zeroconf": ["_wyoming._tcp.local."]
}
47 changes: 44 additions & 3 deletions homeassistant/components/wyoming/satellite.py
@@ -1,4 +1,5 @@
"""Support for Wyoming satellite services."""

import asyncio
from collections.abc import AsyncGenerator
import io
Expand All @@ -10,6 +11,7 @@
from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStart, AudioStop
from wyoming.client import AsyncTcpClient
from wyoming.error import Error
from wyoming.info import Describe, Info
from wyoming.ping import Ping, Pong
from wyoming.pipeline import PipelineStage, RunPipeline
from wyoming.satellite import PauseSatellite, RunSatellite
Expand Down Expand Up @@ -86,7 +88,9 @@ async def run(self) -> None:
await self._connect_and_loop()
except asyncio.CancelledError:
raise # don't restart
except Exception: # pylint: disable=broad-exception-caught
except Exception as err: # pylint: disable=broad-exception-caught
_LOGGER.debug("%s: %s", err.__class__.__name__, str(err))

# Ensure sensor is off (before restart)
self.device.set_is_active(False)

Expand Down Expand Up @@ -197,6 +201,8 @@ async def _connect_and_loop(self) -> None:
async def _run_pipeline_loop(self) -> None:
"""Run a pipeline one or more times."""
assert self._client is not None
client_info: Info | None = None
wake_word_phrase: str | None = None
run_pipeline: RunPipeline | None = None
send_ping = True

Expand All @@ -209,6 +215,9 @@ async def _run_pipeline_loop(self) -> None:
)
pending = {pipeline_ended_task, client_event_task}

# Update info from satellite
await self._client.write_event(Describe().event())

while self.is_running and (not self.device.is_muted):
if send_ping:
# Ensure satellite is still connected
Expand All @@ -230,6 +239,9 @@ async def _run_pipeline_loop(self) -> None:
)
pending.add(pipeline_ended_task)

# Clear last wake word detection
wake_word_phrase = None

if (run_pipeline is not None) and run_pipeline.restart_on_end:
# Automatically restart pipeline.
# Used with "always on" streaming satellites.
Expand All @@ -253,7 +265,7 @@ async def _run_pipeline_loop(self) -> None:
elif RunPipeline.is_type(client_event.type):
# Satellite requested pipeline run
run_pipeline = RunPipeline.from_event(client_event)
self._run_pipeline_once(run_pipeline)
self._run_pipeline_once(run_pipeline, wake_word_phrase)
elif (
AudioChunk.is_type(client_event.type) and self._is_pipeline_running
):
Expand All @@ -265,6 +277,32 @@ async def _run_pipeline_loop(self) -> None:
# Stop pipeline
_LOGGER.debug("Client requested pipeline to stop")
self._audio_queue.put_nowait(b"")
elif Info.is_type(client_event.type):
client_info = Info.from_event(client_event)
_LOGGER.debug("Updated client info: %s", client_info)
elif Detection.is_type(client_event.type):
detection = Detection.from_event(client_event)
wake_word_phrase = detection.name

# Resolve wake word name/id to phrase if info is available.
#
# This allows us to deconflict multiple satellite wake-ups
# with the same wake word.
if (client_info is not None) and (client_info.wake is not None):
found_phrase = False
for wake_service in client_info.wake:
for wake_model in wake_service.models:
if wake_model.name == detection.name:
wake_word_phrase = (
wake_model.phrase or wake_model.name
)
found_phrase = True
break

if found_phrase:
break

_LOGGER.debug("Client detected wake word: %s", wake_word_phrase)
else:
_LOGGER.debug("Unexpected event from satellite: %s", client_event)

Expand All @@ -274,7 +312,9 @@ async def _run_pipeline_loop(self) -> None:
)
pending.add(client_event_task)

def _run_pipeline_once(self, run_pipeline: RunPipeline) -> None:
def _run_pipeline_once(
self, run_pipeline: RunPipeline, wake_word_phrase: str | None = None
) -> None:
"""Run a pipeline once."""
_LOGGER.debug("Received run information: %s", run_pipeline)

Expand Down Expand Up @@ -332,6 +372,7 @@ def _run_pipeline_once(self, run_pipeline: RunPipeline) -> None:
volume_multiplier=self.device.volume_multiplier,
),
device_id=self.device.device_id,
wake_word_phrase=wake_word_phrase,
),
name="wyoming satellite pipeline",
)
Expand Down
23 changes: 21 additions & 2 deletions homeassistant/components/wyoming/wake_word.py
@@ -1,4 +1,5 @@
"""Support for Wyoming wake-word-detection services."""

import asyncio
from collections.abc import AsyncIterable
import logging
Expand Down Expand Up @@ -49,7 +50,9 @@
wake_service = service.info.wake[0]

self._supported_wake_words = [
wake_word.WakeWord(id=ww.name, name=ww.description or ww.name)
wake_word.WakeWord(
id=ww.name, name=ww.description or ww.name, phrase=ww.phrase
)
for ww in wake_service.models
]
self._attr_name = wake_service.name
Expand All @@ -64,7 +67,11 @@
if info is not None:
wake_service = info.wake[0]
self._supported_wake_words = [
wake_word.WakeWord(id=ww.name, name=ww.description or ww.name)
wake_word.WakeWord(
id=ww.name,
name=ww.description or ww.name,
phrase=ww.phrase,
)
for ww in wake_service.models
]

Expand Down Expand Up @@ -140,6 +147,7 @@

return wake_word.DetectionResult(
wake_word_id=detection.name,
wake_word_phrase=self._get_phrase(detection.name),
timestamp=detection.timestamp,
queued_audio=queued_audio,
)
Expand Down Expand Up @@ -183,3 +191,14 @@
_LOGGER.exception("Error processing audio stream: %s", err)

return None

def _get_phrase(self, model_id: str) -> str:
"""Get wake word phrase for model id."""
for ww_model in self._supported_wake_words:
if not ww_model.phrase:
continue

Check warning on line 199 in homeassistant/components/wyoming/wake_word.py

View check run for this annotation

Codecov / codecov/patch

homeassistant/components/wyoming/wake_word.py#L199

Added line #L199 was not covered by tests

if ww_model.id == model_id:
return ww_model.phrase

return model_id
2 changes: 1 addition & 1 deletion requirements_all.txt
Expand Up @@ -2863,7 +2863,7 @@ wled==0.17.0
wolf-comm==0.0.4

# homeassistant.components.wyoming
wyoming==1.5.2
wyoming==1.5.3

# homeassistant.components.xbox
xbox-webapi==2.0.11
Expand Down
2 changes: 1 addition & 1 deletion requirements_test_all.txt
Expand Up @@ -2195,7 +2195,7 @@ wled==0.17.0
wolf-comm==0.0.4

# homeassistant.components.wyoming
wyoming==1.5.2
wyoming==1.5.3

# homeassistant.components.xbox
xbox-webapi==2.0.11
Expand Down