Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 35 additions & 5 deletions src/agents/voice/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,24 +243,54 @@ async def _wait_for_completion(self):
tasks.append(self._dispatcher_task)
await asyncio.gather(*tasks)

def _cleanup_tasks(self):
async def _cleanup_tasks(self):
"""Cancel all pending tasks and wait for them to complete.

This ensures that any exceptions raised by the tasks are properly handled
and prevents warnings about unhandled task exceptions.
"""
self._finish_turn()

tasks = []
for task in self._tasks:
if not task.done():
task.cancel()
if isinstance(task, asyncio.Task):
tasks.append(task)

if self._dispatcher_task and not self._dispatcher_task.done():
self._dispatcher_task.cancel()
if isinstance(self._dispatcher_task, asyncio.Task):
tasks.append(self._dispatcher_task)

if self.text_generation_task and not self.text_generation_task.done():
self.text_generation_task.cancel()
if isinstance(self.text_generation_task, asyncio.Task):
tasks.append(self.text_generation_task)

# Wait for all cancelled tasks to complete and collect exceptions
if tasks:
results = await asyncio.gather(*tasks, return_exceptions=True)
# Check if any task failed with a real exception (not CancelledError)
# This catches exceptions that occurred after _check_errors() but before cancellation
if not self._stored_exception:
for result in results:
is_exception = isinstance(result, Exception)
is_not_cancelled = not isinstance(result, asyncio.CancelledError)
if is_exception and is_not_cancelled:
self._stored_exception = result
break

def _check_errors(self):
"""Check for exceptions in completed tasks.

Note: task.cancelled() check ensures CancelledError is never raised.
"""
for task in self._tasks:
if task.done():
if task.exception():
self._stored_exception = task.exception()
if task.done() and not task.cancelled():
exc = task.exception()
if exc:
self._stored_exception = exc
break

async def stream(self) -> AsyncIterator[VoiceStreamEvent]:
Expand All @@ -281,7 +311,7 @@ async def stream(self) -> AsyncIterator[VoiceStreamEvent]:
break

self._check_errors()
self._cleanup_tasks()
await self._cleanup_tasks()

if self._stored_exception:
raise self._stored_exception