Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .tts import (
TTS,
ChunkedStream,
DeliveryMode,
Encoding,
SynthesizeStream,
TextNormalization,
Expand All @@ -36,6 +37,7 @@
"TTS",
"ChunkedStream",
"SynthesizeStream",
"DeliveryMode",
"Encoding",
"TTSModels",
"TextNormalization",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
_TextNormalizationStr = Literal["APPLY_TEXT_NORMALIZATION_UNSPECIFIED", "ON", "OFF"]
TextNormalization = _TextNormalizationStr | bool
TimestampTransportStrategy = Literal["TIMESTAMP_TRANSPORT_STRATEGY_UNSPECIFIED", "SYNC", "ASYNC"]
DeliveryMode = Literal["DELIVERY_MODE_UNSPECIFIED", "STABLE", "BALANCED", "CREATIVE"]

DEFAULT_TIMESTAMP_TRANSPORT_STRATEGY: TimestampTransportStrategy = "ASYNC"

Expand Down Expand Up @@ -98,6 +99,7 @@ class _TTSOptions:
language: NotGivenOr[str] = NOT_GIVEN
timestamp_type: NotGivenOr[TimestampType] = NOT_GIVEN
text_normalization: NotGivenOr[TextNormalization] = NOT_GIVEN
delivery_mode: NotGivenOr[DeliveryMode] = NOT_GIVEN
timestamp_transport_strategy: TimestampTransportStrategy = DEFAULT_TIMESTAMP_TRANSPORT_STRATEGY
buffer_char_threshold: int = DEFAULT_BUFFER_CHAR_THRESHOLD
max_buffer_delay_ms: int = DEFAULT_MAX_BUFFER_DELAY_MS
Expand Down Expand Up @@ -408,6 +410,8 @@ async def _send_loop(self) -> None:
pkt["create"]["timestampType"] = opts.timestamp_type
if is_given(opts.text_normalization):
pkt["create"]["applyTextNormalization"] = opts.text_normalization
if is_given(opts.delivery_mode):
pkt["create"]["deliveryMode"] = opts.delivery_mode
# Always enable auto_mode since we always use SentenceTokenizer
pkt["create"]["autoMode"] = True
await self._ws.send_str(json.dumps(pkt))
Expand Down Expand Up @@ -831,6 +835,7 @@ def __init__(
language: NotGivenOr[str] = NOT_GIVEN,
timestamp_type: NotGivenOr[TimestampType] = NOT_GIVEN,
text_normalization: NotGivenOr[TextNormalization] = NOT_GIVEN,
delivery_mode: NotGivenOr[DeliveryMode] = NOT_GIVEN,
timestamp_transport_strategy: NotGivenOr[TimestampTransportStrategy] = NOT_GIVEN,
buffer_char_threshold: NotGivenOr[int] = NOT_GIVEN,
max_buffer_delay_ms: NotGivenOr[int] = NOT_GIVEN,
Expand Down Expand Up @@ -866,6 +871,11 @@ def __init__(
text_normalization (str, optional): Controls text normalization. When "ON", numbers,
dates, and abbreviations are expanded (e.g., "Dr." -> "Doctor"). When "OFF",
text is read exactly as written. Defaults to automatic.
delivery_mode (str, optional): Controls output variation on ``inworld-tts-2`` only.
One of "DELIVERY_MODE_UNSPECIFIED", "STABLE", "BALANCED", or "CREATIVE".
The Inworld API ignores ``temperature`` on ``inworld-tts-2`` — use
``delivery_mode`` to steer output variation on that model instead.
Defaults to the server-side default ("BALANCED").
timestamp_transport_strategy (str, optional): Controls how timestamp info is
transported relative to audio data. "SYNC" returns timestamps in the same
message as audio data. "ASYNC" allows timestamps to return in trailing
Expand Down Expand Up @@ -918,6 +928,8 @@ def __init__(
_validate_str_param(timestamp_type, "timestamp_type", TimestampType)
if is_given(text_normalization):
text_normalization = _resolve_text_normalization(text_normalization)
if is_given(delivery_mode):
_validate_str_param(delivery_mode, "delivery_mode", DeliveryMode)
if is_given(timestamp_transport_strategy):
_validate_str_param(
timestamp_transport_strategy,
Expand All @@ -936,6 +948,7 @@ def __init__(
language=language,
timestamp_type=timestamp_type,
text_normalization=text_normalization,
delivery_mode=delivery_mode,
timestamp_transport_strategy=timestamp_transport_strategy
if is_given(timestamp_transport_strategy)
else DEFAULT_TIMESTAMP_TRANSPORT_STRATEGY,
Expand Down Expand Up @@ -994,6 +1007,7 @@ def update_options(
language: NotGivenOr[str] = NOT_GIVEN,
timestamp_type: NotGivenOr[TimestampType] = NOT_GIVEN,
text_normalization: NotGivenOr[TextNormalization] = NOT_GIVEN,
delivery_mode: NotGivenOr[DeliveryMode] = NOT_GIVEN,
timestamp_transport_strategy: NotGivenOr[TimestampTransportStrategy] = NOT_GIVEN,
buffer_char_threshold: NotGivenOr[int] = NOT_GIVEN,
max_buffer_delay_ms: NotGivenOr[int] = NOT_GIVEN,
Expand All @@ -1013,6 +1027,9 @@ def update_options(
language (str, optional): BCP-47 language tag (e.g., "en-US", "fr-FR").
timestamp_type (str, optional): Controls timestamp metadata ("WORD" or "CHARACTER").
text_normalization (str, optional): Controls text normalization ("ON" or "OFF").
delivery_mode (str, optional): Controls output variation on ``inworld-tts-2`` only
("DELIVERY_MODE_UNSPECIFIED", "STABLE", "BALANCED", or "CREATIVE"). The Inworld
API ignores ``temperature`` on TTS-2; use this instead.
timestamp_transport_strategy (str, optional): Controls timestamp transport strategy
("SYNC" or "ASYNC").
buffer_char_threshold (int, optional): For streaming, min characters before triggering.
Expand Down Expand Up @@ -1040,6 +1057,9 @@ def update_options(
self._opts.timestamp_type = timestamp_type
if is_given(text_normalization):
self._opts.text_normalization = _resolve_text_normalization(text_normalization)
if is_given(delivery_mode):
_validate_str_param(delivery_mode, "delivery_mode", DeliveryMode)
self._opts.delivery_mode = delivery_mode
if is_given(timestamp_transport_strategy):
_validate_str_param(
timestamp_transport_strategy,
Expand Down Expand Up @@ -1151,6 +1171,8 @@ async def _run(self, output_emitter: tts.AudioEmitter) -> None:
body_params["timestampType"] = self._opts.timestamp_type
if utils.is_given(self._opts.text_normalization):
body_params["applyTextNormalization"] = self._opts.text_normalization
if utils.is_given(self._opts.delivery_mode):
body_params["deliveryMode"] = self._opts.delivery_mode
body_params["timestampTransportStrategy"] = self._opts.timestamp_transport_strategy

x_request_id = str(uuid.uuid4())
Expand Down
194 changes: 194 additions & 0 deletions tests/test_plugin_inworld_tts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
"""Tests for Inworld TTS plugin configuration options and wire payloads."""

from __future__ import annotations

import asyncio
import json
from unittest.mock import AsyncMock, MagicMock

import pytest

from livekit.agents import APIConnectionError
from livekit.agents.types import NOT_GIVEN
from livekit.agents.utils import aio


async def test_delivery_mode_default_not_given():
from livekit.plugins.inworld import TTS

tts = TTS(api_key="test-key", model="inworld-tts-2")
assert tts._opts.delivery_mode is NOT_GIVEN


async def test_delivery_mode_set_on_init():
from livekit.plugins.inworld import TTS

tts = TTS(api_key="test-key", model="inworld-tts-2", delivery_mode="STABLE")
assert tts._opts.delivery_mode == "STABLE"


async def test_delivery_mode_accepts_all_documented_values():
from livekit.plugins.inworld import TTS

for value in ("DELIVERY_MODE_UNSPECIFIED", "STABLE", "BALANCED", "CREATIVE"):
tts = TTS(api_key="test-key", model="inworld-tts-2", delivery_mode=value)
assert tts._opts.delivery_mode == value


async def test_delivery_mode_rejects_unknown_value():
from livekit.plugins.inworld import TTS

with pytest.raises(ValueError):
TTS(api_key="test-key", model="inworld-tts-2", delivery_mode="EXPRESSIVE")


async def test_delivery_mode_update_options():
from livekit.plugins.inworld import TTS

tts = TTS(api_key="test-key", model="inworld-tts-2")
assert tts._opts.delivery_mode is NOT_GIVEN
tts.update_options(delivery_mode="CREATIVE")
assert tts._opts.delivery_mode == "CREATIVE"

with pytest.raises(ValueError):
tts.update_options(delivery_mode="LOUD")


async def _capture_first_ws_create_packet(opts) -> dict:
"""Drive `_InworldConnection._send_loop` against a fake websocket and
return the first `create` packet it sends, as a parsed dict.

The fake `send_str` signals an `asyncio.Event` once it has captured a
payload, so the test wakes immediately rather than polling.
"""
from livekit.plugins.inworld.tts import _CreateContextMsg, _InworldConnection

sent_payloads: list[str] = []
captured = asyncio.Event()

def _on_send(payload: str) -> None:
sent_payloads.append(payload)
captured.set()

fake_ws = MagicMock()
fake_ws.send_str = AsyncMock(side_effect=_on_send)

conn = _InworldConnection(
session=MagicMock(),
ws_url="wss://example.invalid/",
authorization="Basic test",
)
# Skip real connect(); inject the fake websocket directly so `_send_loop`
# reads from it as if a connection had been established.
conn._ws = fake_ws

await conn._outbound_queue.put(_CreateContextMsg(context_id="ctx-1", opts=opts))

send_task = asyncio.create_task(conn._send_loop())
try:
await asyncio.wait_for(captured.wait(), timeout=2.0)
finally:
conn._closed = True
await aio.cancel_and_wait(send_task)

return json.loads(sent_payloads[0])


async def test_ws_create_packet_includes_delivery_mode():
"""The WebSocket `create` packet sent by `_send_loop` includes
`deliveryMode` at the top level of the `create` object when
`delivery_mode` is set on the TTS."""
from livekit.plugins.inworld.tts import _TTSOptions

opts = _TTSOptions(
model="inworld-tts-2",
encoding="PCM",
voice="Ashley",
sample_rate=24000,
bit_rate=64000,
speaking_rate=1.0,
temperature=1.0,
delivery_mode="STABLE",
)

pkt = await _capture_first_ws_create_packet(opts)
assert pkt["create"]["deliveryMode"] == "STABLE"
assert pkt["create"]["modelId"] == "inworld-tts-2"


async def test_ws_create_packet_omits_delivery_mode_when_not_given():
"""When `delivery_mode` is not set, the WS `create` packet must not
include the `deliveryMode` key at all (Inworld treats absence as the
server default)."""
from livekit.plugins.inworld.tts import _TTSOptions

opts = _TTSOptions(
model="inworld-tts-2",
encoding="PCM",
voice="Ashley",
sample_rate=24000,
bit_rate=64000,
speaking_rate=1.0,
temperature=1.0,
)

pkt = await _capture_first_ws_create_packet(opts)
assert "deliveryMode" not in pkt["create"]


def _patch_session_to_capture_post(tts, captured: dict[str, object]) -> None:
"""Replace the TTS's aiohttp session with a stub that captures the body of
the first ``post()`` call and short-circuits the rest of the request."""

class _FakePostCM:
async def __aenter__(self):
# Raise after capture so ChunkedStream._run unwinds quickly; the
# body has already been recorded by _fake_post.
raise RuntimeError("short-circuit")

async def __aexit__(self, *exc):
return None

def _fake_post(url, *, json=None, **kwargs):
captured["url"] = url
captured["json"] = json
return _FakePostCM()

fake_session = MagicMock()
fake_session.post = _fake_post
tts._session = fake_session


async def test_http_body_includes_delivery_mode():
"""The HTTP `synthesize` request body includes `deliveryMode` at the
top level when `delivery_mode` is set on the TTS."""
from livekit.plugins.inworld import TTS

tts = TTS(api_key="test-key", model="inworld-tts-2", delivery_mode="BALANCED")
captured: dict[str, object] = {}
_patch_session_to_capture_post(tts, captured)

with pytest.raises(APIConnectionError):
async for _ in tts.synthesize("hello"):
pass

body = captured.get("json")
assert isinstance(body, dict), f"expected JSON dict, got {body!r}"
assert body["deliveryMode"] == "BALANCED"
assert body["modelId"] == "inworld-tts-2"


async def test_http_body_omits_delivery_mode_when_not_given():
from livekit.plugins.inworld import TTS

tts = TTS(api_key="test-key", model="inworld-tts-2")
captured: dict[str, object] = {}
_patch_session_to_capture_post(tts, captured)

with pytest.raises(APIConnectionError):
async for _ in tts.synthesize("hello"):
pass

body = captured.get("json")
assert isinstance(body, dict)
assert "deliveryMode" not in body
Loading