Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/agents/realtime/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,12 @@ class RealtimeModelConfig(TypedDict):
the OpenAI Realtime model will use the default OpenAI WebSocket URL.
"""

headers: NotRequired[dict[str, str]]
"""The headers to use when connecting. If unset, the model will use a sane default.
Note that, when you set this, authorization header won't be set under the hood.
e.g., {"api-key": "your api key here"} for Azure OpenAI Realtime WebSocket connections.
"""

initial_model_settings: NotRequired[RealtimeSessionModelSettings]
"""The initial model settings to use when connecting."""

Expand Down
38 changes: 23 additions & 15 deletions src/agents/realtime/openai_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,15 +188,23 @@ async def connect(self, options: RealtimeModelConfig) -> None:
else:
self._tracing_config = "auto"

if not api_key:
raise UserError("API key is required but was not provided.")

url = options.get("url", f"wss://api.openai.com/v1/realtime?model={self.model}")

headers = {
"Authorization": f"Bearer {api_key}",
"OpenAI-Beta": "realtime=v1",
}
headers: dict[str, str] = {}
if options.get("headers") is not None:
# For customizing request headers
headers.update(options["headers"])
else:
# OpenAI's Realtime API
if not api_key:
raise UserError("API key is required but was not provided.")

headers.update(
{
"Authorization": f"Bearer {api_key}",
"OpenAI-Beta": "realtime=v1",
}
)
self._websocket = await websockets.connect(
url,
user_agent_header=_USER_AGENT,
Expand Down Expand Up @@ -490,9 +498,7 @@ async def _handle_ws_event(self, event: dict[str, Any]):
try:
if "previous_item_id" in event and event["previous_item_id"] is None:
event["previous_item_id"] = "" # TODO (rm) remove
parsed: AllRealtimeServerEvents = self._server_event_type_adapter.validate_python(
event
)
parsed: AllRealtimeServerEvents = self._server_event_type_adapter.validate_python(event)
except pydantic.ValidationError as e:
logger.error(f"Failed to validate server event: {event}", exc_info=True)
await self._emit_event(
Expand Down Expand Up @@ -583,11 +589,13 @@ async def _handle_ws_event(self, event: dict[str, Any]):
):
await self._handle_output_item(parsed.item)
elif parsed.type == "input_audio_buffer.timeout_triggered":
await self._emit_event(RealtimeModelInputAudioTimeoutTriggeredEvent(
item_id=parsed.item_id,
audio_start_ms=parsed.audio_start_ms,
audio_end_ms=parsed.audio_end_ms,
))
await self._emit_event(
RealtimeModelInputAudioTimeoutTriggeredEvent(
item_id=parsed.item_id,
audio_start_ms=parsed.audio_start_ms,
audio_end_ms=parsed.audio_end_ms,
)
)

def _update_created_session(self, session: OpenAISessionObject) -> None:
self._created_session = session
Expand Down
41 changes: 39 additions & 2 deletions tests/realtime/test_openai_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,45 @@ def mock_create_task_func(coro):

# Verify internal state
assert model._websocket == mock_websocket
assert model._websocket_task is not None
assert model.model == "gpt-4o-realtime-preview"
assert model._websocket_task is not None
assert model.model == "gpt-4o-realtime-preview"

@pytest.mark.asyncio
async def test_connect_with_custom_headers_overrides_defaults(self, model, mock_websocket):
"""If custom headers are provided, use them verbatim without adding defaults."""
# Even when custom headers are provided, the implementation still requires api_key.
config = {
"api_key": "unused-because-headers-override",
"headers": {"api-key": "azure-key", "x-custom": "1"},
"url": "wss://custom.example.com/realtime?model=custom",
# Use a valid realtime model name for session.update to validate.
"initial_model_settings": {"model_name": "gpt-4o-realtime-preview"},
}

async def async_websocket(*args, **kwargs):
return mock_websocket

with patch("websockets.connect", side_effect=async_websocket) as mock_connect:
with patch("asyncio.create_task") as mock_create_task:
mock_task = AsyncMock()

def mock_create_task_func(coro):
coro.close()
return mock_task

mock_create_task.side_effect = mock_create_task_func

await model.connect(config)

# Verify WebSocket connection used the provided URL
called_url = mock_connect.call_args[0][0]
assert called_url == "wss://custom.example.com/realtime?model=custom"

# Verify headers are exactly as provided and no defaults were injected
headers = mock_connect.call_args.kwargs["additional_headers"]
assert headers == {"api-key": "azure-key", "x-custom": "1"}
assert "Authorization" not in headers
assert "OpenAI-Beta" not in headers

@pytest.mark.asyncio
async def test_connect_with_callable_api_key(self, model, mock_websocket):
Expand Down