diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index aa3e50e07e..1e75f22b5e 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -79,10 +79,12 @@ def __init__(self, url: str) -> None: url: The endpoint URL. """ self.url = url + parsed_url = httpx.URL(url) + self.origin = f"{parsed_url.scheme}://{parsed_url.netloc.decode()}" if parsed_url.netloc else None self.session_id: str | None = None self.protocol_version: str | None = None - def _prepare_headers(self) -> dict[str, str]: + def _prepare_headers(self, client: httpx.AsyncClient | None = None) -> dict[str, str]: """Build MCP-specific request headers. These headers will be merged with the httpx.AsyncClient's default headers, @@ -92,6 +94,8 @@ def _prepare_headers(self) -> dict[str, str]: "accept": "application/json, text/event-stream", "content-type": "application/json", } + if self.origin and (client is None or "origin" not in client.headers): + headers["origin"] = self.origin # Add session headers if available if self.session_id: headers[MCP_SESSION_ID] = self.session_id @@ -189,7 +193,7 @@ async def handle_get_stream(self, client: httpx.AsyncClient, read_stream_writer: if not self.session_id: return - headers = self._prepare_headers() + headers = self._prepare_headers(client) if last_event_id: headers[LAST_EVENT_ID] = last_event_id @@ -225,7 +229,7 @@ async def handle_get_stream(self, client: httpx.AsyncClient, read_stream_writer: async def _handle_resumption_request(self, ctx: RequestContext) -> None: """Handle a resumption request using GET with SSE.""" - headers = self._prepare_headers() + headers = self._prepare_headers(ctx.client) if ctx.metadata and ctx.metadata.resumption_token: headers[LAST_EVENT_ID] = ctx.metadata.resumption_token else: @@ -253,7 +257,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: async def _handle_post_request(self, ctx: RequestContext) -> None: """Handle a POST request with response processing.""" - headers = self._prepare_headers() + headers = self._prepare_headers(ctx.client) message = ctx.session_message.message is_initialization = self._is_initialization_request(message) @@ -388,7 +392,7 @@ async def _handle_reconnection( delay_ms = retry_interval_ms if retry_interval_ms is not None else DEFAULT_RECONNECTION_DELAY_MS await anyio.sleep(delay_ms / 1000.0) - headers = self._prepare_headers() + headers = self._prepare_headers(ctx.client) headers[LAST_EVENT_ID] = last_event_id # Extract original request ID to map responses @@ -496,7 +500,7 @@ async def terminate_session(self, client: httpx.AsyncClient) -> None: return # pragma: no cover try: - headers = self._prepare_headers() + headers = self._prepare_headers(client) response = await client.delete(self.url, headers=headers) if response.status_code == 405: diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 3d5770fb61..3f6a5ee73c 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -2318,3 +2318,39 @@ async def test_streamable_http_client_preserves_custom_with_mcp_headers( assert "content-type" in headers_data assert headers_data["content-type"] == "application/json" + + +@pytest.mark.anyio +async def test_streamable_http_client_adds_origin_header(context_aware_server: None, basic_server_url: str) -> None: + async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch + await session.initialize() + + tool_result = await session.call_tool("echo_headers", {}) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + headers_data = json.loads(tool_result.content[0].text) + + assert headers_data["origin"] == basic_server_url + + +@pytest.mark.anyio +async def test_streamable_http_client_preserves_custom_origin_header( + context_aware_server: None, basic_server_url: str +) -> None: + custom_origin = "https://proxy.example" + + async with create_mcp_http_client(headers={"Origin": custom_origin}) as httpx_client: + async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + read_stream, + write_stream, + ): + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch + await session.initialize() + + tool_result = await session.call_tool("echo_headers", {}) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + headers_data = json.loads(tool_result.content[0].text) + + assert headers_data["origin"] == custom_origin