Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions src/agents/voice/models/openai_stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,18 +321,33 @@ def _check_errors(self) -> None:
if exc and isinstance(exc, Exception):
self._stored_exception = exc

def _cleanup_tasks(self) -> None:
async def _cleanup_tasks(self) -> None:
"""Cancel all pending tasks and wait for them to complete.

This ensures that any exceptions raised by the tasks are properly handled
and prevents warnings about unhandled task exceptions.
"""
tasks = []

if self._listener_task and not self._listener_task.done():
self._listener_task.cancel()
tasks.append(self._listener_task)

if self._process_events_task and not self._process_events_task.done():
self._process_events_task.cancel()
tasks.append(self._process_events_task)

if self._stream_audio_task and not self._stream_audio_task.done():
self._stream_audio_task.cancel()
tasks.append(self._stream_audio_task)

if self._connection_task and not self._connection_task.done():
self._connection_task.cancel()
tasks.append(self._connection_task)

# Wait for all cancelled tasks to complete and collect exceptions
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)

async def transcribe_turns(self) -> AsyncIterator[str]:
self._connection_task = asyncio.create_task(self._process_websocket_connection())
Expand Down Expand Up @@ -367,7 +382,7 @@ async def close(self) -> None:
if self._websocket:
await self._websocket.close()

self._cleanup_tasks()
await self._cleanup_tasks()


class OpenAISTTModel(STTModel):
Expand Down
137 changes: 137 additions & 0 deletions tests/voice/test_openai_stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,3 +378,140 @@ async def test_inactivity_timeout():
assert len(collected_turns) == 0, "No transcripts expected, but we got something?"

await session.close()


@pytest.mark.asyncio
async def test_cleanup_tasks_cancels_and_awaits_all_tasks():
"""
Test that _cleanup_tasks() properly cancels and awaits all pending tasks.
This ensures proper resource cleanup and prevents unhandled task exceptions.
"""
mock_ws = create_mock_websocket(
[
json.dumps({"type": "transcription_session.created"}),
json.dumps({"type": "transcription_session.updated"}),
]
)

with patch("websockets.connect", return_value=mock_ws):
audio_input = await FakeStreamedAudioInput.get(count=2)
stt_settings = STTModelSettings()

session = OpenAISTTTranscriptionSession(
input=audio_input,
client=AsyncMock(api_key="FAKE_KEY"),
model="whisper-1",
settings=stt_settings,
trace_include_sensitive_data=False,
trace_include_sensitive_audio_data=False,
)

# Create some tasks to simulate active background tasks
async def long_running_task():
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
# Expected when cancelled
raise

session._listener_task = asyncio.create_task(long_running_task())
session._process_events_task = asyncio.create_task(long_running_task())
session._stream_audio_task = asyncio.create_task(long_running_task())
session._connection_task = asyncio.create_task(long_running_task())

# Verify tasks are running
assert not session._listener_task.done()
assert not session._process_events_task.done()
assert not session._stream_audio_task.done()
assert not session._connection_task.done()

# Call cleanup_tasks
await session._cleanup_tasks()

# Verify all tasks were cancelled and completed
assert session._listener_task.cancelled()
assert session._process_events_task.cancelled()
assert session._stream_audio_task.cancelled()
assert session._connection_task.cancelled()


@pytest.mark.asyncio
async def test_cleanup_tasks_handles_exceptions():
"""
Test that _cleanup_tasks() properly handles exceptions from cancelled tasks
without raising them (using return_exceptions=True).
"""
mock_ws = create_mock_websocket(
[
json.dumps({"type": "transcription_session.created"}),
json.dumps({"type": "transcription_session.updated"}),
]
)

with patch("websockets.connect", return_value=mock_ws):
audio_input = await FakeStreamedAudioInput.get(count=2)
stt_settings = STTModelSettings()

session = OpenAISTTTranscriptionSession(
input=audio_input,
client=AsyncMock(api_key="FAKE_KEY"),
model="whisper-1",
settings=stt_settings,
trace_include_sensitive_data=False,
trace_include_sensitive_audio_data=False,
)

# Create tasks that raise exceptions when cancelled
async def task_with_exception():
try:
await asyncio.sleep(10)
except asyncio.CancelledError as e:
raise RuntimeError("Task exception during cancellation") from e

session._listener_task = asyncio.create_task(task_with_exception())
session._process_events_task = asyncio.create_task(task_with_exception())

# cleanup_tasks should not raise despite the exceptions
await session._cleanup_tasks()

# Tasks should be done (cancelled or exception raised)
assert session._listener_task.done()
assert session._process_events_task.done()


@pytest.mark.asyncio
async def test_close_calls_cleanup_tasks():
"""
Test that close() properly calls _cleanup_tasks() to clean up background tasks.
"""
mock_ws = create_mock_websocket(
[
json.dumps({"type": "transcription_session.created"}),
json.dumps({"type": "transcription_session.updated"}),
]
)

with patch("websockets.connect", return_value=mock_ws):
audio_input = await FakeStreamedAudioInput.get(count=2)
stt_settings = STTModelSettings()

session = OpenAISTTTranscriptionSession(
input=audio_input,
client=AsyncMock(api_key="FAKE_KEY"),
model="whisper-1",
settings=stt_settings,
trace_include_sensitive_data=False,
trace_include_sensitive_audio_data=False,
)

# Create a task
async def long_running_task():
await asyncio.sleep(10)

session._listener_task = asyncio.create_task(long_running_task())

# close() should cancel and await the task
await session.close()

# Task should be cancelled
assert session._listener_task.cancelled()