Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import time
import weakref
from collections.abc import Iterator
from dataclasses import dataclass
from dataclasses import dataclass, replace
from typing import Any, Literal, overload
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse

Expand Down Expand Up @@ -793,6 +793,8 @@ class RealtimeSession(
def __init__(self, realtime_model: RealtimeModel) -> None:
super().__init__(realtime_model)
self._realtime_model: RealtimeModel = realtime_model
# per-session copy of opts so update_options can diff against session's own state
self._opts = replace(realtime_model._opts)
self._tools = llm.ToolContext.empty()
self._msg_ch = utils.aio.Chan[RealtimeClientEvent | dict[str, Any]]()
self._input_resampler: rtc.AudioResampler | None = None
Expand Down Expand Up @@ -824,12 +826,12 @@ def send_event(self, event: RealtimeClientEvent | dict[str, Any]) -> None:
@utils.log_exceptions(logger=logger)
async def _main_task(self) -> None:
num_retries: int = 0
max_retries = self._realtime_model._opts.conn_options.max_retry
max_retries = self._opts.conn_options.max_retry

async def _reconnect() -> None:
logger.debug(
f"reconnecting to {self._realtime_model._provider_label}",
extra={"max_session_duration": self._realtime_model._opts.max_session_duration},
extra={"max_session_duration": self._opts.max_session_duration},
)

events: list[RealtimeClientEvent | dict[str, Any]] = []
Expand Down Expand Up @@ -862,10 +864,7 @@ async def _reconnect() -> None:
by_alias=True, exclude_unset=True, exclude_defaults=False
)

if (
self._realtime_model._opts.is_azure
and self._realtime_model._opts.api_version
):
if self._opts.is_azure and self._opts.api_version:
_normalize_azure_client_event(ev)

self.emit("openai_client_event_queued", ev)
Expand Down Expand Up @@ -910,9 +909,7 @@ async def _reconnect() -> None:
else:
self._emit_error(e, recoverable=True)

retry_interval = self._realtime_model._opts.conn_options._interval_for_retry(
num_retries
)
retry_interval = self._opts.conn_options._interval_for_retry(num_retries)
logger.warning(
f"{self._realtime_model._provider_label} connection failed, retrying in {retry_interval}s",
exc_info=e,
Expand All @@ -929,21 +926,21 @@ async def _reconnect() -> None:

async def _create_ws_conn(self) -> aiohttp.ClientWebSocketResponse:
headers = {"User-Agent": "LiveKit Agents"}
if self._realtime_model._opts.is_azure:
if self._realtime_model._opts.entra_token:
headers["Authorization"] = f"Bearer {self._realtime_model._opts.entra_token}"
if self._opts.is_azure:
if self._opts.entra_token:
headers["Authorization"] = f"Bearer {self._opts.entra_token}"

if self._realtime_model._opts.api_key:
headers["api-key"] = self._realtime_model._opts.api_key
if self._opts.api_key:
headers["api-key"] = self._opts.api_key
else:
headers["Authorization"] = f"Bearer {self._realtime_model._opts.api_key}"
headers["Authorization"] = f"Bearer {self._opts.api_key}"

url = process_base_url(
self._realtime_model._opts.base_url,
self._realtime_model._opts.model,
is_azure=self._realtime_model._opts.is_azure,
api_version=self._realtime_model._opts.api_version,
azure_deployment=self._realtime_model._opts.azure_deployment,
self._opts.base_url,
self._opts.model,
is_azure=self._opts.is_azure,
api_version=self._opts.api_version,
azure_deployment=self._opts.azure_deployment,
)

if lk_oai_debug:
Expand All @@ -953,7 +950,7 @@ async def _create_ws_conn(self) -> aiohttp.ClientWebSocketResponse:
try:
ws = await asyncio.wait_for(
self._realtime_model._ensure_http_session().ws_connect(url=url, headers=headers),
self._realtime_model._opts.conn_options.timeout,
self._opts.conn_options.timeout,
)
self._report_connection_acquired(time.perf_counter() - t0)
return ws
Expand Down Expand Up @@ -981,10 +978,7 @@ async def _send_task() -> None:

# Azure uses "text" for assistant content parts, while
# the new API uses "output_text" for assistant content.
if (
self._realtime_model._opts.is_azure
and self._realtime_model._opts.api_version
):
if self._opts.is_azure and self._opts.api_version:
_normalize_azure_client_event(msg)

self.emit("openai_client_event_queued", msg)
Expand Down Expand Up @@ -1027,7 +1021,7 @@ async def _recv_task() -> None:
# Azure OpenAI uses old-style event names from the beta API.
# Normalize them to the current OpenAI event names so the rest
# of the handler code only needs to deal with one set of names.
if self._realtime_model._opts.is_azure:
if self._opts.is_azure:
event_type = event.get("type", "")
normalized = _AZURE_EVENT_MAPPING.get(event_type)
if normalized is not None:
Expand Down Expand Up @@ -1114,9 +1108,9 @@ async def _recv_task() -> None:
asyncio.create_task(_send_task(), name="_send_task"),
]
wait_reconnect_task: asyncio.Task | None = None
if self._realtime_model._opts.max_session_duration is not None:
if self._opts.max_session_duration is not None:
wait_reconnect_task = asyncio.create_task(
asyncio.sleep(self._realtime_model._opts.max_session_duration),
asyncio.sleep(self._opts.max_session_duration),
name="_timeout_task",
)
tasks.append(wait_reconnect_task)
Expand Down Expand Up @@ -1146,7 +1140,7 @@ def _wrap_session_update(
and returns a dict (since AzureSessionUpdateEvent is not part of
the RealtimeClientEvent union).
"""
if self._realtime_model._opts.is_azure and self._realtime_model._opts.api_version:
if self._opts.is_azure and self._opts.api_version:
# legacy Azure API: convert to old flat format
return AzureSessionUpdateEvent(
type="session.update",
Expand All @@ -1163,8 +1157,8 @@ def _wrap_session_update(
def _create_session_update_event(self) -> SessionUpdateEvent | dict[str, Any]:
audio_format = realtime.realtime_audio_formats.AudioPCM(rate=SAMPLE_RATE, type="audio/pcm")
# they do not support both text and audio modalities, it'll respond in audio + transcript
modality = "audio" if "audio" in self._realtime_model._opts.modalities else "text"
opts = self._realtime_model._opts
modality = "audio" if "audio" in self._opts.modalities else "text"
Comment thread
longcw marked this conversation as resolved.
opts = self._opts

session = RealtimeSessionCreateRequest(
type="realtime",
Expand Down Expand Up @@ -1223,69 +1217,66 @@ def update_options(
has_changes = False

if is_given(tool_choice):
current_oai = to_oai_tool_choice(self._realtime_model._opts.tool_choice)
current_oai = to_oai_tool_choice(self._opts.tool_choice)
next_oai = to_oai_tool_choice(tool_choice)
self._realtime_model._opts.tool_choice = tool_choice
self._opts.tool_choice = tool_choice
if current_oai != next_oai:
session.tool_choice = next_oai
has_changes = True

if is_given(max_response_output_tokens):
if self._realtime_model._opts.max_response_output_tokens != max_response_output_tokens:
if self._opts.max_response_output_tokens != max_response_output_tokens:
session.max_output_tokens = max_response_output_tokens
has_changes = True
self._realtime_model._opts.max_response_output_tokens = max_response_output_tokens
self._opts.max_response_output_tokens = max_response_output_tokens

if is_given(tracing):
if self._realtime_model._opts.tracing != tracing:
if self._opts.tracing != tracing:
session.tracing = tracing # type: ignore[assignment]
has_changes = True
self._realtime_model._opts.tracing = tracing
self._opts.tracing = tracing

if is_given(truncation):
if self._realtime_model._opts.truncation != truncation:
if self._opts.truncation != truncation:
session.truncation = truncation
has_changes = True
self._realtime_model._opts.truncation = truncation
self._opts.truncation = truncation

has_audio_config = False
audio_output = RealtimeAudioConfigOutput()
audio_input = RealtimeAudioConfigInput()
audio_config = RealtimeAudioConfig(output=audio_output, input=audio_input)

if is_given(voice):
if self._realtime_model._opts.voice != voice:
if self._opts.voice != voice:
audio_output.voice = voice
has_audio_config = True
self._realtime_model._opts.voice = voice
self._opts.voice = voice

if is_given(turn_detection):
if self._realtime_model._opts.turn_detection != turn_detection:
if self._opts.turn_detection != turn_detection:
audio_input.turn_detection = turn_detection
has_audio_config = True
self._realtime_model._opts.turn_detection = turn_detection
self._opts.turn_detection = turn_detection

if is_given(input_audio_transcription):
if self._realtime_model._opts.input_audio_transcription != input_audio_transcription:
if self._opts.input_audio_transcription != input_audio_transcription:
audio_input.transcription = input_audio_transcription
has_audio_config = True
self._realtime_model._opts.input_audio_transcription = input_audio_transcription
self._opts.input_audio_transcription = input_audio_transcription

if is_given(input_audio_noise_reduction):
input_audio_noise_reduction = to_noise_reduction(input_audio_noise_reduction)
if (
self._realtime_model._opts.input_audio_noise_reduction
!= input_audio_noise_reduction
):
if self._opts.input_audio_noise_reduction != input_audio_noise_reduction:
audio_input.noise_reduction = input_audio_noise_reduction
has_audio_config = True
self._realtime_model._opts.input_audio_noise_reduction = input_audio_noise_reduction
self._opts.input_audio_noise_reduction = input_audio_noise_reduction

if is_given(speed):
if self._realtime_model._opts.speed != speed:
if self._opts.speed != speed:
audio_output.speed = speed
has_audio_config = True
self._realtime_model._opts.speed = speed
self._opts.speed = speed

if has_audio_config:
session.audio = audio_config
Expand Down Expand Up @@ -1451,7 +1442,7 @@ def _create_tools_update_event(self, tools: list[llm.Tool]) -> dict[str, Any]:
event_id=utils.shortuuid("tools_update_"),
session=RealtimeSessionCreateRequest.model_construct(
type="realtime",
model=self._realtime_model._opts.model,
model=self._opts.model,
tools=oai_tools, # type: ignore
),
)
Expand Down Expand Up @@ -1604,7 +1595,7 @@ def _close_current_generation(self, reason: str | None = None) -> None:
generation.text_ch.close()
generation.audio_ch.close()
if not generation.modalities.done():
generation.modalities.set_result(self._realtime_model._opts.modalities)
generation.modalities.set_result(self._opts.modalities)

self._current_generation.function_ch.close()
self._current_generation.message_ch.close()
Expand Down Expand Up @@ -1645,9 +1636,7 @@ def _handle_input_audio_buffer_speech_started(
def _handle_input_audio_buffer_speech_stopped(
self, _: InputAudioBufferSpeechStoppedEvent
) -> None:
user_transcription_enabled = (
self._realtime_model._opts.input_audio_transcription is not None
)
user_transcription_enabled = self._opts.input_audio_transcription is not None
self.emit(
"input_speech_stopped",
llm.InputSpeechStoppedEvent(user_transcription_enabled=user_transcription_enabled),
Expand Down Expand Up @@ -1857,7 +1846,7 @@ def _handle_response_output_item_done(self, event: ResponseOutputItemDoneEvent)
item_generation.audio_ch.close()
if not item_generation.modalities.done():
# in case message modalities is not set, this shouldn't happen
item_generation.modalities.set_result(self._realtime_model._opts.modalities)
item_generation.modalities.set_result(self._opts.modalities)

def _handle_function_call(self, item: RealtimeConversationItemFunctionCall) -> None:
assert self._current_generation is not None, "current_generation is None"
Expand Down Expand Up @@ -1892,7 +1881,7 @@ def _handle_response_done(self, event: ResponseDoneEvent) -> None:
if not generation.audio_ch.closed:
generation.audio_ch.close()
if not generation.modalities.done():
generation.modalities.set_result(self._realtime_model._opts.modalities)
generation.modalities.set_result(self._opts.modalities)

self._current_generation.function_ch.close()
self._current_generation.message_ch.close()
Expand Down
Loading