Skip to content
Merged
2 changes: 1 addition & 1 deletion python/packages/azure-ai/tests/test_azure_ai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def mock_project_client() -> MagicMock:
mock_client.telemetry.get_application_insights_connection_string = AsyncMock()

# Mock get_openai_client method
mock_client.get_openai_client = AsyncMock()
mock_client.get_openai_client = MagicMock()

# Mock close method
mock_client.close = AsyncMock()
Expand Down
2 changes: 1 addition & 1 deletion python/packages/azure-ai/tests/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def mock_project_client() -> MagicMock:
mock_client.telemetry.get_application_insights_connection_string = AsyncMock()

# Mock get_openai_client method
mock_client.get_openai_client = AsyncMock()
mock_client.get_openai_client = MagicMock()

# Mock close method
mock_client.close = AsyncMock()
Expand Down
94 changes: 87 additions & 7 deletions python/packages/core/agent_framework/_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import asyncio
import base64
import contextvars
import json
import logging
import re
Expand Down Expand Up @@ -38,6 +39,7 @@
from mcp.shared.session import RequestResponder

from ._clients import SupportsChatGetResponse
from ._middleware import FunctionInvocationContext


logger = logging.getLogger(__name__)
Expand All @@ -59,6 +61,9 @@ class MCPSpecificApproval(TypedDict, total=False):

_MCP_REMOTE_NAME_KEY = "_mcp_remote_name"
_MCP_NORMALIZED_NAME_KEY = "_mcp_normalized_name"
_mcp_call_headers: contextvars.ContextVar[dict[str, str]] = contextvars.ContextVar("_mcp_call_headers")
MCP_DEFAULT_TIMEOUT = 30
MCP_DEFAULT_SSE_READ_TIMEOUT = 60 * 5

# region: Helpers

Expand Down Expand Up @@ -137,6 +142,22 @@ def _inject_otel_into_mcp_meta(meta: dict[str, Any] | None = None) -> dict[str,
return meta


def streamable_http_client(*args: Any, **kwargs: Any) -> _AsyncGeneratorContextManager[Any, None]:
"""Lazily import the MCP streamable HTTP transport."""
try:
from mcp.client.streamable_http import streamable_http_client as _streamable_http_client
except ModuleNotFoundError as ex:
missing_name = ex.name or str(ex)
if missing_name == "mcp" or missing_name.startswith("mcp.") or "mcp" in missing_name:
raise ModuleNotFoundError("`MCPStreamableHTTPTool` requires `mcp`. Please install `mcp`.") from ex
raise ModuleNotFoundError(
f"`MCPStreamableHTTPTool` requires streamable HTTP transport support. "
f"The optional dependency `{missing_name}` is not installed. Please update your dependencies."
) from ex

return _streamable_http_client(*args, **kwargs) # type: ignore[return-value]


# region: MCP Plugin


Expand Down Expand Up @@ -951,9 +972,20 @@ async def load_tools(self) -> None:
input_schema = dict(tool.inputSchema or {})
if input_schema.get("type") == "object" and "properties" not in input_schema:
input_schema["properties"] = {}

async def _call_tool_with_runtime_kwargs(
ctx: FunctionInvocationContext,
*,
_remote_tool_name: str = tool.name,
**kwargs: Any,
) -> str | list[Content]:
call_kwargs = dict(ctx.kwargs)
call_kwargs.update(kwargs)
return await self.call_tool(_remote_tool_name, **call_kwargs)

# Create FunctionTools out of each tool
func: FunctionTool = FunctionTool(
func=partial(self.call_tool, tool.name),
func=_call_tool_with_runtime_kwargs,
name=local_name,
description=tool.description or "",
approval_mode=approval_mode,
Expand Down Expand Up @@ -1386,6 +1418,7 @@ def __init__(
client: SupportsChatGetResponse | None = None,
additional_properties: dict[str, Any] | None = None,
http_client: AsyncClient | None = None,
header_provider: Callable[[dict[str, Any]], dict[str, str]] | None = None,
**kwargs: Any,
) -> None:
"""Initialize the MCP streamable HTTP tool.
Expand Down Expand Up @@ -1433,6 +1466,11 @@ def __init__(
``streamable_http_client`` API will create and manage a default client.
To configure headers, timeouts, or other HTTP client settings, create
and pass your own ``asyncClient`` instance.
header_provider: Optional callable that receives the runtime keyword arguments
(from ``FunctionInvocationContext.kwargs``) and returns a ``dict[str, str]``
of HTTP headers to inject into every outbound request to the MCP server.
Use this to forward per-request context (e.g. authentication tokens set in
agent middleware) without creating a separate ``httpx.AsyncClient``.
kwargs: Additional keyword arguments (accepted for backward compatibility but not used).
"""
super().__init__(
Expand All @@ -1453,25 +1491,67 @@ def __init__(
self.url = url
self.terminate_on_close = terminate_on_close
self._httpx_client: AsyncClient | None = http_client
self._header_provider = header_provider

def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
"""Get an MCP streamable HTTP client.

Returns:
An async context manager for the streamable HTTP client transport.
"""
try:
from mcp.client.streamable_http import streamable_http_client
except ModuleNotFoundError as ex:
raise ModuleNotFoundError("`mcp` is required to use `MCPStreamableHTTPTool`. Please install `mcp`.") from ex
from httpx import AsyncClient, Request, Timeout

http_client = self._httpx_client
if self._header_provider is not None:
if http_client is None:
http_client = AsyncClient(
follow_redirects=True,
timeout=Timeout(MCP_DEFAULT_TIMEOUT, read=MCP_DEFAULT_SSE_READ_TIMEOUT),
)
Comment thread
eavanvalkenburg marked this conversation as resolved.
self._httpx_client = http_client

if not hasattr(self, "_inject_headers_hook"):

async def _inject_headers(request: Request) -> None: # noqa: RUF029
headers = _mcp_call_headers.get({})
for key, value in headers.items():
request.headers[key] = value

self._inject_headers_hook = _inject_headers # type: ignore[attr-defined]
http_client.event_hooks["request"].append(self._inject_headers_hook) # type: ignore[attr-defined]

# Pass the http_client (which may be None) to streamable_http_client
return streamable_http_client(
url=self.url,
http_client=self._httpx_client,
http_client=http_client,
terminate_on_close=self.terminate_on_close if self.terminate_on_close is not None else True,
)

async def call_tool(self, tool_name: str, **kwargs: Any) -> str | list[Content]:
"""Call a tool, injecting headers from the header_provider if configured.

When a ``header_provider`` was supplied at construction time, the runtime
*kwargs* (originating from ``FunctionInvocationContext.kwargs``) are passed
to the provider. The returned headers are attached to every HTTP request
made during this tool call via a ``contextvars.ContextVar``.

Args:
tool_name: The name of the tool to call.

Keyword Args:
kwargs: Arguments to pass to the tool.

Returns:
A list of Content items representing the tool output.
"""
if self._header_provider is not None:
headers = self._header_provider(kwargs)
token = _mcp_call_headers.set(headers)
try:
return await super().call_tool(tool_name, **kwargs)
finally:
_mcp_call_headers.reset(token)
return await super().call_tool(tool_name, **kwargs)


class MCPWebsocketTool(MCPTool):
"""MCP tool for connecting to WebSocket-based MCP servers.
Expand Down
Loading
Loading