diff --git a/src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs index 2e49babc..209d644d 100644 --- a/src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs @@ -93,6 +93,11 @@ private async Task InitializeAsync(JsonRpcMessage message, CancellationToken can private async Task InitializeSseTransportAsync(JsonRpcMessage message, CancellationToken cancellationToken) { + if (_options.KnownSessionId is not null) + { + throw new InvalidOperationException("Streamable HTTP transport is required to resume an existing session."); + } + var sseTransport = new SseClientSessionTransport(_name, _options, _httpClient, _messageChannel, _loggerFactory); try diff --git a/src/ModelContextProtocol.Core/Client/HttpClientTransport.cs b/src/ModelContextProtocol.Core/Client/HttpClientTransport.cs index 322b9175..7c6eeecc 100644 --- a/src/ModelContextProtocol.Core/Client/HttpClientTransport.cs +++ b/src/ModelContextProtocol.Core/Client/HttpClientTransport.cs @@ -72,6 +72,11 @@ public HttpClientTransport(HttpClientTransportOptions transportOptions, HttpClie /// public async Task ConnectAsync(CancellationToken cancellationToken = default) { + if (_options.KnownSessionId is not null && _options.TransportMode == HttpTransportMode.Sse) + { + throw new InvalidOperationException("SSE transport does not support resuming an existing session."); + } + return _options.TransportMode switch { HttpTransportMode.AutoDetect => new AutoDetectingClientSessionTransport(Name, _options, _mcpHttpClient, _loggerFactory), diff --git a/src/ModelContextProtocol.Core/Client/HttpClientTransportOptions.cs b/src/ModelContextProtocol.Core/Client/HttpClientTransportOptions.cs index 94b95eec..c181966d 100644 --- a/src/ModelContextProtocol.Core/Client/HttpClientTransportOptions.cs +++ b/src/ModelContextProtocol.Core/Client/HttpClientTransportOptions.cs @@ -73,6 +73,33 @@ public required Uri Endpoint /// public IDictionary? AdditionalHeaders { get; set; } + /// + /// Gets or sets a session identifier that should be reused when connecting to a Streamable HTTP server. + /// + /// + /// + /// When non-, the transport assumes the server already created the session and will include the + /// specified session identifier in every HTTP request. This allows reconnecting to an existing session created in a + /// previous process. This option is only supported by the Streamable HTTP transport mode. + /// + /// + /// Clients should pair this with + /// + /// to skip the initialization handshake when rehydrating a previously negotiated session. + /// + /// + public string? KnownSessionId { get; set; } + + /// + /// Gets or sets a value indicating whether this transport endpoint is responsible for ending the session on dispose. + /// + /// + /// When (default), the transport sends a DELETE request that informs the server the session is + /// complete. Set this to when creating a transport used solely to bootstrap session information + /// that will later be resumed elsewhere. + /// + public bool OwnsSession { get; set; } = true; + /// /// Gets sor sets the authorization provider to use for authentication. /// diff --git a/src/ModelContextProtocol.Core/Client/McpClient.Methods.cs b/src/ModelContextProtocol.Core/Client/McpClient.Methods.cs index 6397d8e7..f246ee05 100644 --- a/src/ModelContextProtocol.Core/Client/McpClient.Methods.cs +++ b/src/ModelContextProtocol.Core/Client/McpClient.Methods.cs @@ -50,6 +50,36 @@ public static async Task CreateAsync( return clientSession; } + /// + /// Recreates an using an existing transport session without sending a new initialize request. + /// + /// The transport instance already configured to connect to the target server. + /// The metadata captured from the original session that should be applied when resuming. + /// Optional client settings that should mirror those used to create the original session. + /// An optional logger factory for diagnostics. + /// Token used when establishing the transport connection. + /// An bound to the resumed session. + /// Thrown when or is . + public static async Task ResumeSessionAsync( + IClientTransport clientTransport, + ResumeClientSessionOptions resumeOptions, + McpClientOptions? clientOptions = null, + ILoggerFactory? loggerFactory = null, + CancellationToken cancellationToken = default) + { + Throw.IfNull(clientTransport); + Throw.IfNull(resumeOptions); + Throw.IfNull(resumeOptions.ServerCapabilities); + Throw.IfNull(resumeOptions.ServerInfo); + + var transport = await clientTransport.ConnectAsync(cancellationToken).ConfigureAwait(false); + var endpointName = clientTransport.Name; + + var clientSession = new McpClientImpl(transport, endpointName, clientOptions, loggerFactory); + clientSession.ResumeSession(resumeOptions); + return clientSession; + } + /// /// Sends a ping request to verify server connectivity. /// diff --git a/src/ModelContextProtocol.Core/Client/McpClientImpl.cs b/src/ModelContextProtocol.Core/Client/McpClientImpl.cs index 3a289d13..ccc7b9c6 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientImpl.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientImpl.cs @@ -80,7 +80,7 @@ private void RegisterHandlers(McpClientOptions options, NotificationHandlers not cancellationToken), McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams, McpJsonUtilities.JsonContext.Default.CreateMessageResult); - + _options.Capabilities ??= new(); _options.Capabilities.Sampling ??= new(); } @@ -207,6 +207,28 @@ await this.SendNotificationAsync( LogClientConnected(_endpointName); } + /// + /// Configures the client to use an already initialized session without performing the handshake. + /// + /// The metadata captured from the previous session that should be applied to the resumed client. + internal void ResumeSession(ResumeClientSessionOptions resumeOptions) + { + Throw.IfNull(resumeOptions); + Throw.IfNull(resumeOptions.ServerCapabilities); + Throw.IfNull(resumeOptions.ServerInfo); + + _ = _sessionHandler.ProcessMessagesAsync(CancellationToken.None); + + _serverCapabilities = resumeOptions.ServerCapabilities; + _serverInfo = resumeOptions.ServerInfo; + _serverInstructions = resumeOptions.ServerInstructions; + _negotiatedProtocolVersion = resumeOptions.NegotiatedProtocolVersion + ?? _options.ProtocolVersion + ?? McpSessionHandler.LatestProtocolVersion; + + LogClientSessionResumed(_endpointName); + } + /// public override Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) => _sessionHandler.SendRequestAsync(request, cancellationToken); @@ -249,4 +271,7 @@ public override async ValueTask DisposeAsync() [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} client created and connected.")] private partial void LogClientConnected(string endpointName); + + [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} client resumed existing session.")] + private partial void LogClientSessionResumed(string endpointName); } diff --git a/src/ModelContextProtocol.Core/Client/ResumeClientSessionOptions.cs b/src/ModelContextProtocol.Core/Client/ResumeClientSessionOptions.cs new file mode 100644 index 00000000..ae01caf3 --- /dev/null +++ b/src/ModelContextProtocol.Core/Client/ResumeClientSessionOptions.cs @@ -0,0 +1,29 @@ +using ModelContextProtocol.Protocol; + +namespace ModelContextProtocol.Client; + +/// +/// Provides the metadata captured from a previous MCP client session that is required to resume it. +/// +public sealed class ResumeClientSessionOptions +{ + /// + /// Gets or sets the server capabilities that were negotiated during the original session setialization. + /// + public required ServerCapabilities ServerCapabilities { get; set; } + + /// + /// Gets or sets the server implementation metadata that identifies the connected MCP server. + /// + public required Implementation ServerInfo { get; set; } + + /// + /// Gets or sets any instructions previously supplied by the server. + /// + public string? ServerInstructions { get; set; } + + /// + /// Gets or sets the protocol version that was negotiated with the server. + /// + public string? NegotiatedProtocolVersion { get; set; } +} diff --git a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs index 2b9700f4..2f0ec99b 100644 --- a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs @@ -47,6 +47,12 @@ public StreamableHttpClientSessionTransport( // until the first call to SendMessageAsync. Fortunately, that happens internally in McpClient.ConnectAsync // so we still throw any connection-related Exceptions from there and never expose a pre-connected client to the user. SetConnected(); + + if (_options.KnownSessionId is { } knownSessionId) + { + SessionId = knownSessionId; + _getReceiveTask = ReceiveUnsolicitedMessagesAsync(); + } } /// @@ -60,6 +66,14 @@ public override async Task SendMessageAsync(JsonRpcMessage message, Cancellation // This is used by the auto transport so it can fall back and try SSE given a non-200 response without catching an exception. internal async Task SendHttpRequestAsync(JsonRpcMessage message, CancellationToken cancellationToken) { + if (_options.KnownSessionId is not null && + message is JsonRpcRequest { Method: RequestMethods.Initialize }) + { + throw new InvalidOperationException( + $"Cannot send '{RequestMethods.Initialize}' when {nameof(HttpClientTransportOptions)}.{nameof(HttpClientTransportOptions.KnownSessionId)} is configured. " + + $"Call {nameof(McpClient)}.{nameof(McpClient.ResumeSessionAsync)} to resume existing sessions."); + } + using var sendCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _connectionCts.Token); cancellationToken = sendCts.Token; @@ -116,7 +130,7 @@ internal async Task SendHttpRequestAsync(JsonRpcMessage mes var initializeResult = JsonSerializer.Deserialize(initResponse.Result, McpJsonUtilities.JsonContext.Default.InitializeResult); _negotiatedProtocolVersion = initializeResult?.ProtocolVersion; - _getReceiveTask = ReceiveUnsolicitedMessagesAsync(); + _getReceiveTask ??= ReceiveUnsolicitedMessagesAsync(); } return response; @@ -139,7 +153,7 @@ public override async ValueTask DisposeAsync() try { // Send DELETE request to terminate the session. Only send if we have a session ID, per MCP spec. - if (!string.IsNullOrEmpty(SessionId)) + if (_options.OwnsSession && !string.IsNullOrEmpty(SessionId)) { await SendDeleteRequest(); } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs index 04eceb8d..cce2e4f0 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs @@ -2,7 +2,11 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Primitives; using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; using System.Collections.Concurrent; +using System.Threading; +using System.Threading.Tasks; namespace ModelContextProtocol.AspNetCore.Tests; @@ -188,4 +192,95 @@ public async Task StreamableHttpClient_SendsMcpProtocolVersionHeader_AfterInitia Assert.True(protocolVersionHeaderValues.Count > 1); Assert.All(protocolVersionHeaderValues, v => Assert.Equal("2025-03-26", v)); } + + [Fact] + public async Task CanResumeSessionWithMapMcpAndRunSessionHandler() + { + Assert.SkipWhen(Stateless, "Session resumption relies on server-side session tracking."); + + var runSessionCount = 0; + var serverTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + Builder.Services.AddMcpServer(options => + { + options.ServerInfo = new Implementation + { + Name = "ResumeServer", + Version = "1.0.0", + }; + }).WithHttpTransport(opts => + { + ConfigureStateless(opts); + opts.RunSessionHandler = async (context, server, cancellationToken) => + { + Interlocked.Increment(ref runSessionCount); + serverTcs.TrySetResult(server); + await server.RunAsync(cancellationToken); + }; + }).WithTools(); + + await using var app = Builder.Build(); + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); + + ServerCapabilities? serverCapabilities = null; + Implementation? serverInfo = null; + string? serverInstructions = null; + string? negotiatedProtocolVersion = null; + string? resumedSessionId = null; + + await using var initialTransport = new HttpClientTransport(new() + { + Endpoint = new("http://localhost:5000/"), + TransportMode = HttpTransportMode.StreamableHttp, + OwnsSession = false, + }, HttpClient, LoggerFactory); + + await using (var initialClient = await McpClient.CreateAsync(initialTransport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)) + { + resumedSessionId = initialClient.SessionId ?? throw new InvalidOperationException("SessionId not negotiated."); + serverCapabilities = initialClient.ServerCapabilities; + serverInfo = initialClient.ServerInfo; + serverInstructions = initialClient.ServerInstructions; + negotiatedProtocolVersion = initialClient.NegotiatedProtocolVersion; + + await initialClient.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + } + + Assert.NotNull(serverCapabilities); + Assert.NotNull(serverInfo); + Assert.False(string.IsNullOrEmpty(resumedSessionId)); + + await serverTcs.Task.WaitAsync(TimeSpan.FromSeconds(5), TestContext.Current.CancellationToken); + + await using var resumeTransport = new HttpClientTransport(new() + { + Endpoint = new("http://localhost:5000/"), + TransportMode = HttpTransportMode.StreamableHttp, + KnownSessionId = resumedSessionId!, + }, HttpClient, LoggerFactory); + + var resumeOptions = new ResumeClientSessionOptions + { + ServerCapabilities = serverCapabilities!, + ServerInfo = serverInfo!, + ServerInstructions = serverInstructions, + NegotiatedProtocolVersion = negotiatedProtocolVersion, + }; + + await using (var resumedClient = await McpClient.ResumeSessionAsync( + resumeTransport, + resumeOptions, + loggerFactory: LoggerFactory, + cancellationToken: TestContext.Current.CancellationToken)) + { + var tools = await resumedClient.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + Assert.NotEmpty(tools); + + Assert.Equal(serverInstructions, resumedClient.ServerInstructions); + Assert.Equal(negotiatedProtocolVersion, resumedClient.NegotiatedProtocolVersion); + } + + Assert.Equal(1, runSessionCount); + } } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs index f1cd458f..366b9f41 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs @@ -6,6 +6,8 @@ using ModelContextProtocol.Client; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; +using System.Threading; +using System.Threading.Tasks; using System.Text.Json; using System.Text.Json.Serialization.Metadata; @@ -107,6 +109,22 @@ private async Task StartAsync(bool enableDelete = false) await _app.StartAsync(TestContext.Current.CancellationToken); } + private async Task StartResumeServerAsync(string expectedSessionId) + { + Builder.Services.Configure(options => + { + options.SerializerOptions.TypeInfoResolverChain.Add(McpJsonUtilities.DefaultOptions.TypeInfoResolver!); + }); + + _app = Builder.Build(); + + var resumeServer = new ResumeTestServer(expectedSessionId); + resumeServer.MapEndpoints(_app); + + await _app.StartAsync(TestContext.Current.CancellationToken); + return resumeServer; + } + [Fact] public async Task CanCallToolOnSessionlessStreamableHttpServer() { @@ -174,6 +192,94 @@ public async Task SendsDeleteRequestOnDispose() Assert.Equal("test-session-123", sessionId); } + [Fact] + public async Task DoesNotSendDeleteWhenTransportDoesNotOwnSession() + { + await StartAsync(enableDelete: true); + + await using var transport = new HttpClientTransport(new() + { + Endpoint = new("http://localhost:5000/mcp"), + TransportMode = HttpTransportMode.StreamableHttp, + OwnsSession = false, + }, HttpClient, LoggerFactory); + + await using (await McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)) + { + // No-op. Disposing the client should not trigger a DELETE request. + } + + Assert.Empty(_deleteRequestSessionIds); + } + + [Fact] + public async Task ResumeSessionStartsGetImmediately() + { + const string sessionId = "resume-session-123"; + const string resumeInstructions = "Use cached instructions"; + const string resumeProtocolVersion = "2025-06-18"; + var resumeServer = await StartResumeServerAsync(sessionId); + + await using var transport = new HttpClientTransport(new() + { + Endpoint = new("http://localhost:5000/mcp"), + TransportMode = HttpTransportMode.StreamableHttp, + KnownSessionId = sessionId, + }, HttpClient, LoggerFactory); + + var serverCapabilities = new ServerCapabilities + { + Tools = new(), + }; + var resumeOptions = new ResumeClientSessionOptions + { + ServerCapabilities = serverCapabilities, + ServerInfo = new Implementation { Name = "resume-server", Version = "1.0.0" }, + ServerInstructions = resumeInstructions, + NegotiatedProtocolVersion = resumeProtocolVersion, + }; + + await using (var client = await McpClient.ResumeSessionAsync( + transport, + resumeOptions, + loggerFactory: LoggerFactory, + cancellationToken: TestContext.Current.CancellationToken)) + { + var observedSessionId = await resumeServer.GetStarted.WaitAsync(TimeSpan.FromSeconds(5), TestContext.Current.CancellationToken); + Assert.Equal(sessionId, observedSessionId); + + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + var tool = Assert.Single(tools); + Assert.Equal("resume-echo", tool.Name); + + Assert.Equal(sessionId, Assert.Single(resumeServer.PostSessionIds)); + Assert.Same(serverCapabilities, client.ServerCapabilities); + Assert.Same(resumeOptions.ServerInfo, client.ServerInfo); + Assert.Equal(resumeInstructions, client.ServerInstructions); + Assert.Equal(resumeProtocolVersion, client.NegotiatedProtocolVersion); + } + + Assert.Equal(sessionId, Assert.Single(resumeServer.DeleteSessionIds)); + } + + [Fact] + public async Task CreateAsyncWithKnownSessionIdThrows() + { + await StartAsync(); + + await using var transport = new HttpClientTransport(new() + { + Endpoint = new("http://localhost:5000/mcp"), + TransportMode = HttpTransportMode.StreamableHttp, + KnownSessionId = "already-initialized", + }, HttpClient, LoggerFactory); + + var exception = await Assert.ThrowsAsync(() => + McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); + + Assert.Contains(nameof(McpClient.ResumeSessionAsync), exception.Message); + } + private static async Task CallEchoAndValidateAsync(McpClientTool echoTool) { var response = await echoTool.CallAsync(new Dictionary() { ["message"] = "Hello world!" }, cancellationToken: TestContext.Current.CancellationToken); @@ -198,4 +304,98 @@ private static string Echo(string message) { return message; } + + private sealed class ResumeTestServer + { + private static readonly Tool ResumeTool = new() + { + Name = "resume-echo", + Description = "Echoes the provided message.", + }; + + private readonly string _expectedSessionId; + private readonly TaskCompletionSource _getStarted = new(TaskCreationOptions.RunContinuationsAsynchronously); + private readonly List _postSessionIds = []; + private readonly List _deleteSessionIds = []; + + public ResumeTestServer(string expectedSessionId) + { + _expectedSessionId = expectedSessionId; + } + + public Task GetStarted => _getStarted.Task; + public IReadOnlyList PostSessionIds => _postSessionIds; + public IReadOnlyList DeleteSessionIds => _deleteSessionIds; + + public void MapEndpoints(WebApplication app) + { + app.MapGet("/mcp", HandleGetAsync); + app.MapPost("/mcp", HandlePostAsync); + app.MapDelete("/mcp", HandleDeleteAsync); + } + + private async Task HandleGetAsync(HttpContext context) + { + var sessionId = context.Request.Headers["mcp-session-id"].ToString(); + if (!string.Equals(sessionId, _expectedSessionId, StringComparison.Ordinal)) + { + context.Response.StatusCode = StatusCodes.Status404NotFound; + return; + } + + context.Response.Headers.ContentType = "text/event-stream"; + _getStarted.TrySetResult(sessionId); + await context.Response.Body.FlushAsync(); + + try + { + await Task.Delay(Timeout.Infinite, context.RequestAborted); + } + catch (OperationCanceledException) + { + } + } + + private async Task HandlePostAsync(HttpContext context) + { + var sessionId = context.Request.Headers["mcp-session-id"].ToString(); + _postSessionIds.Add(sessionId); + + if (!string.Equals(sessionId, _expectedSessionId, StringComparison.Ordinal)) + { + context.Response.StatusCode = StatusCodes.Status404NotFound; + return; + } + + var request = await context.Request.ReadFromJsonAsync(GetJsonTypeInfo(), context.RequestAborted); + if (request is null) + { + context.Response.StatusCode = StatusCodes.Status400BadRequest; + return; + } + + if (request.Method == RequestMethods.ToolsList) + { + var response = new JsonRpcResponse + { + Id = request.Id, + Result = JsonSerializer.SerializeToNode(new ListToolsResult + { + Tools = [ResumeTool], + }, McpJsonUtilities.DefaultOptions), + }; + + await context.Response.WriteAsJsonAsync(response, cancellationToken: context.RequestAborted); + return; + } + + context.Response.StatusCode = StatusCodes.Status202Accepted; + } + + private Task HandleDeleteAsync(HttpContext context) + { + _deleteSessionIds.Add(context.Request.Headers["mcp-session-id"].ToString()); + return Task.CompletedTask; + } + } }