Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion livekit-agents/livekit/agents/vad.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class VADEvent:
@dataclass
class VADCapabilities:
update_interval: float
supports_reset: bool = False
"""Whether the VAD Stream supports mid-session reset"""


class VAD(ABC, rtc.EventEmitter[Literal["metrics_collected"]]):
Expand Down Expand Up @@ -99,10 +101,15 @@ class VADStream(ABC):
class _FlushSentinel:
pass

class _ResetSentinel:
pass

def __init__(self, vad: VAD) -> None:
self._vad = vad
self._last_activity_time = time.perf_counter()
self._input_ch = aio.Chan[rtc.AudioFrame | VADStream._FlushSentinel]()
self._input_ch = aio.Chan[
rtc.AudioFrame | VADStream._FlushSentinel | VADStream._ResetSentinel
]()
self._event_ch = aio.Chan[VADEvent]()

self._tee_aiter = aio.itertools.tee(self._event_ch, 2)
Expand Down Expand Up @@ -158,6 +165,19 @@ def flush(self) -> None:
self._check_not_closed()
self._input_ch.send_nowait(self._FlushSentinel())

def reset(self) -> None:
"""Reset vad state without closing the stream."""
if not self._vad.capabilities.supports_reset:
raise RuntimeError(
f"{self._vad._label} does not support mid-session reset "
f"({type(self).__module__}.{type(self).__name__}.reset requires "
"VADCapabilities.supports_reset=True); create a new VAD stream instead."
)

self._check_input_not_ended()
self._check_not_closed()
self._input_ch.send_nowait(self._ResetSentinel())

def end_input(self) -> None:
"""Mark the end of input, no more audio will be pushed"""
self.flush()
Expand Down
19 changes: 17 additions & 2 deletions livekit-agents/livekit/agents/voice/audio_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ..telemetry import trace_types, tracer
from ..types import NOT_GIVEN, NotGivenOr
from ..utils import aio, is_given
from ..vad import VADStream
from . import io
from ._utils import _set_participant_attributes
from .endpointing import BaseEndpointing
Expand Down Expand Up @@ -171,6 +172,7 @@ def __init__(

self._stt_pipeline: _STTPipeline | None = None
self._vad_ch: aio.Chan[rtc.AudioFrame] | None = None
self._vad_stream: VADStream | None = None

self._tasks: set[asyncio.Task[Any]] = set()

Expand Down Expand Up @@ -590,6 +592,7 @@ def update_stt(self, stt: io.STTNode | None, *, pipeline: _STTPipeline | None =
def update_vad(self, vad: vad.VAD | None) -> None:
self._vad = vad
if vad:
self._vad_stream = None
self._vad_ch = aio.Chan[rtc.AudioFrame]()
self._vad_atask = asyncio.create_task(
self._vad_task(vad, self._vad_ch, self._vad_atask)
Expand All @@ -600,6 +603,7 @@ def update_vad(self, vad: vad.VAD | None) -> None:
self._tasks.add(task)
self._vad_atask = None
self._vad_ch = None
self._vad_stream = None

self._interruption_enabled = (
self._interruption_detection is not None and self._vad is not None
Expand Down Expand Up @@ -945,10 +949,18 @@ async def _on_stt_event(self, ev: stt.SpeechEvent) -> None:
# if user is still speaking (an immediate VAD SOS will interrupt the agent)
if self._vad:
if self._speaking:
_start_time = time.perf_counter()
if self._vad_stream is not None and self._vad.capabilities.supports_reset:
self._vad_stream.reset()
else:
self.update_vad(self._vad)

logger.warning(
"stt end of speech received while user is speaking, resetting vad"
"stt end of speech received while user is speaking, resetting vad",
extra={
"reset_duration_ms": (time.perf_counter() - _start_time) * 1000,
},
)
self.update_vad(self._vad)

self._speaking = False
self._user_turn_committed = True
Expand Down Expand Up @@ -1215,6 +1227,7 @@ async def _vad_task(
await aio.cancel_and_wait(task)

stream = vad.stream()
self._vad_stream = stream

@utils.log_exceptions(logger=logger)
async def _forward() -> None:
Expand All @@ -1229,6 +1242,8 @@ async def _forward() -> None:
finally:
await aio.cancel_and_wait(forward_task)
await stream.aclose()
if self._vad_stream is stream:
self._vad_stream = None

# reset the speaking state to prevent stuck user speaking state during handoff
if self._speaking:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ def window_size_samples(self) -> int:
def context_size(self) -> int:
return self._context_size

def reset(self) -> None:
self._context.fill(0)
self._rnn_state.fill(0)
self._input_buffer.fill(0)

def __call__(self, x: np.ndarray) -> float:
self._input_buffer[:, : self._context_size] = self._context
self._input_buffer[:, self._context_size :] = x
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ def __init__(
session: onnxruntime.InferenceSession,
opts: _VADOptions,
) -> None:
super().__init__(capabilities=agents.vad.VADCapabilities(update_interval=0.032))
super().__init__(
capabilities=agents.vad.VADCapabilities(update_interval=0.032, supports_reset=True)
)
self._onnx_session = session
self._opts = opts
self._streams = weakref.WeakSet[VADStream]()
Expand Down Expand Up @@ -311,7 +313,49 @@ async def _main_task(self) -> None:

extra_inference_time = 0.0

def _reset_state() -> None:
nonlocal speech_buffer_index
nonlocal pub_speaking, pub_speech_duration, pub_silence_duration
nonlocal pub_current_sample, pub_timestamp
nonlocal speech_threshold_duration, silence_threshold_duration
nonlocal input_frames, inference_frames, resampler
nonlocal input_copy_remaining_fract, extra_inference_time

self._model.reset()
self._exp_filter = utils.ExpFilter(alpha=0.35)

speech_buffer_index = 0
self._speech_buffer_max_reached = False
if self._speech_buffer is not None:
self._speech_buffer.fill(0)

pub_speaking = False
pub_speech_duration = 0.0
pub_silence_duration = 0.0
pub_current_sample = 0
pub_timestamp = 0.0
speech_threshold_duration = 0.0
silence_threshold_duration = 0.0

input_frames = []
inference_frames = []
input_copy_remaining_fract = 0.0
extra_inference_time = 0.0

if self._input_sample_rate and self._input_sample_rate != self._opts.sample_rate:
resampler = rtc.AudioResampler(
input_rate=self._input_sample_rate,
output_rate=self._opts.sample_rate,
quality=rtc.AudioResamplerQuality.QUICK,
)
else:
resampler = None

async for input_frame in self._input_ch:
if isinstance(input_frame, self._ResetSentinel):
_reset_state()
continue

if not isinstance(input_frame, rtc.AudioFrame):
continue # ignore flush sentinel for now

Expand Down
8 changes: 7 additions & 1 deletion tests/fake_vad.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import time

from livekit import rtc
from livekit.agents.vad import VAD, VADCapabilities, VADEvent, VADEventType, VADStream

from .fake_stt import FakeUserSpeech
Expand Down Expand Up @@ -41,7 +42,12 @@ async def _main_task(self) -> None:
if not self._vad._fake_user_speeches:
return

await self._input_ch.recv()
async for input_frame in self._input_ch:
if isinstance(input_frame, rtc.AudioFrame):
break
else:
return

start_time = time.perf_counter()

def current_time() -> float:
Expand Down
75 changes: 75 additions & 0 deletions tests/test_agent_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import time
from unittest.mock import MagicMock, patch

import pytest

Expand Down Expand Up @@ -746,6 +747,80 @@ async def test_backchannel_boundary_suppresses_start_boundary_interruption() ->
await _close_test_session(session)


def _make_fake_vad(*, supports_reset: bool) -> MagicMock:
fake_vad = MagicMock()
fake_vad.capabilities = MagicMock(supports_reset=supports_reset)
return fake_vad


async def _make_stt_eos_recognition() -> AudioRecognition:
return AudioRecognition(
create_session(FakeActions()),
hooks=_TestRecognitionHooks(),
endpointing=BaseEndpointing(min_delay=0.0, max_delay=0.0),
stt=None,
vad=None,
interruption_detection=None,
turn_detection="stt",
)


async def test_stt_eos_resets_active_vad_stream_without_restarting_vad() -> None:
recognition = await _make_stt_eos_recognition()
recognition._speaking = True
recognition._vad = _make_fake_vad(supports_reset=True)
resettable_stream = MagicMock()
recognition._vad_stream = resettable_stream

try:
with patch.object(recognition, "update_vad") as update_vad:
await recognition._on_stt_event(SpeechEvent(type=SpeechEventType.END_OF_SPEECH))

resettable_stream.reset.assert_called_once_with()
update_vad.assert_not_called()
assert recognition._vad_stream is resettable_stream
finally:
if recognition._end_of_turn_task is not None:
await aio.cancel_and_wait(recognition._end_of_turn_task)
await _close_test_session(recognition._session)


async def test_stt_eos_falls_back_to_update_vad_when_no_active_stream() -> None:
recognition = await _make_stt_eos_recognition()
recognition._speaking = True
recognition._vad = _make_fake_vad(supports_reset=True)
recognition._vad_stream = None

try:
with patch.object(recognition, "update_vad") as update_vad:
await recognition._on_stt_event(SpeechEvent(type=SpeechEventType.END_OF_SPEECH))

update_vad.assert_called_once_with(recognition._vad)
finally:
if recognition._end_of_turn_task is not None:
await aio.cancel_and_wait(recognition._end_of_turn_task)
await _close_test_session(recognition._session)


async def test_stt_eos_falls_back_to_update_vad_when_reset_unsupported() -> None:
recognition = await _make_stt_eos_recognition()
recognition._speaking = True
recognition._vad = _make_fake_vad(supports_reset=False)
resettable_stream = MagicMock()
recognition._vad_stream = resettable_stream

try:
with patch.object(recognition, "update_vad") as update_vad:
await recognition._on_stt_event(SpeechEvent(type=SpeechEventType.END_OF_SPEECH))

update_vad.assert_called_once_with(recognition._vad)
resettable_stream.reset.assert_not_called()
finally:
if recognition._end_of_turn_task is not None:
await aio.cancel_and_wait(recognition._end_of_turn_task)
await _close_test_session(recognition._session)


async def test_backchannel_boundary_releases_end_boundary_transcript() -> None:
actions = FakeActions()
session = create_session(
Expand Down
68 changes: 68 additions & 0 deletions tests/test_vad.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio

import pytest

from livekit.agents import vad
Expand All @@ -14,6 +16,15 @@
)


async def _next_reset_duration(caplog: pytest.LogCaptureFixture) -> float:
while True:
for record in caplog.records:
if record.getMessage() == "reset vad stream":
return float(record.__dict__["reset_duration_ms"])

await asyncio.sleep(0.01)


@pytest.mark.parametrize("sample_rate", SAMPLE_RATES)
async def test_chunks_vad(sample_rate) -> None:
frames, *_ = await utils.make_test_speech(chunk_duration_ms=10, sample_rate=sample_rate)
Expand Down Expand Up @@ -60,6 +71,63 @@ async def test_chunks_vad(sample_rate) -> None:
f.write(utils.make_wav_file(inference_frames))


async def _drain_speech_segment(
stream: vad.VADStream, frames: list, *, timeout: float = 30.0
) -> tuple[vad.VADEvent, vad.VADEvent]:
"""Push *frames* until both START_OF_SPEECH and END_OF_SPEECH have fired."""

done = asyncio.Event()

async def _pump() -> None:
for frame in frames:
if done.is_set():
return
stream.push_frame(frame)
await asyncio.sleep(0)

async def _consume() -> tuple[vad.VADEvent, vad.VADEvent]:
sos_event: vad.VADEvent | None = None
async for ev in stream:
if ev.type == vad.VADEventType.START_OF_SPEECH and sos_event is None:
sos_event = ev
elif ev.type == vad.VADEventType.END_OF_SPEECH and sos_event is not None:
return sos_event, ev

raise AssertionError("stream ended before END_OF_SPEECH")

pump_task = asyncio.create_task(_pump())
try:
return await asyncio.wait_for(_consume(), timeout=timeout)
finally:
done.set()
pump_task.cancel()
try:
await pump_task
except asyncio.CancelledError:
pass


async def test_reset_recovers_full_speech_segment() -> None:
"""Real speech audio should still produce a complete SOS + EOS cycle after reset."""

frames, *_ = await utils.make_test_speech(chunk_duration_ms=10, sample_rate=16000)
assert len(frames) > 1, "frames aren't chunked"

stream = VAD.stream()
try:
first_sos, first_eos = await _drain_speech_segment(stream, frames)
assert first_sos.type == vad.VADEventType.START_OF_SPEECH
assert first_eos.type == vad.VADEventType.END_OF_SPEECH

stream.reset()

second_sos, second_eos = await _drain_speech_segment(stream, frames)
assert second_sos.type == vad.VADEventType.START_OF_SPEECH
assert second_eos.type == vad.VADEventType.END_OF_SPEECH
finally:
await stream.aclose()


@pytest.mark.parametrize("sample_rate", SAMPLE_RATES)
async def test_file_vad(sample_rate):
frames, *_ = await utils.make_test_speech(sample_rate=sample_rate)
Expand Down
Loading