Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion src/mcp/client/sse.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import logging
from collections.abc import Callable
from contextlib import asynccontextmanager
from typing import Any
from urllib.parse import urljoin, urlparse
from urllib.parse import parse_qs, urljoin, urlparse

import anyio
import httpx
Expand All @@ -21,6 +22,11 @@ def remove_request_params(url: str) -> str:
return urljoin(url, urlparse(url).path)


def _extract_session_id_from_endpoint(endpoint_url: str) -> str | None:
query_params = parse_qs(urlparse(endpoint_url).query)
return query_params.get("sessionId", [None])[0] or query_params.get("session_id", [None])[0]


@asynccontextmanager
async def sse_client(
url: str,
Expand All @@ -29,6 +35,7 @@ async def sse_client(
sse_read_timeout: float = 60 * 5,
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
auth: httpx.Auth | None = None,
on_session_created: Callable[[str], None] | None = None,
):
"""
Client transport for SSE.
Expand All @@ -42,6 +49,7 @@ async def sse_client(
timeout: HTTP timeout for regular operations.
sse_read_timeout: Timeout for SSE read operations.
auth: Optional HTTPX authentication handler.
on_session_created: Optional callback invoked with the session ID when received.
"""
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
Expand Down Expand Up @@ -89,6 +97,11 @@ async def sse_reader(
logger.error(error_msg) # pragma: no cover
raise ValueError(error_msg) # pragma: no cover

if on_session_created:
session_id = _extract_session_id_from_endpoint(endpoint_url)
if session_id:
on_session_created(session_id)

task_status.started(endpoint_url)

case "message":
Expand Down
55 changes: 54 additions & 1 deletion tests/shared/test_sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
from collections.abc import AsyncGenerator, Generator
from typing import Any
from unittest.mock import Mock

import anyio
import httpx
Expand All @@ -16,9 +17,10 @@
from starlette.responses import Response
from starlette.routing import Mount, Route

import mcp.client.sse
import mcp.types as types
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.client.sse import _extract_session_id_from_endpoint, sse_client
from mcp.server import Server
from mcp.server.sse import SseServerTransport
from mcp.server.transport_security import TransportSecuritySettings
Expand Down Expand Up @@ -184,6 +186,57 @@ async def test_sse_client_basic_connection(server: None, server_url: str) -> Non
assert isinstance(ping_result, EmptyResult)


@pytest.mark.anyio
async def test_sse_client_on_session_created(server: None, server_url: str) -> None:
captured_session_id: str | None = None

def on_session_created(session_id: str) -> None:
nonlocal captured_session_id
captured_session_id = session_id

async with sse_client(server_url + "/sse", on_session_created=on_session_created) as streams:
async with ClientSession(*streams) as session:
result = await session.initialize()
assert isinstance(result, InitializeResult)

assert captured_session_id is not None
assert len(captured_session_id) > 0


@pytest.mark.parametrize(
"endpoint_url,expected",
[
("/messages?sessionId=abc123", "abc123"),
("/messages?session_id=def456", "def456"),
("/messages?sessionId=abc&session_id=def", "abc"),
("/messages?other=value", None),
("/messages", None),
("", None),
],
)
def test_extract_session_id_from_endpoint(endpoint_url: str, expected: str | None) -> None:
assert _extract_session_id_from_endpoint(endpoint_url) == expected


@pytest.mark.anyio
async def test_sse_client_on_session_created_not_called_when_no_session_id(
server: None, server_url: str, monkeypatch: pytest.MonkeyPatch
) -> None:
callback_mock = Mock()

def mock_extract(url: str) -> None:
return None

monkeypatch.setattr(mcp.client.sse, "_extract_session_id_from_endpoint", mock_extract)

async with sse_client(server_url + "/sse", on_session_created=callback_mock) as streams:
async with ClientSession(*streams) as session:
result = await session.initialize()
assert isinstance(result, InitializeResult)

callback_mock.assert_not_called()


@pytest.fixture
async def initialized_sse_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]:
async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams:
Expand Down