Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/mcp/client/_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
from types import TracebackType
from typing import Any

import anyio

from mcp.client._transport import TransportStreams
from mcp.server import Server
from mcp.server.mcpserver import MCPServer
from mcp.shared._task_group import create_mcp_task_group
from mcp.shared.memory import create_client_server_memory_streams


Expand Down Expand Up @@ -48,7 +47,7 @@ async def _connect(self) -> AsyncIterator[TransportStreams]:
client_read, client_write = client_streams
server_read, server_write = server_streams

async with anyio.create_task_group() as tg:
async with create_mcp_task_group() as tg:
# Start server in background
tg.start_soon(
lambda: actual_server.run(
Expand Down
4 changes: 2 additions & 2 deletions src/mcp/client/session_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from types import TracebackType
from typing import Any, TypeAlias

import anyio
import httpx
from pydantic import BaseModel, Field
from typing_extensions import Self
Expand All @@ -25,6 +24,7 @@
from mcp.client.stdio import StdioServerParameters
from mcp.client.streamable_http import streamable_http_client
from mcp.shared._httpx_utils import create_mcp_http_client
from mcp.shared._task_group import create_mcp_task_group
from mcp.shared.exceptions import MCPError
from mcp.shared.session import ProgressFnT

Expand Down Expand Up @@ -166,7 +166,7 @@ async def __aexit__(
await self._exit_stack.aclose()

# Concurrently close session stacks.
async with anyio.create_task_group() as tg:
async with create_mcp_task_group() as tg:
for exit_stack in self._session_exit_stacks.values():
tg.start_soon(exit_stack.aclose)

Expand Down
3 changes: 2 additions & 1 deletion src/mcp/client/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from mcp import types
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
from mcp.shared._task_group import create_mcp_task_group
from mcp.shared.message import SessionMessage

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -60,7 +61,7 @@ async def sse_client(
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)

async with anyio.create_task_group() as tg:
async with create_mcp_task_group() as tg:
try:
logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}")
async with httpx_client_factory(
Expand Down
3 changes: 2 additions & 1 deletion src/mcp/client/stdio.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
get_windows_executable_command,
terminate_windows_process_tree,
)
from mcp.shared._task_group import create_mcp_task_group
from mcp.shared.message import SessionMessage

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -177,7 +178,7 @@ async def stdin_writer():
except anyio.ClosedResourceError: # pragma: no cover
await anyio.lowlevel.checkpoint()

async with anyio.create_task_group() as tg, process:
async with create_mcp_task_group() as tg, process:
tg.start_soon(stdout_reader)
tg.start_soon(stdin_writer)
try:
Expand Down
3 changes: 2 additions & 1 deletion src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from mcp.client._transport import TransportStreams
from mcp.shared._httpx_utils import create_mcp_http_client
from mcp.shared._task_group import create_mcp_task_group
from mcp.shared.message import ClientMessageMetadata, SessionMessage
from mcp.types import (
INTERNAL_ERROR,
Expand Down Expand Up @@ -546,7 +547,7 @@ async def streamable_http_client(

transport = StreamableHTTPTransport(url)

async with anyio.create_task_group() as tg:
async with create_mcp_task_group() as tg:
try:
logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")

Expand Down
3 changes: 2 additions & 1 deletion src/mcp/client/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from websockets.typing import Subprotocol

from mcp import types
from mcp.shared._task_group import create_mcp_task_group
from mcp.shared.message import SessionMessage


Expand Down Expand Up @@ -68,7 +69,7 @@ async def ws_writer():
msg_dict = session_message.message.model_dump(by_alias=True, mode="json", exclude_unset=True)
await ws.send(json.dumps(msg_dict))

async with anyio.create_task_group() as tg:
async with create_mcp_task_group() as tg:
# Start reader and writer tasks
tg.start_soon(ws_reader)
tg.start_soon(ws_writer)
Expand Down
5 changes: 2 additions & 3 deletions src/mcp/server/experimental/task_result_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
import logging
from typing import Any

import anyio

from mcp.server.session import ServerSession
from mcp.shared._task_group import create_mcp_task_group
from mcp.shared.exceptions import MCPError
from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY, is_terminal
from mcp.shared.experimental.tasks.message_queue import TaskMessageQueue
Expand Down Expand Up @@ -162,7 +161,7 @@ async def _wait_for_task_update(self, task_id: str) -> None:

Races between store update and queue message - first one wins.
"""
async with anyio.create_task_group() as tg:
async with create_mcp_task_group() as tg:

async def wait_for_store() -> None:
try:
Expand Down
6 changes: 3 additions & 3 deletions src/mcp/server/experimental/task_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from contextlib import asynccontextmanager
from dataclasses import dataclass, field

import anyio
from anyio.abc import TaskGroup

from mcp.server.experimental.task_result_handler import TaskResultHandler
from mcp.server.session import ServerSession
from mcp.shared._task_group import create_mcp_task_group
from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore
from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue, TaskMessageQueue
from mcp.shared.experimental.tasks.store import TaskStore
Expand Down Expand Up @@ -79,8 +79,8 @@ async def run(self) -> AsyncIterator[None]:
# Task group is now available
...
"""
async with anyio.create_task_group() as tg:
self._task_group = tg
async with create_mcp_task_group() as tg:
self._task_group = tg # type: ignore[assignment]
try:
yield
finally:
Expand Down
3 changes: 2 additions & 1 deletion src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ async def main():
from mcp.server.streamable_http import EventStore
from mcp.server.streamable_http_manager import StreamableHTTPASGIApp, StreamableHTTPSessionManager
from mcp.server.transport_security import TransportSecuritySettings
from mcp.shared._task_group import create_mcp_task_group
from mcp.shared.exceptions import MCPError
from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.shared.session import RequestResponder
Expand Down Expand Up @@ -386,7 +387,7 @@ async def run(
task_support.configure_session(session)
await stack.enter_async_context(task_support.run())

async with anyio.create_task_group() as tg:
async with create_mcp_task_group() as tg:
async for message in session.incoming_messages:
logger.debug("Received message: %s", message)

Expand Down
3 changes: 2 additions & 1 deletion src/mcp/server/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ async def handle_sse(request):
TransportSecurityMiddleware,
TransportSecuritySettings,
)
from mcp.shared._task_group import create_mcp_task_group
from mcp.shared.message import ServerMessageMetadata, SessionMessage

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -174,7 +175,7 @@ async def sse_writer():
}
)

async with anyio.create_task_group() as tg:
async with create_mcp_task_group() as tg:

async def response_wrapper(scope: Scope, receive: Receive, send: Send):
"""The EventSourceResponse returning signals a client close / disconnect.
Expand Down
3 changes: 2 additions & 1 deletion src/mcp/server/stdio.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ async def run_server():
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream

from mcp import types
from mcp.shared._task_group import create_mcp_task_group
from mcp.shared.message import SessionMessage


Expand Down Expand Up @@ -77,7 +78,7 @@ async def stdout_writer():
except anyio.ClosedResourceError: # pragma: no cover
await anyio.lowlevel.checkpoint()

async with anyio.create_task_group() as tg:
async with create_mcp_task_group() as tg:
tg.start_soon(stdin_reader)
tg.start_soon(stdout_writer)
yield read_stream, write_stream
5 changes: 3 additions & 2 deletions src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from starlette.types import Receive, Scope, Send

from mcp.server.transport_security import TransportSecurityMiddleware, TransportSecuritySettings
from mcp.shared._task_group import create_mcp_task_group
from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
from mcp.types import (
Expand Down Expand Up @@ -614,7 +615,7 @@ async def sse_writer(): # pragma: lax no cover
# Start the SSE response (this will send headers immediately)
try:
# First send the response to establish the SSE connection
async with anyio.create_task_group() as tg:
async with create_mcp_task_group() as tg:
tg.start_soon(response, scope, receive, send)
# Then send the message to be processed by the server
session_message = self._create_session_message(message, request, request_id, protocol_version)
Expand Down Expand Up @@ -970,7 +971,7 @@ async def connect(
self._write_stream = write_stream

# Start a task group for message routing
async with anyio.create_task_group() as tg:
async with create_mcp_task_group() as tg:
# Create a message router that distributes messages to request streams
async def message_router():
try:
Expand Down
3 changes: 2 additions & 1 deletion src/mcp/server/streamable_http_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
StreamableHTTPServerTransport,
)
from mcp.server.transport_security import TransportSecuritySettings
from mcp.shared._task_group import create_mcp_task_group
from mcp.types import INVALID_REQUEST, ErrorData, JSONRPCError

if TYPE_CHECKING:
Expand Down Expand Up @@ -122,7 +123,7 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]:
)
self._has_started = True

async with anyio.create_task_group() as tg:
async with create_mcp_task_group() as tg:
# Store the task group for later use
self._task_group = tg
logger.info("StreamableHTTP session manager started")
Expand Down
3 changes: 2 additions & 1 deletion src/mcp/server/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from starlette.websockets import WebSocket

from mcp import types
from mcp.shared._task_group import create_mcp_task_group
from mcp.shared.message import SessionMessage


Expand Down Expand Up @@ -52,7 +53,7 @@ async def ws_writer():
except anyio.ClosedResourceError:
await websocket.close()

async with anyio.create_task_group() as tg:
async with create_mcp_task_group() as tg:
tg.start_soon(ws_reader)
tg.start_soon(ws_writer)
yield (read_stream, write_stream)
93 changes: 93 additions & 0 deletions src/mcp/shared/_task_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""Task group wrapper that collapses single-exception ExceptionGroups.

When an anyio task group contains tasks and one fails, the exception is
always wrapped in an ExceptionGroup — even if there is only one real
exception. This makes it impossible for callers to catch specific error
types with ``except SomeError:``.

This module provides a drop-in replacement for ``anyio.create_task_group()``
that automatically unwraps single-exception groups so callers receive the
original exception directly.
"""

from __future__ import annotations

import sys
from types import TracebackType

import anyio
from anyio.abc import TaskGroup

if sys.version_info < (3, 11): # pragma: lax no cover
from exceptiongroup import BaseExceptionGroup # pragma: lax no cover


def collapse_exception_group(exc: BaseExceptionGroup) -> BaseException: # type: ignore[type-arg]
"""Unwrap nested single-exception BaseExceptionGroups.

If the group (and any nested groups) each contain exactly one exception,
return the innermost real exception. Otherwise return *exc* unchanged.
"""
while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1: # type: ignore[reportUnnecessaryIsInstance]
exc = exc.exceptions[0] # type: ignore[assignment]
return exc


class _CollapsingTaskGroup:
"""A thin wrapper around an anyio ``TaskGroup`` that collapses exceptions.

On ``__aexit__``, if the task group raises a ``BaseExceptionGroup`` that
contains only a single exception, that inner exception is re-raised
directly so callers can ``except`` it by its concrete type.

The wrapper delegates ``start_soon``, ``start``, and ``cancel_scope`` to
the underlying task group.
"""

def __init__(self) -> None:
self._task_group: TaskGroup | None = None

def _tg(self) -> TaskGroup:
if self._task_group is None:
raise RuntimeError("Task group has not been entered")
return self._task_group

@property
def cancel_scope(self) -> anyio.CancelScope:
return self._tg().cancel_scope

def start_soon(self, *args: object, **kwargs: object) -> None:
self._tg().start_soon(*args, **kwargs) # type: ignore[arg-type]

async def start(self, *args: object, **kwargs: object) -> object:
return await self._tg().start(*args, **kwargs) # type: ignore[arg-type]

async def __aenter__(self) -> _CollapsingTaskGroup:
self._task_group = anyio.create_task_group()
await self._task_group.__aenter__()
return self

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool | None:
try:
return await self._tg().__aexit__(exc_type, exc_val, exc_tb)
except BaseExceptionGroup as eg:
collapsed = collapse_exception_group(eg)
if collapsed is not eg:
raise collapsed from eg
raise


def create_mcp_task_group() -> _CollapsingTaskGroup:
"""Create an anyio task group that collapses single-exception groups.

Use this as a drop-in replacement for ``anyio.create_task_group()``::

async with create_mcp_task_group() as tg:
tg.start_soon(some_task)
"""
return _CollapsingTaskGroup()
3 changes: 2 additions & 1 deletion src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pydantic import BaseModel, TypeAdapter
from typing_extensions import Self

from mcp.shared._task_group import create_mcp_task_group
from mcp.shared.exceptions import MCPError
from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage
from mcp.shared.response_router import ResponseRouter
Expand Down Expand Up @@ -212,7 +213,7 @@ def add_response_router(self, router: ResponseRouter) -> None:
self._response_routers.append(router)

async def __aenter__(self) -> Self:
self._task_group = anyio.create_task_group()
self._task_group = create_mcp_task_group()
await self._task_group.__aenter__()
self._task_group.start_soon(self._receive_loop)
return self
Expand Down
Loading
Loading