From dcca6e25fe66142ec09f8dd4bbe8aea71b91e00a Mon Sep 17 00:00:00 2001 From: Camila Rondinini Date: Mon, 1 Dec 2025 13:44:45 +0000 Subject: [PATCH 1/5] add on_session_created callback option --- src/mcp/client/sse.py | 12 +++++++++++- tests/shared/test_sse.py | 19 +++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 437a0fa241..4fbfffcc9b 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -1,7 +1,8 @@ import logging +from collections.abc import Callable from contextlib import asynccontextmanager from typing import Any -from urllib.parse import urljoin, urlparse +from urllib.parse import parse_qs, urljoin, urlparse import anyio import httpx @@ -29,6 +30,7 @@ async def sse_client( sse_read_timeout: float = 60 * 5, httpx_client_factory: McpHttpClientFactory = create_mcp_http_client, auth: httpx.Auth | None = None, + on_session_created: Callable[[str], None] | None = None, ): """ Client transport for SSE. @@ -42,6 +44,7 @@ async def sse_client( timeout: HTTP timeout for regular operations. sse_read_timeout: Timeout for SSE read operations. auth: Optional HTTPX authentication handler. + on_session_created: Optional callback invoked with the session ID when received. """ read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] @@ -89,6 +92,13 @@ async def sse_reader( logger.error(error_msg) # pragma: no cover raise ValueError(error_msg) # pragma: no cover + if on_session_created: + session_id = parse_qs(endpoint_parsed.query).get( + "sessionId", [None] + )[0] + if session_id: + on_session_created(session_id) + task_status.started(endpoint_url) case "message": diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 28ac07d092..7dd35977c7 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -184,6 +184,25 @@ async def test_sse_client_basic_connection(server: None, server_url: str) -> Non assert isinstance(ping_result, EmptyResult) +@pytest.mark.anyio +async def test_sse_client_on_session_created(server: None, server_url: str) -> None: + captured_session_id: str | None = None + + def on_session_created(session_id: str) -> None: + nonlocal captured_session_id + captured_session_id = session_id + + async with sse_client( + server_url + "/sse", on_session_created=on_session_created + ) as streams: + async with ClientSession(*streams) as session: + result = await session.initialize() + assert isinstance(result, InitializeResult) + + assert captured_session_id is not None + assert len(captured_session_id) > 0 + + @pytest.fixture async def initialized_sse_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]: async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams: From 9033b813d6331d928afe6e6df6f5df74ee2ebca1 Mon Sep 17 00:00:00 2001 From: Camila Rondinini Date: Mon, 1 Dec 2025 13:53:35 +0000 Subject: [PATCH 2/5] formatting --- src/mcp/client/sse.py | 4 +--- tests/shared/test_sse.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 4fbfffcc9b..e22dde67ec 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -93,9 +93,7 @@ async def sse_reader( raise ValueError(error_msg) # pragma: no cover if on_session_created: - session_id = parse_qs(endpoint_parsed.query).get( - "sessionId", [None] - )[0] + session_id = parse_qs(endpoint_parsed.query).get("sessionId", [None])[0] if session_id: on_session_created(session_id) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 7dd35977c7..579172d37f 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -192,9 +192,7 @@ def on_session_created(session_id: str) -> None: nonlocal captured_session_id captured_session_id = session_id - async with sse_client( - server_url + "/sse", on_session_created=on_session_created - ) as streams: + async with sse_client(server_url + "/sse", on_session_created=on_session_created) as streams: async with ClientSession(*streams) as session: result = await session.initialize() assert isinstance(result, InitializeResult) From 288830e8902929bbd101948a11837010ffcfc502 Mon Sep 17 00:00:00 2001 From: Camila Rondinini Date: Mon, 1 Dec 2025 14:04:53 +0000 Subject: [PATCH 3/5] support both casing options --- src/mcp/client/sse.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index e22dde67ec..eb9675cfa5 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -93,7 +93,11 @@ async def sse_reader( raise ValueError(error_msg) # pragma: no cover if on_session_created: - session_id = parse_qs(endpoint_parsed.query).get("sessionId", [None])[0] + query_params = parse_qs(endpoint_parsed.query) + session_id = ( + query_params.get("sessionId", [None])[0] + or query_params.get("session_id", [None])[0] + ) if session_id: on_session_created(session_id) From 2455043cca73bb89f92eef7298be4f00ddc312ff Mon Sep 17 00:00:00 2001 From: Camila Rondinini Date: Mon, 1 Dec 2025 14:50:05 +0000 Subject: [PATCH 4/5] test coverage --- src/mcp/client/sse.py | 11 ++++++----- tests/shared/test_sse.py | 39 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index eb9675cfa5..5d57cc5a59 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -22,6 +22,11 @@ def remove_request_params(url: str) -> str: return urljoin(url, urlparse(url).path) +def _extract_session_id_from_endpoint(endpoint_url: str) -> str | None: + query_params = parse_qs(urlparse(endpoint_url).query) + return query_params.get("sessionId", [None])[0] or query_params.get("session_id", [None])[0] + + @asynccontextmanager async def sse_client( url: str, @@ -93,11 +98,7 @@ async def sse_reader( raise ValueError(error_msg) # pragma: no cover if on_session_created: - query_params = parse_qs(endpoint_parsed.query) - session_id = ( - query_params.get("sessionId", [None])[0] - or query_params.get("session_id", [None])[0] - ) + session_id = _extract_session_id_from_endpoint(endpoint_url) if session_id: on_session_created(session_id) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 579172d37f..109b41f265 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -18,7 +18,7 @@ import mcp.types as types from mcp.client.session import ClientSession -from mcp.client.sse import sse_client +from mcp.client.sse import _extract_session_id_from_endpoint, sse_client from mcp.server import Server from mcp.server.sse import SseServerTransport from mcp.server.transport_security import TransportSecuritySettings @@ -201,6 +201,43 @@ def on_session_created(session_id: str) -> None: assert len(captured_session_id) > 0 +@pytest.mark.parametrize( + "endpoint_url,expected", + [ + ("/messages?sessionId=abc123", "abc123"), + ("/messages?session_id=def456", "def456"), + ("/messages?sessionId=abc&session_id=def", "abc"), + ("/messages?other=value", None), + ("/messages", None), + ("", None), + ], +) +def test_extract_session_id_from_endpoint(endpoint_url: str, expected: str | None) -> None: + assert _extract_session_id_from_endpoint(endpoint_url) == expected + + +@pytest.mark.anyio +async def test_sse_client_on_session_created_not_called_when_no_session_id( + server: None, server_url: str, monkeypatch: pytest.MonkeyPatch +) -> None: + from mcp.client import sse + + callback_called = False + + def on_session_created(session_id: str) -> None: + nonlocal callback_called + callback_called = True + + monkeypatch.setattr(sse, "_extract_session_id_from_endpoint", lambda url: None) + + async with sse_client(server_url + "/sse", on_session_created=on_session_created) as streams: + async with ClientSession(*streams) as session: + result = await session.initialize() + assert isinstance(result, InitializeResult) + + assert callback_called is False + + @pytest.fixture async def initialized_sse_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]: async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams: From 782992809fb25f228462fdbf7f5313c1125fe5b0 Mon Sep 17 00:00:00 2001 From: Camila Rondinini Date: Mon, 1 Dec 2025 14:59:59 +0000 Subject: [PATCH 5/5] test coverage --- tests/shared/test_sse.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 109b41f265..fcad127072 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -4,6 +4,7 @@ import time from collections.abc import AsyncGenerator, Generator from typing import Any +from unittest.mock import Mock import anyio import httpx @@ -16,6 +17,7 @@ from starlette.responses import Response from starlette.routing import Mount, Route +import mcp.client.sse import mcp.types as types from mcp.client.session import ClientSession from mcp.client.sse import _extract_session_id_from_endpoint, sse_client @@ -220,22 +222,19 @@ def test_extract_session_id_from_endpoint(endpoint_url: str, expected: str | Non async def test_sse_client_on_session_created_not_called_when_no_session_id( server: None, server_url: str, monkeypatch: pytest.MonkeyPatch ) -> None: - from mcp.client import sse + callback_mock = Mock() - callback_called = False + def mock_extract(url: str) -> None: + return None - def on_session_created(session_id: str) -> None: - nonlocal callback_called - callback_called = True - - monkeypatch.setattr(sse, "_extract_session_id_from_endpoint", lambda url: None) + monkeypatch.setattr(mcp.client.sse, "_extract_session_id_from_endpoint", mock_extract) - async with sse_client(server_url + "/sse", on_session_created=on_session_created) as streams: + async with sse_client(server_url + "/sse", on_session_created=callback_mock) as streams: async with ClientSession(*streams) as session: result = await session.initialize() assert isinstance(result, InitializeResult) - assert callback_called is False + callback_mock.assert_not_called() @pytest.fixture