Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import ValidationError
from sse_starlette import EventSourceResponse
from starlette.requests import Request
from starlette.requests import ClientDisconnect, Request
from starlette.responses import Response
from starlette.types import Receive, Scope, Send

Expand Down Expand Up @@ -634,16 +634,35 @@ async def sse_writer(): # pragma: lax no cover
await sse_stream_reader.aclose()

except Exception as err: # pragma: no cover
await self._handle_post_error(err, scope, receive, send, writer)
return

async def _handle_post_error( # pragma: no cover
self,
err: Exception,
scope: Scope,
receive: Receive,
send: Send,
writer: MemoryObjectSendStream[SessionMessage | Exception] | None,
) -> None:
"""Handle errors from POST request processing.

ClientDisconnect is logged as a warning since it is a client-side event
(timeout, cancel, network drop), not a server error.
All other exceptions are logged as errors and return HTTP 500.
"""
if isinstance(err, ClientDisconnect):
logger.warning("Client disconnected during POST request")
else: # pragma: no cover
logger.exception("Error handling POST request")
response = self._create_error_response(
f"Error handling POST request: {err}",
HTTPStatus.INTERNAL_SERVER_ERROR,
INTERNAL_ERROR,
)
await response(scope, receive, send)
if writer:
await writer.send(Exception(err))
return
if writer:
await writer.send(Exception(err))

async def _handle_get_request(self, request: Request, send: Send) -> None:
"""Handle GET request to establish SSE.
Expand Down
45 changes: 45 additions & 0 deletions tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -2316,3 +2316,48 @@ async def test_streamable_http_client_preserves_custom_with_mcp_headers(

assert "content-type" in headers_data
assert headers_data["content-type"] == "application/json"


def test_client_disconnect_does_not_return_500(basic_server: None, basic_server_url: str):
"""Test that ClientDisconnect is handled gracefully without HTTP 500.

When a client disconnects before the server finishes reading the request body,
the server should log a warning and remain healthy for subsequent requests.
"""
import urllib.parse

parsed = urllib.parse.urlparse(basic_server_url)
host = parsed.hostname
port = parsed.port

# Send a POST request with a large Content-Length, then close the socket
# immediately. This causes Starlette to raise ClientDisconnect when it
# tries to read the body.
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((host, port))
raw_request = (
f"POST /mcp HTTP/1.1\r\n"
f"Host: {host}:{port}\r\n"
f"Content-Type: application/json\r\n"
f"Accept: application/json, text/event-stream\r\n"
f"Content-Length: 100000\r\n"
f"\r\n"
)
sock.sendall(raw_request.encode())
# Close immediately without sending the body
sock.close()

# Give the server a moment to process the disconnect
time.sleep(0.5)

# Verify the server is still healthy — a normal request should succeed
response = requests.post(
f"{basic_server_url}/mcp",
headers={
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
},
json=INIT_REQUEST,
)
# Server should be alive and respond normally (200 for SSE init)
assert response.status_code == 200
Loading