Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ dependencies = [
"pyjwt[crypto]>=2.10.1",
"typing-extensions>=4.13.0",
"typing-inspection>=0.4.1",
"opentelemetry-api>=1.23.0",
]

[project.optional-dependencies]
Expand Down Expand Up @@ -71,6 +72,7 @@ dev = [
"coverage[toml]>=7.10.7,<=7.13",
"pillow>=12.0",
"strict-no-cover",
"opentelemetry-sdk>=1.23.0",
]
docs = [
"mkdocs>=1.6.1",
Expand Down
10 changes: 9 additions & 1 deletion src/mcp/client/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
async def post_writer(endpoint_url: str):
try:
async with write_stream_reader:
async for session_message in write_stream_reader:

async def handle_message(session_message: SessionMessage) -> None:
logger.debug(f"Sending client message: {session_message}")
response = await client.post(
endpoint_url,
Expand All @@ -144,6 +145,13 @@ async def post_writer(endpoint_url: str):
)
response.raise_for_status()
logger.debug(f"Client message sent successfully: {response.status_code}")

async for session_message in write_stream_reader:
async with anyio.create_task_group() as tg_local:
session_message.context.run(
tg_local.start_soon, handle_message, session_message
)

except Exception: # pragma: lax no cover
logger.exception("Error in post_writer")
finally:
Expand Down
11 changes: 8 additions & 3 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,8 @@ async def post_writer(
"""Handle writing requests to the server."""
try:
async with write_stream_reader:
async for session_message in write_stream_reader:

async def handle_message(session_message: SessionMessage) -> None:
message = session_message.message
metadata = (
session_message.metadata
Expand Down Expand Up @@ -478,8 +479,12 @@ async def handle_request_async():
else:
await handle_request_async()

except Exception: # pragma: lax no cover
logger.exception("Error in post_writer")
async for session_message in write_stream_reader:
async with anyio.create_task_group() as tg_local:
session_message.context.run(tg_local.start_soon, handle_message, session_message)

except Exception:
logger.exception("Error in post_writer") # pragma: no cover
finally:
await read_stream_writer.aclose()
await write_stream.aclose()
Expand Down
8 changes: 7 additions & 1 deletion src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,13 @@ async def run(
async for message in session.incoming_messages:
logger.debug("Received message: %s", message)

tg.start_soon(
if isinstance(message, RequestResponder) and message.context is not None:
context = message.context
else:
context = contextvars.copy_context()

context.run(
tg.start_soon,
self._handle_message,
message,
session,
Expand Down
4 changes: 3 additions & 1 deletion src/mcp/shared/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
to support transport-specific features like resumability.
"""

import contextvars
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any

from mcp.types import JSONRPCMessage, RequestId
Expand Down Expand Up @@ -49,4 +50,5 @@ class SessionMessage:
"""A message with specific metadata for transport-specific features."""

message: JSONRPCMessage
context: contextvars.Context = field(default_factory=contextvars.copy_context)
metadata: MessageMetadata = None
66 changes: 60 additions & 6 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import contextvars
import logging
from collections.abc import Callable
from contextlib import AsyncExitStack
Expand All @@ -8,6 +9,8 @@

import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from opentelemetry import context as otel_context
from opentelemetry.propagate import extract, inject
from pydantic import BaseModel, TypeAdapter
from typing_extensions import Self

Expand Down Expand Up @@ -79,11 +82,13 @@ def __init__(
session: BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT],
on_complete: Callable[[RequestResponder[ReceiveRequestT, SendResultT]], Any],
message_metadata: MessageMetadata = None,
context: contextvars.Context | None = None,
) -> None:
self.request_id = request_id
self.request_meta = request_meta
self.request = request
self.message_metadata = message_metadata
self.context = context
self._session = session
self._completed = False
self._cancel_scope = anyio.CancelScope()
Expand Down Expand Up @@ -251,6 +256,9 @@ async def send_request(
response_stream, response_stream_reader = anyio.create_memory_object_stream[JSONRPCResponse | JSONRPCError](1)
self._response_streams[request_id] = response_stream

# Propagate opentelemetry trace context
self._inject_otel_context(request)

# Set up progress token if progress callback is provided
request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True)
if progress_callback is not None:
Expand Down Expand Up @@ -295,6 +303,10 @@ async def send_notification(
related_request_id: RequestId | None = None,
) -> None:
"""Emits a notification, which is a one-way message that does not expect a response."""

# Propagate opentelemetry trace context
self._inject_otel_context(notification)

# Some transport implementations may need to set the related_request_id
# to attribute to the notifications to the request that triggered them.
jsonrpc_notification = JSONRPCNotification(
Expand All @@ -307,6 +319,28 @@ async def send_notification(
)
await self._write_stream.send(session_message)

def _inject_otel_context(self, request: SendRequestT | SendNotificationT) -> None:
"""Propagate OpenTelemetry context in `_meta`.

See
- SEP414 https://github.com/modelcontextprotocol/modelcontextprotocol/pull/414
- OpenTelemetry semantic conventions
https://github.com/open-telemetry/semantic-conventions/blob/v1.39.0/docs/gen-ai/mcp.md
"""

if request.params is None:
return

carrier: RequestParamsMeta = {}
inject(carrier)
if not carrier:
return

if request.params.meta is None:
request.params.meta = {}

request.params.meta.update(carrier)

async def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None:
if isinstance(response, ErrorData):
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
Expand All @@ -333,10 +367,9 @@ def _receive_notification_adapter(self) -> TypeAdapter[ReceiveNotificationT]:
async def _receive_loop(self) -> None:
async with self._read_stream, self._write_stream:
try:
async for message in self._read_stream:
if isinstance(message, Exception): # pragma: no cover
await self._handle_incoming(message)
elif isinstance(message.message, JSONRPCRequest):

async def handle_message(message: SessionMessage) -> None:
if isinstance(message.message, JSONRPCRequest):
try:
validated_request = self._receive_request_adapter.validate_python(
message.message.model_dump(by_alias=True, mode="json", exclude_none=True),
Expand All @@ -349,6 +382,7 @@ async def _receive_loop(self) -> None:
session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
message_metadata=message.metadata,
context=message.context,
)
self._in_flight[responder.request_id] = responder
await self._received_request(responder)
Expand Down Expand Up @@ -397,15 +431,35 @@ async def _receive_loop(self) -> None:
logging.exception("Progress callback raised an exception")
await self._received_notification(notification)
await self._handle_incoming(notification)
except Exception:
except Exception: # pragma: lax no cover
# For other validation errors, log and continue
logging.warning( # pragma: no cover
logging.warning(
f"Failed to validate notification:. Message was: {message.message}",
exc_info=True,
)
else: # Response or error
await self._handle_response(message)

async def _handle_message_with_otel(message: SessionMessage) -> None:
meta = None
if isinstance(message.message, (JSONRPCRequest | JSONRPCNotification)) and message.message.params:
meta = message.message.params.get("_meta")

extracted_ctx = extract(meta) if meta else None
otel_token = otel_context.attach(extracted_ctx) if extracted_ctx else None
try:
await handle_message(message)
finally:
if otel_token:
otel_context.detach(otel_token)

async for message in self._read_stream:
if isinstance(message, Exception): # pragma: no cover
await self._handle_incoming(message)
else:
async with anyio.create_task_group() as tg:
message.context.run(tg.start_soon, _handle_message_with_otel, message)

except anyio.ClosedResourceError:
# This is expected when the client disconnects abruptly.
# Without this handler, the exception would propagate up and
Expand Down
Loading