From 37330b770323a4ab2b83d48dae5f3db0414e69a4 Mon Sep 17 00:00:00 2001 From: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Date: Tue, 12 May 2026 16:38:49 +0800 Subject: [PATCH] fix: run sync tools off event loop --- .../packages/core/agent_framework/_tools.py | 15 ++++++++--- python/packages/core/tests/core/test_tools.py | 25 +++++++++++++++++++ 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 93722a8987..d18c13f7e3 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -537,6 +537,15 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: self.invocation_exception_count += 1 raise + async def _invoke_function(self, call_kwargs: Mapping[str, Any]) -> Any: + """Run sync tools off the event loop during async invocation.""" + func = self.func.func if isinstance(self.func, FunctionTool) else self.func + if inspect.iscoroutinefunction(func): + return await self.__call__(**call_kwargs) + + res = await asyncio.to_thread(self.__call__, **call_kwargs) + return await res if inspect.isawaitable(res) else res + @overload async def invoke( self, @@ -679,8 +688,7 @@ async def invoke( if not OBSERVABILITY_SETTINGS.ENABLED: # type: ignore[name-defined] logger.info(f"Function name: {self.name}") logger.debug(f"Function arguments: {observable_kwargs}") - res = self.__call__(**call_kwargs) - result = await res if inspect.isawaitable(res) else res + result = await self._invoke_function(call_kwargs) if skip_parsing: logger.info(f"Function {self.name} succeeded.") logger.debug(f"Function result: {type(result).__name__}") @@ -730,8 +738,7 @@ async def invoke( start_time_stamp = perf_counter() end_time_stamp: float | None = None try: - res = self.__call__(**call_kwargs) - result = await res if inspect.isawaitable(res) else res + result = await self._invoke_function(call_kwargs) end_time_stamp = perf_counter() except Exception as exception: end_time_stamp = perf_counter() diff --git a/python/packages/core/tests/core/test_tools.py b/python/packages/core/tests/core/test_tools.py index b3762bf4ef..ee15e8d16f 100644 --- a/python/packages/core/tests/core/test_tools.py +++ b/python/packages/core/tests/core/test_tools.py @@ -1,4 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. +import asyncio +import threading from typing import Annotated, Any, Literal, get_args, get_origin from unittest.mock import Mock @@ -1346,6 +1348,29 @@ async def slow(x: int) -> int: assert raw == 42 +async def test_invoke_sync_tool_does_not_block_event_loop() -> None: + release_tool = threading.Event() + tool_thread_ids: list[int] = [] + event_loop_thread_id = threading.get_ident() + + @tool + def wait_for_release() -> str: + tool_thread_ids.append(threading.get_ident()) + return "released" if release_tool.wait(timeout=0.2) else "timed out" + + async def release_soon() -> None: + await asyncio.sleep(0.01) + release_tool.set() + + tool_task = asyncio.create_task(wait_for_release.invoke(skip_parsing=True)) + release_task = asyncio.create_task(release_soon()) + + assert await asyncio.wait_for(tool_task, timeout=1) == "released" + await release_task + assert tool_thread_ids + assert tool_thread_ids[0] != event_loop_thread_id + + async def test_invoke_skip_parsing_bypasses_configured_result_parser() -> None: """The tool's own result_parser is bypassed when skip_parsing=True is requested.""" parser_calls: list[Any] = []