Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions livekit-agents/livekit/agents/inference/interruption.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
aio,
http_context,
is_given,
log_exceptions,
shortuuid,
)
from ._utils import (
Expand Down Expand Up @@ -470,7 +469,6 @@ def __init__(
@abstractmethod
async def _run(self) -> None: ...

@log_exceptions(logger=logger)
async def _main_task(self) -> None:
max_retries = self._conn_options.max_retry

Expand Down Expand Up @@ -760,7 +758,6 @@ async def _send_task(input_ch: aio.Chan[npt.NDArray[np.int16]]) -> None:
finally:
await aio.cancel_and_wait(*tasks)

@log_exceptions(logger=logger)
async def predict(self, waveform: np.ndarray) -> InterruptionResponse:
created_at = perf_counter_ns()
try:
Expand Down
23 changes: 12 additions & 11 deletions livekit-agents/livekit/agents/voice/agent_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -1497,15 +1497,9 @@ def _on_error(
error_event = ErrorEvent(error=error, source=self.tts)
self._session.emit("error", error_event)
elif isinstance(error, inference.InterruptionDetectionError):
error_event = ErrorEvent(error=error, source=self._interruption_detector)
self._session.emit("error", error_event)

if not error.recoverable:
# redundant no op, but keeping it for clarity
self._session._on_error(error)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we remove inference.InterruptionDetectionError annotation for session._on_error as well?


self._fallback_to_vad_interruption()
return
self._fallback_to_vad_interruption(error)
return

self._session._on_error(error)

Expand Down Expand Up @@ -3592,7 +3586,9 @@ def _restore_interruption_by_audio_activity(self) -> None:
self._default_interruption_by_audio_activity_enabled
)

def _fallback_to_vad_interruption(self) -> None:
def _fallback_to_vad_interruption(
self, error: inference.InterruptionDetectionError | None = None
) -> None:
"""Degrade gracefully from adaptive interruption to VAD-based interruption.

Called when the adaptive interruption detector encounters an unrecoverable error.
Expand All @@ -3611,11 +3607,16 @@ def _fallback_to_vad_interruption(self) -> None:
self._interruption_detector.off("overlapping_speech", self._on_overlap_speech_ended)

if self._audio_recognition:
# this also releases any held transcripts
self._audio_recognition.update_interruption_detection(None)

logger.warning(
logger.info(
"adaptive interruption disabled due to unrecoverable error, "
"falling back to VAD-based interruption"
"falling back to VAD-based interruption",
extra={
"error": str(error.error) if error is not None else None,
"label": error.label if error is not None else None,
},
)

def _init_metrics_from_end_of_turn(self, info: _EndOfTurnInfo) -> llm.MetricsReport:
Expand Down
11 changes: 1 addition & 10 deletions livekit-agents/livekit/agents/voice/agent_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1370,12 +1370,7 @@ async def _update_activity_task(
await self._update_activity(agent)

def _on_error(
self,
error: llm.LLMError
| stt.STTError
| tts.TTSError
| llm.RealtimeModelError
| inference.InterruptionDetectionError,
self, error: llm.LLMError | stt.STTError | tts.TTSError | llm.RealtimeModelError
) -> None:
if self._closing_task or error.recoverable:
return
Expand All @@ -1388,10 +1383,6 @@ def _on_error(
self._tts_error_counts += 1
if self._tts_error_counts <= self.conn_options.max_unrecoverable_errors:
return
elif error.type == "interruption_detection_error":
# interruption detection errors are handled by AgentActivity via VAD fallback,
# they should never close the session
return

if isinstance(error.error, APIError):
logger.error(f"AgentSession is closing due to unrecoverable error: {error.error}")
Expand Down
36 changes: 30 additions & 6 deletions livekit-agents/livekit/agents/voice/audio_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from livekit import rtc

from .. import inference, llm, stt, utils, vad
from .._exceptions import APIError
from ..inference.interruption import (
_AgentSpeechEndedSentinel,
_AgentSpeechStartedSentinel,
Expand Down Expand Up @@ -390,15 +391,33 @@ def on_end_of_overlap_speech(
_OverlapSpeechEndedSentinel(ended_at=ended_at or time.time())
)

async def _flush_held_transcripts(self, cooldown: float) -> None:
"""Flush held transcripts whose *end time* is after the ignore_user_transcript_until - cooldown timestamp.
@utils.log_exceptions(logger=logger)
async def _flush_held_transcripts(self, cooldown: float, force: bool = False) -> None:
"""Flush held transcripts.

When ``force`` is True, all buffered events are emitted unconditionally; this
is used during interruption-detector teardown when the ignore-window gating
can no longer be trusted.

If the event has no timestamps, we assume it is the same as the next valid event.
Otherwise, drop transcripts whose *end time* falls before
``ignore_user_transcript_until - cooldown`` and re-emit the rest. Events
without timestamps are treated as the next valid event.
"""
if not self._transcript_buffer:
self._reset_interruption_detection()
return

if force:
events_to_emit = list(self._transcript_buffer)
# reset before emitting to avoid recursive calls
self._reset_interruption_detection()
for ev in events_to_emit:
await self._on_stt_event(ev)
return

if (
not self._interruption_enabled
or not is_given(self._ignore_user_transcript_until)
or not self._transcript_buffer
or self._input_started_at is None
):
self._reset_interruption_detection()
Expand Down Expand Up @@ -428,14 +447,13 @@ async def _flush_held_transcripts(self, cooldown: float) -> None:
should_flush = True
break

# extract events to emit and reset BEFORE iterating
# to prevent recursive calls
events_to_emit = (
list(self._transcript_buffer)[int(emit_from_index) :]
if emit_from_index is not None and should_flush
else []
)
_ignore_user_transcript_until = self._ignore_user_transcript_until
# reset before emitting to avoid recursive calls
self._reset_interruption_detection()

for ev in events_to_emit:
Expand Down Expand Up @@ -643,6 +661,9 @@ def update_interruption_detection(
self._interruption_atask = None
self._interruption_ch = None
self._cancel_backchannel_boundary()
flush_task = asyncio.create_task(self._flush_held_transcripts(cooldown=0.0, force=True))
flush_task.add_done_callback(lambda _: self._tasks.discard(flush_task))
self._tasks.add(flush_task)

self._interruption_enabled = (
self._interruption_detection is not None and self._vad is not None
Expand Down Expand Up @@ -1259,6 +1280,9 @@ async def _forward() -> None:
try:
async for ev in stream:
await self._on_overlap_speech_event(ev)
except APIError:
# avoid already emitted error from the stream
return
finally:
await aio.cancel_and_wait(forward_task)
await stream.aclose()
Expand Down
129 changes: 128 additions & 1 deletion tests/test_agent_session.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import asyncio
import logging
import time
from unittest.mock import MagicMock, Mock

import pytest

Expand All @@ -16,13 +18,15 @@
UserStateChangedEvent,
function_tool,
inference,
vad,
)
from livekit.agents.llm import (
FunctionToolCall,
)
from livekit.agents.llm.chat_context import ChatContext, ChatMessage
from livekit.agents.stt import SpeechData, SpeechEvent, SpeechEventType
from livekit.agents.utils import aio
from livekit.agents.voice.agent_activity import AgentActivity
from livekit.agents.voice.audio_recognition import AudioRecognition, _EndOfTurnInfo
from livekit.agents.voice.endpointing import BaseEndpointing
from livekit.agents.voice.events import FunctionToolsExecutedEvent
Expand Down Expand Up @@ -790,6 +794,128 @@ async def test_backchannel_boundary_releases_end_boundary_transcript() -> None:
await _close_test_session(session)


async def test_interruption_detection_error_is_not_session_error() -> None:
actions = FakeActions()
session = create_session(actions)
activity = AgentActivity(MyAgent(), session)
fallback = Mock()
activity._fallback_to_vad_interruption = fallback
error_events: list[object] = []
session.on("error", error_events.append)

try:
recoverable = inference.InterruptionDetectionError(
label="test",
error=RuntimeError("temporary failure"),
recoverable=True,
)
activity._on_error(recoverable)

unrecoverable = inference.InterruptionDetectionError(
label="test",
error=RuntimeError("adaptive unavailable"),
recoverable=False,
)
activity._on_error(unrecoverable)

assert error_events == []
fallback.assert_called_once_with(unrecoverable)
finally:
await _close_test_session(session)


async def test_vad_fallback_uses_next_vad_inference_event(
caplog: pytest.LogCaptureFixture,
) -> None:
actions = FakeActions()
session = create_session(actions)
activity = AgentActivity(MyAgent(), session)
error = inference.InterruptionDetectionError(
label="test",
error=RuntimeError("adaptive unavailable"),
recoverable=False,
)

audio_recognition = MagicMock()
current_speech = MagicMock()
current_speech.interrupted = False
current_speech.allow_interruptions = True

activity._audio_recognition = audio_recognition
activity._current_speech = current_speech
activity._interruption_detection_enabled = True
activity._interruption_by_audio_activity_enabled = False
activity._default_interruption_by_audio_activity_enabled = True

caplog.set_level(logging.INFO, logger="livekit.agents")

try:
activity._fallback_to_vad_interruption(error)

audio_recognition.update_interruption_detection.assert_called_once_with(None)
current_speech.interrupt.assert_not_called()
assert activity._interruption_detection_enabled is False
assert activity._interruption_by_audio_activity_enabled is True

activity.on_vad_inference_done(
vad.VADEvent(
type=vad.VADEventType.INFERENCE_DONE,
samples_index=0,
timestamp=time.time(),
speech_duration=session.options.interruption["min_duration"] - 0.01,
silence_duration=0.0,
speaking=True,
)
)
current_speech.interrupt.assert_not_called()

activity.on_vad_inference_done(
vad.VADEvent(
type=vad.VADEventType.INFERENCE_DONE,
samples_index=0,
timestamp=time.time(),
speech_duration=session.options.interruption["min_duration"],
silence_duration=0.0,
speaking=True,
)
)
current_speech.interrupt.assert_called_once_with()
assert any(
record.levelno == logging.INFO
and "falling back to VAD-based interruption" in record.message
for record in caplog.records
)
assert not [record for record in caplog.records if record.levelno >= logging.WARNING]
finally:
await _close_test_session(session)


async def test_force_flush_held_transcripts_emits_buffered_events() -> None:
actions = FakeActions()
session = create_session(actions)
hooks = _TestRecognitionHooks()
recognition = AudioRecognition(
session,
hooks=hooks,
endpointing=BaseEndpointing(min_delay=0.1, max_delay=1.0),
stt=None,
vad=None,
interruption_detection=None,
turn_detection="manual",
)
recognition._transcript_buffer.append(
_final_transcript_event(text="held transcript", start_time=0.0, end_time=1.0)
)

try:
await recognition._flush_held_transcripts(cooldown=0.0, force=True)

assert hooks.final_transcripts == ["held transcript"]
assert not recognition._transcript_buffer
finally:
await _close_test_session(session)


@pytest.mark.parametrize(
"preemptive_generation, expected_latency",
[
Expand Down Expand Up @@ -956,6 +1082,7 @@ async def test_unknown_function_call() -> None:
class _TestRecognitionHooks:
def __init__(self) -> None:
self.interruptions: list[inference.OverlappingSpeechEvent] = []
self.final_transcripts: list[str] = []

def on_interruption(self, ev: inference.OverlappingSpeechEvent) -> None:
self.interruptions.append(ev)
Expand All @@ -973,7 +1100,7 @@ def on_interim_transcript(self, ev: SpeechEvent, *, speaking: bool | None) -> No
pass

def on_final_transcript(self, ev: SpeechEvent, *, speaking: bool | None = None) -> None:
pass
self.final_transcripts.append(ev.alternatives[0].text)

def on_end_of_turn(self, info: _EndOfTurnInfo) -> bool:
return True
Expand Down
5 changes: 4 additions & 1 deletion tests/test_interruption/test_interruption_failover.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from __future__ import annotations

import asyncio
import logging
import time
from unittest.mock import AsyncMock, MagicMock, Mock

Expand Down Expand Up @@ -114,7 +115,8 @@ def _mock_request_info() -> MagicMock:

class TestHttpTimeout:
@pytest.mark.asyncio
async def test_retries_then_emits_unrecoverable(self) -> None:
async def test_retries_then_emits_unrecoverable(self, caplog: pytest.LogCaptureFixture) -> None:
caplog.set_level(logging.WARNING, logger="livekit.agents")
mock_session = AsyncMock(spec=aiohttp.ClientSession)
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(side_effect=asyncio.TimeoutError("test timeout"))
Expand All @@ -133,6 +135,7 @@ async def test_retries_then_emits_unrecoverable(self) -> None:
unrecoverable_errors = [e for e in errors if not e.recoverable]
assert len(recoverable_errors) == 0
assert len(unrecoverable_errors) == 1
assert not [record for record in caplog.records if record.levelno >= logging.ERROR]


# there is no 429 in HTTP when hosted on LiveKit Cloud, so this is actually redundant
Expand Down
Loading