Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 76 additions & 57 deletions src/agents/realtime/openai_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
from typing import Any, Callable

import websockets
from openai.types.beta.realtime.conversation_item import ConversationItem
from openai.types.beta.realtime.realtime_server_event import (
RealtimeServerEvent as OpenAIRealtimeServerEvent,
)
from openai.types.beta.realtime.response_audio_delta_event import ResponseAudioDeltaEvent
from pydantic import TypeAdapter
from websockets.asyncio.client import ClientConnection

Expand Down Expand Up @@ -233,6 +235,77 @@ async def interrupt(self) -> None:
self._audio_length_ms = 0.0
self._current_audio_content_index = None

async def _handle_audio_delta(self, parsed: ResponseAudioDeltaEvent) -> None:
"""Handle audio delta events and update audio tracking state."""
self._current_audio_content_index = parsed.content_index
self._current_item_id = parsed.item_id
if self._audio_start_time is None:
self._audio_start_time = datetime.now()
self._audio_length_ms = 0.0

audio_bytes = base64.b64decode(parsed.delta)
# Calculate audio length in ms using 24KHz pcm16le
self._audio_length_ms += self._calculate_audio_length_ms(audio_bytes)
await self._emit_event(
RealtimeModelAudioEvent(data=audio_bytes, response_id=parsed.response_id)
)

def _calculate_audio_length_ms(self, audio_bytes: bytes) -> float:
"""Calculate audio length in milliseconds for 24KHz PCM16LE format."""
return len(audio_bytes) / 24 / 2

async def _handle_output_item(self, item: ConversationItem) -> None:
"""Handle response output item events (function calls and messages)."""
if item.type == "function_call" and item.status == "completed":
tool_call = RealtimeToolCallItem(
item_id=item.id or "",
previous_item_id=None,
type="function_call",
# We use the same item for tool call and output, so it will be completed by the
# output being added
status="in_progress",
arguments=item.arguments or "",
name=item.name or "",
output=None,
)
await self._emit_event(RealtimeModelItemUpdatedEvent(item=tool_call))
await self._emit_event(
RealtimeModelToolCallEvent(
call_id=item.id or "",
name=item.name or "",
arguments=item.arguments or "",
id=item.id or "",
)
)
elif item.type == "message":
# Handle message items from output_item events (no previous_item_id)
message_item: RealtimeMessageItem = TypeAdapter(RealtimeMessageItem).validate_python(
{
"item_id": item.id or "",
"type": item.type,
"role": item.role,
"content": item.content,
"status": "in_progress",
}
)
await self._emit_event(RealtimeModelItemUpdatedEvent(item=message_item))

async def _handle_conversation_item(
self, item: ConversationItem, previous_item_id: str | None
) -> None:
"""Handle conversation item creation/retrieval events."""
message_item: RealtimeMessageItem = TypeAdapter(RealtimeMessageItem).validate_python(
{
"item_id": item.id or "",
"previous_item_id": previous_item_id,
"type": item.type,
"role": item.role,
"content": item.content,
"status": "in_progress",
}
)
await self._emit_event(RealtimeModelItemUpdatedEvent(item=message_item))

async def close(self) -> None:
"""Close the session."""
if self._websocket:
Expand All @@ -258,18 +331,7 @@ async def _handle_ws_event(self, event: dict[str, Any]):
return

if parsed.type == "response.audio.delta":
self._current_audio_content_index = parsed.content_index
self._current_item_id = parsed.item_id
if self._audio_start_time is None:
self._audio_start_time = datetime.now()
self._audio_length_ms = 0.0

audio_bytes = base64.b64decode(parsed.delta)
# Calculate audio length in ms using 24KHz pcm16le
self._audio_length_ms += len(audio_bytes) / 24 / 2
await self._emit_event(
RealtimeModelAudioEvent(data=audio_bytes, response_id=parsed.response_id)
)
await self._handle_audio_delta(parsed)
elif parsed.type == "response.audio.done":
await self._emit_event(RealtimeModelAudioDoneEvent())
elif parsed.type == "input_audio_buffer.speech_started":
Expand All @@ -291,21 +353,10 @@ async def _handle_ws_event(self, event: dict[str, Any]):
parsed.type == "conversation.item.created"
or parsed.type == "conversation.item.retrieved"
):
item = parsed.item
previous_item_id = (
parsed.previous_item_id if parsed.type == "conversation.item.created" else None
)
message_item: RealtimeMessageItem = TypeAdapter(RealtimeMessageItem).validate_python(
{
"item_id": item.id or "",
"previous_item_id": previous_item_id,
"type": item.type,
"role": item.role,
"content": item.content,
"status": "in_progress",
}
)
await self._emit_event(RealtimeModelItemUpdatedEvent(item=message_item))
await self._handle_conversation_item(parsed.item, previous_item_id)
elif (
parsed.type == "conversation.item.input_audio_transcription.completed"
or parsed.type == "conversation.item.truncated"
Expand Down Expand Up @@ -341,36 +392,4 @@ async def _handle_ws_event(self, event: dict[str, Any]):
parsed.type == "response.output_item.added"
or parsed.type == "response.output_item.done"
):
item = parsed.item
if item.type == "function_call" and item.status == "completed":
tool_call = RealtimeToolCallItem(
item_id=item.id or "",
previous_item_id=None,
type="function_call",
# We use the same item for tool call and output, so it will be completed by the
# output being added
status="in_progress",
arguments=item.arguments or "",
name=item.name or "",
output=None,
)
await self._emit_event(RealtimeModelItemUpdatedEvent(item=tool_call))
await self._emit_event(
RealtimeModelToolCallEvent(
call_id=item.id or "",
name=item.name or "",
arguments=item.arguments or "",
id=item.id or "",
)
)
elif item.type == "message":
message_item = TypeAdapter(RealtimeMessageItem).validate_python(
{
"item_id": item.id or "",
"type": item.type,
"role": item.role,
"content": item.content,
"status": "in_progress",
}
)
await self._emit_event(RealtimeModelItemUpdatedEvent(item=message_item))
await self._handle_output_item(parsed.item)
Loading