diff --git a/src/agents/voice/result.py b/src/agents/voice/result.py index fea79902e..3161da0fb 100644 --- a/src/agents/voice/result.py +++ b/src/agents/voice/result.py @@ -243,24 +243,54 @@ async def _wait_for_completion(self): tasks.append(self._dispatcher_task) await asyncio.gather(*tasks) - def _cleanup_tasks(self): + async def _cleanup_tasks(self): + """Cancel all pending tasks and wait for them to complete. + + This ensures that any exceptions raised by the tasks are properly handled + and prevents warnings about unhandled task exceptions. + """ self._finish_turn() + tasks = [] for task in self._tasks: if not task.done(): task.cancel() + if isinstance(task, asyncio.Task): + tasks.append(task) if self._dispatcher_task and not self._dispatcher_task.done(): self._dispatcher_task.cancel() + if isinstance(self._dispatcher_task, asyncio.Task): + tasks.append(self._dispatcher_task) if self.text_generation_task and not self.text_generation_task.done(): self.text_generation_task.cancel() + if isinstance(self.text_generation_task, asyncio.Task): + tasks.append(self.text_generation_task) + + # Wait for all cancelled tasks to complete and collect exceptions + if tasks: + results = await asyncio.gather(*tasks, return_exceptions=True) + # Check if any task failed with a real exception (not CancelledError) + # This catches exceptions that occurred after _check_errors() but before cancellation + if not self._stored_exception: + for result in results: + is_exception = isinstance(result, Exception) + is_not_cancelled = not isinstance(result, asyncio.CancelledError) + if is_exception and is_not_cancelled: + self._stored_exception = result + break def _check_errors(self): + """Check for exceptions in completed tasks. + + Note: task.cancelled() check ensures CancelledError is never raised. + """ for task in self._tasks: - if task.done(): - if task.exception(): - self._stored_exception = task.exception() + if task.done() and not task.cancelled(): + exc = task.exception() + if exc: + self._stored_exception = exc break async def stream(self) -> AsyncIterator[VoiceStreamEvent]: @@ -281,7 +311,7 @@ async def stream(self) -> AsyncIterator[VoiceStreamEvent]: break self._check_errors() - self._cleanup_tasks() + await self._cleanup_tasks() if self._stored_exception: raise self._stored_exception