Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ class STTOptions:
min_turn_silence: NotGivenOr[int] = NOT_GIVEN
max_turn_silence: NotGivenOr[int] = NOT_GIVEN
format_turns: NotGivenOr[bool] = NOT_GIVEN
continuous_partials: NotGivenOr[bool] = NOT_GIVEN
interruption_delay: NotGivenOr[int] = NOT_GIVEN
keyterms_prompt: NotGivenOr[list[str]] = NOT_GIVEN
prompt: NotGivenOr[str] = NOT_GIVEN
vad_threshold: NotGivenOr[float] = NOT_GIVEN
Expand All @@ -84,6 +86,8 @@ def __init__(
min_turn_silence: NotGivenOr[int] = NOT_GIVEN,
max_turn_silence: NotGivenOr[int] = NOT_GIVEN,
format_turns: NotGivenOr[bool] = NOT_GIVEN,
continuous_partials: NotGivenOr[bool] = NOT_GIVEN,
interruption_delay: NotGivenOr[int] = NOT_GIVEN,
keyterms_prompt: NotGivenOr[list[str]] = NOT_GIVEN,
prompt: NotGivenOr[str] = NOT_GIVEN,
vad_threshold: NotGivenOr[float] = NOT_GIVEN,
Expand All @@ -108,6 +112,19 @@ def __init__(
Defaults to 0.4.
min_turn_silence: Minimum silence in ms before a confident end-of-turn is finalized.
min_end_of_turn_silence_when_confident: Deprecated. Use min_turn_silence instead.
continuous_partials: Whether to emit additional partial transcripts during long
turns at a steady ~3 second cadence. By default, partials are emitted at
two points: one at 750 ms after turn start (configurable via
`interruption_delay`), and one each time silence exceeds
`min_turn_silence` without ending the turn. When enabled (default in
LiveKit; AssemblyAI server defaults to False), additional partials covering
the full turn transcript are emitted approximately every 3 seconds while
speech continues, on top of those baseline partials. Only supported with
the 'u3-rt-pro' model.
interruption_delay: How soon the first early partial is emitted, in ms.
Range 0–1000, default 500. Lower values produce faster time-to-first-token
for barge-in; higher values produce more confident first partials. Only
supported with the 'u3-rt-pro' model.
"""
super().__init__(
capabilities=stt.STTCapabilities(
Expand All @@ -125,6 +142,22 @@ def __init__(
if is_given(prompt) and model != "u3-rt-pro":
raise ValueError("The 'prompt' parameter is only supported with the 'u3-rt-pro' model.")

if is_given(continuous_partials) and model != "u3-rt-pro":
raise ValueError(
"The 'continuous_partials' parameter is only supported with the 'u3-rt-pro' model."
)

if is_given(interruption_delay) and model != "u3-rt-pro":
raise ValueError(
"The 'interruption_delay' parameter is only supported with the 'u3-rt-pro' model."
)

# LiveKit defaults continuous_partials to True (vs. AssemblyAI's server default of
# False) for steady-cadence partials. This parameter is only supported for
# u3-rt-pro, enforced by the validation above.
if not is_given(continuous_partials) and model == "u3-rt-pro":
continuous_partials = True

self._base_url = base_url
assemblyai_api_key = api_key if is_given(api_key) else os.environ.get("ASSEMBLYAI_API_KEY")
if not assemblyai_api_key:
Expand Down Expand Up @@ -159,6 +192,8 @@ def __init__(
min_turn_silence=min_turn_silence,
max_turn_silence=max_turn_silence,
format_turns=format_turns,
continuous_partials=continuous_partials,
interruption_delay=interruption_delay,
keyterms_prompt=keyterms_prompt,
prompt=prompt,
vad_threshold=vad_threshold,
Expand Down Expand Up @@ -220,6 +255,8 @@ def update_options(
prompt: NotGivenOr[str] = NOT_GIVEN,
keyterms_prompt: NotGivenOr[list[str]] = NOT_GIVEN,
vad_threshold: NotGivenOr[float] = NOT_GIVEN,
continuous_partials: NotGivenOr[bool] = NOT_GIVEN,
interruption_delay: NotGivenOr[int] = NOT_GIVEN,
# Deprecated — use min_turn_silence instead
min_end_of_turn_silence_when_confident: NotGivenOr[int] = NOT_GIVEN,
) -> None:
Expand All @@ -245,6 +282,10 @@ def update_options(
self._opts.keyterms_prompt = keyterms_prompt
if is_given(vad_threshold):
self._opts.vad_threshold = vad_threshold
if is_given(continuous_partials):
self._opts.continuous_partials = continuous_partials
if is_given(interruption_delay):
self._opts.interruption_delay = interruption_delay

for stream in self._streams:
stream.update_options(
Expand All @@ -255,6 +296,8 @@ def update_options(
prompt=prompt,
keyterms_prompt=keyterms_prompt,
vad_threshold=vad_threshold,
continuous_partials=continuous_partials,
interruption_delay=interruption_delay,
)


Expand Down Expand Up @@ -308,6 +351,8 @@ def update_options(
prompt: NotGivenOr[str] = NOT_GIVEN,
keyterms_prompt: NotGivenOr[list[str]] = NOT_GIVEN,
vad_threshold: NotGivenOr[float] = NOT_GIVEN,
continuous_partials: NotGivenOr[bool] = NOT_GIVEN,
interruption_delay: NotGivenOr[int] = NOT_GIVEN,
# Deprecated — use min_turn_silence instead
min_end_of_turn_silence_when_confident: NotGivenOr[int] = NOT_GIVEN,
) -> None:
Expand All @@ -333,6 +378,10 @@ def update_options(
self._opts.keyterms_prompt = keyterms_prompt
if is_given(vad_threshold):
self._opts.vad_threshold = vad_threshold
if is_given(continuous_partials):
self._opts.continuous_partials = continuous_partials
if is_given(interruption_delay):
self._opts.interruption_delay = interruption_delay

# Send UpdateConfiguration message over the active websocket
config_msg: dict = {"type": "UpdateConfiguration"}
Expand All @@ -346,6 +395,10 @@ def update_options(
config_msg["min_turn_silence"] = min_turn_silence
if is_given(end_of_turn_confidence_threshold):
config_msg["end_of_turn_confidence_threshold"] = end_of_turn_confidence_threshold
if is_given(continuous_partials):
config_msg["continuous_partials"] = continuous_partials
if is_given(interruption_delay):
config_msg["interruption_delay"] = interruption_delay
if is_given(vad_threshold):
config_msg["vad_threshold"] = vad_threshold

Expand Down Expand Up @@ -514,6 +567,12 @@ async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
"encoding": self._opts.encoding,
"speech_model": self._opts.speech_model,
"format_turns": self._opts.format_turns if is_given(self._opts.format_turns) else None,
"continuous_partials": self._opts.continuous_partials
if is_given(self._opts.continuous_partials)
else None,
"interruption_delay": self._opts.interruption_delay
if is_given(self._opts.interruption_delay)
else None,
"end_of_turn_confidence_threshold": self._opts.end_of_turn_confidence_threshold
if is_given(self._opts.end_of_turn_confidence_threshold)
else None,
Expand Down
117 changes: 117 additions & 0 deletions tests/test_plugin_assemblyai_stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import time
from unittest.mock import MagicMock, patch

import pytest

from livekit.agents.stt import SpeechEventType
from livekit.agents.types import NOT_GIVEN

Expand Down Expand Up @@ -184,3 +186,118 @@ async def test_start_time_has_default_before_plugin_override():
# start_time should already be a recent wall-clock value from the base
# class __init__, without any explicit override.
assert time.time() - stream.start_time < 5.0


async def test_continuous_partials_default():
"""Test continuous_partials is not set by default."""
from livekit.plugins.assemblyai import STT

stt = STT(api_key="test-key")
assert stt._opts.continuous_partials is NOT_GIVEN


async def test_continuous_partials_set():
"""Test continuous_partials can be set in constructor with u3-rt-pro."""
from livekit.plugins.assemblyai import STT

stt = STT(api_key="test-key", model="u3-rt-pro", continuous_partials=True)
assert stt._opts.continuous_partials is True


async def test_continuous_partials_requires_u3_rt_pro():
"""Test continuous_partials raises ValueError when used with a non-u3-rt-pro model."""
from livekit.plugins.assemblyai import STT

with pytest.raises(ValueError, match="continuous_partials"):
STT(api_key="test-key", continuous_partials=True)


async def test_continuous_partials_with_u3_pro_alias():
"""Test continuous_partials works with the deprecated 'u3-pro' alias (rewritten to u3-rt-pro)."""
from livekit.plugins.assemblyai import STT

stt = STT(api_key="test-key", model="u3-pro", continuous_partials=True)
assert stt._opts.continuous_partials is True
assert stt._opts.speech_model == "u3-rt-pro"


async def test_continuous_partials_update():
"""Test continuous_partials can be updated dynamically via update_options."""
from livekit.plugins.assemblyai import STT

stt = STT(api_key="test-key", model="u3-rt-pro", continuous_partials=False)
stt.update_options(continuous_partials=True)
assert stt._opts.continuous_partials is True


async def test_continuous_partials_defaults_to_true_for_u3_rt_pro():
"""Test continuous_partials defaults to True when model is u3-rt-pro (LiveKit-only
default; AssemblyAI server defaults to False)."""
from livekit.plugins.assemblyai import STT

stt = STT(api_key="test-key", model="u3-rt-pro")
assert stt._opts.continuous_partials is True


async def test_continuous_partials_explicit_false_overrides_livekit_default():
"""Test explicit continuous_partials=False overrides the LiveKit-only True default."""
from livekit.plugins.assemblyai import STT

stt = STT(api_key="test-key", model="u3-rt-pro", continuous_partials=False)
assert stt._opts.continuous_partials is False


async def test_continuous_partials_update_from_default():
"""Test continuous_partials can be updated via update_options away from LiveKit default."""
from livekit.plugins.assemblyai import STT

# LiveKit defaults this to True for u3-rt-pro
stt = STT(api_key="test-key", model="u3-rt-pro")
assert stt._opts.continuous_partials is True

stt.update_options(continuous_partials=False)
assert stt._opts.continuous_partials is False


async def test_interruption_delay_update():
"""Test interruption_delay can be updated dynamically via update_options."""
from livekit.plugins.assemblyai import STT

stt = STT(api_key="test-key", model="u3-rt-pro", interruption_delay=200)
stt.update_options(interruption_delay=750)
assert stt._opts.interruption_delay == 750


async def test_interruption_delay_update_from_default():
"""Test interruption_delay can be set via update_options when not initially set."""
from livekit.plugins.assemblyai import STT

stt = STT(api_key="test-key", model="u3-rt-pro")
assert stt._opts.interruption_delay is NOT_GIVEN

stt.update_options(interruption_delay=300)
assert stt._opts.interruption_delay == 300


async def test_interruption_delay_default():
"""Test interruption_delay is not set by default."""
from livekit.plugins.assemblyai import STT

stt = STT(api_key="test-key")
assert stt._opts.interruption_delay is NOT_GIVEN


async def test_interruption_delay_set():
"""Test interruption_delay can be set in constructor with u3-rt-pro."""
from livekit.plugins.assemblyai import STT

stt = STT(api_key="test-key", model="u3-rt-pro", interruption_delay=200)
assert stt._opts.interruption_delay == 200


async def test_interruption_delay_requires_u3_rt_pro():
"""Test interruption_delay raises ValueError when used with a non-u3-rt-pro model."""
from livekit.plugins.assemblyai import STT

with pytest.raises(ValueError, match="interruption_delay"):
STT(api_key="test-key", interruption_delay=200)
Loading