diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 57df64705..6d8a0658a 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -288,9 +288,11 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: elif content_type.startswith(SSE): await self._handle_sse_response(response, ctx, is_initialization) else: + # Propagate an error bound to the originating request id so callers get McpError await self._handle_unexpected_content_type( content_type, ctx.read_stream_writer, + message.root.id, ) async def _handle_json_response( @@ -343,11 +345,22 @@ async def _handle_unexpected_content_type( self, content_type: str, read_stream_writer: StreamWriter, + request_id: RequestId | None, ) -> None: """Handle unexpected content type in response.""" error_msg = f"Unexpected content type: {content_type}" logger.error(error_msg) - await read_stream_writer.send(ValueError(error_msg)) + if request_id is not None: + jsonrpc_error = JSONRPCError( + jsonrpc="2.0", + id=request_id, + error=ErrorData(code=32600, message=error_msg), + ) + session_message = SessionMessage(JSONRPCMessage(jsonrpc_error)) + await read_stream_writer.send(session_message) + else: + # Fallback: send as exception if we somehow lack a request id + await read_stream_writer.send(ValueError(error_msg)) async def _send_session_terminated_error( self, diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 55800da33..4bea609dd 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -11,6 +11,11 @@ from collections.abc import Generator from typing import Any +try: + from builtins import ExceptionGroup # type: ignore +except ImportError: + from exceptiongroup import ExceptionGroup # type: ignore + import anyio import httpx import pytest @@ -1597,3 +1602,81 @@ async def bad_client(): assert isinstance(result, InitializeResult) tools = await session.list_tools() assert tools.tools + + +@pytest.mark.anyio +async def test_client_unexpected_content_type_raises_mcp_error(): + """Test that unexpected content types raise McpError instead of just printing. + + This test verifies that when a server returns HTML instead of MCP JSON, + the client properly raises McpError wrapped in ExceptionGroup. + """ + # Use a real server that returns HTML to test the actual behavior + from starlette.responses import HTMLResponse + from starlette.routing import Route + + # Create a simple server that returns HTML instead of MCP JSON + async def html_endpoint(request: Request): + return HTMLResponse("Not an MCP server") + + app = Starlette( + routes=[ + Route("/mcp", html_endpoint, methods=["GET", "POST"]), + ] + ) + + # Start server on a random port using a simpler approach + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + port = s.getsockname()[1] + + # Use a thread instead of multiprocessing to avoid pickle issues + import asyncio + import threading + + def run_server(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + uvicorn.run(app, host="127.0.0.1", port=port, log_level="error") + + server_thread = threading.Thread(target=run_server, daemon=True) + server_thread.start() + + try: + # Give server time to start + await asyncio.sleep(0.5) + + server_url = f"http://127.0.0.1:{port}" + + # Test that the client raises McpError when server returns HTML + with pytest.raises(ExceptionGroup) as exc_info: # type: ignore + async with streamablehttp_client(f"{server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + + # Extract the McpError from the ExceptionGroup (handle nested groups) + mcp_error = None + + def find_mcp_error(exc_group: ExceptionGroup) -> McpError | None: # type: ignore + for exc in exc_group.exceptions: # type: ignore + if isinstance(exc, McpError): + return exc + elif isinstance(exc, ExceptionGroup): # type: ignore + result = find_mcp_error(exc) + if result: + return result + return None + + mcp_error = find_mcp_error(exc_info.value) + + assert mcp_error is not None, "Expected McpError in ExceptionGroup hierarchy" + assert "Unexpected content type" in str(mcp_error) + assert "text/html" in str(mcp_error) + + finally: + # Server thread will be cleaned up automatically as daemon + pass