Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions python/packages/core/agent_framework/_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,15 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
self.invocation_exception_count += 1
raise

async def _invoke_function(self, call_kwargs: Mapping[str, Any]) -> Any:
"""Run sync tools off the event loop during async invocation."""
func = self.func.func if isinstance(self.func, FunctionTool) else self.func
if inspect.iscoroutinefunction(func):
return await self.__call__(**call_kwargs)

res = await asyncio.to_thread(self.__call__, **call_kwargs)
return await res if inspect.isawaitable(res) else res
Comment on lines +540 to +547

@overload
async def invoke(
self,
Expand Down Expand Up @@ -679,8 +688,7 @@ async def invoke(
if not OBSERVABILITY_SETTINGS.ENABLED: # type: ignore[name-defined]
logger.info(f"Function name: {self.name}")
logger.debug(f"Function arguments: {observable_kwargs}")
res = self.__call__(**call_kwargs)
result = await res if inspect.isawaitable(res) else res
result = await self._invoke_function(call_kwargs)
if skip_parsing:
logger.info(f"Function {self.name} succeeded.")
logger.debug(f"Function result: {type(result).__name__}")
Expand Down Expand Up @@ -730,8 +738,7 @@ async def invoke(
start_time_stamp = perf_counter()
end_time_stamp: float | None = None
try:
res = self.__call__(**call_kwargs)
result = await res if inspect.isawaitable(res) else res
result = await self._invoke_function(call_kwargs)
end_time_stamp = perf_counter()
except Exception as exception:
end_time_stamp = perf_counter()
Expand Down
25 changes: 25 additions & 0 deletions python/packages/core/tests/core/test_tools.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import threading
from typing import Annotated, Any, Literal, get_args, get_origin
from unittest.mock import Mock

Expand Down Expand Up @@ -1346,6 +1348,29 @@ async def slow(x: int) -> int:
assert raw == 42


async def test_invoke_sync_tool_does_not_block_event_loop() -> None:
release_tool = threading.Event()
tool_thread_ids: list[int] = []
event_loop_thread_id = threading.get_ident()

@tool
def wait_for_release() -> str:
tool_thread_ids.append(threading.get_ident())
return "released" if release_tool.wait(timeout=0.2) else "timed out"

async def release_soon() -> None:
await asyncio.sleep(0.01)
release_tool.set()

tool_task = asyncio.create_task(wait_for_release.invoke(skip_parsing=True))
release_task = asyncio.create_task(release_soon())

assert await asyncio.wait_for(tool_task, timeout=1) == "released"
await release_task
assert tool_thread_ids
assert tool_thread_ids[0] != event_loop_thread_id


async def test_invoke_skip_parsing_bypasses_configured_result_parser() -> None:
"""The tool's own result_parser is bypassed when skip_parsing=True is requested."""
parser_calls: list[Any] = []
Expand Down