Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ private async Task InitializeAsync(JsonRpcMessage message, CancellationToken can

private async Task InitializeSseTransportAsync(JsonRpcMessage message, CancellationToken cancellationToken)
{
if (_options.KnownSessionId is not null)
{
throw new InvalidOperationException("Streamable HTTP transport is required to resume an existing session.");
}

var sseTransport = new SseClientSessionTransport(_name, _options, _httpClient, _messageChannel, _loggerFactory);

try
Expand Down
5 changes: 5 additions & 0 deletions src/ModelContextProtocol.Core/Client/HttpClientTransport.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ public HttpClientTransport(HttpClientTransportOptions transportOptions, HttpClie
/// <inheritdoc />
public async Task<ITransport> ConnectAsync(CancellationToken cancellationToken = default)
{
if (_options.KnownSessionId is not null && _options.TransportMode == HttpTransportMode.Sse)
{
throw new InvalidOperationException("SSE transport does not support resuming an existing session.");
}

return _options.TransportMode switch
{
HttpTransportMode.AutoDetect => new AutoDetectingClientSessionTransport(Name, _options, _mcpHttpClient, _loggerFactory),
Expand Down
27 changes: 27 additions & 0 deletions src/ModelContextProtocol.Core/Client/HttpClientTransportOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,33 @@ public required Uri Endpoint
/// </remarks>
public IDictionary<string, string>? AdditionalHeaders { get; set; }

/// <summary>
/// Gets or sets a session identifier that should be reused when connecting to a Streamable HTTP server.
/// </summary>
/// <remarks>
/// <para>
/// When non-<see langword="null"/>, the transport assumes the server already created the session and will include the
/// specified session identifier in every HTTP request. This allows reconnecting to an existing session created in a
/// previous process. This option is only supported by the Streamable HTTP transport mode.
/// </para>
/// <para>
/// Clients should pair this with
/// <see cref="McpClient.ResumeSessionAsync(IClientTransport, ResumeClientSessionOptions, McpClientOptions?, Microsoft.Extensions.Logging.ILoggerFactory?, CancellationToken)"/>
/// to skip the initialization handshake when rehydrating a previously negotiated session.
/// </para>
/// </remarks>
public string? KnownSessionId { get; set; }

/// <summary>
/// Gets or sets a value indicating whether this transport endpoint is responsible for ending the session on dispose.
/// </summary>
/// <remarks>
/// When <see langword="true"/> (default), the transport sends a DELETE request that informs the server the session is
/// complete. Set this to <see langword="false"/> when creating a transport used solely to bootstrap session information
/// that will later be resumed elsewhere.
/// </remarks>
public bool OwnsSession { get; set; } = true;

/// <summary>
/// Gets sor sets the authorization provider to use for authentication.
/// </summary>
Expand Down
30 changes: 30 additions & 0 deletions src/ModelContextProtocol.Core/Client/McpClient.Methods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,36 @@ public static async Task<McpClient> CreateAsync(
return clientSession;
}

/// <summary>
/// Recreates an <see cref="McpClient"/> using an existing transport session without sending a new initialize request.
/// </summary>
/// <param name="clientTransport">The transport instance already configured to connect to the target server.</param>
/// <param name="resumeOptions">The metadata captured from the original session that should be applied when resuming.</param>
/// <param name="clientOptions">Optional client settings that should mirror those used to create the original session.</param>
/// <param name="loggerFactory">An optional logger factory for diagnostics.</param>
/// <param name="cancellationToken">Token used when establishing the transport connection.</param>
/// <returns>An <see cref="McpClient"/> bound to the resumed session.</returns>
/// <exception cref="ArgumentNullException">Thrown when <paramref name="clientTransport"/> or <paramref name="resumeOptions"/> is <see langword="null"/>.</exception>
public static async Task<McpClient> ResumeSessionAsync(
IClientTransport clientTransport,
ResumeClientSessionOptions resumeOptions,
McpClientOptions? clientOptions = null,
ILoggerFactory? loggerFactory = null,
CancellationToken cancellationToken = default)
{
Throw.IfNull(clientTransport);
Throw.IfNull(resumeOptions);
Throw.IfNull(resumeOptions.ServerCapabilities);
Throw.IfNull(resumeOptions.ServerInfo);

var transport = await clientTransport.ConnectAsync(cancellationToken).ConfigureAwait(false);
var endpointName = clientTransport.Name;

var clientSession = new McpClientImpl(transport, endpointName, clientOptions, loggerFactory);
clientSession.ResumeSession(resumeOptions);
return clientSession;
}

/// <summary>
/// Sends a ping request to verify server connectivity.
/// </summary>
Expand Down
27 changes: 26 additions & 1 deletion src/ModelContextProtocol.Core/Client/McpClientImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ private void RegisterHandlers(McpClientOptions options, NotificationHandlers not
cancellationToken),
McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams,
McpJsonUtilities.JsonContext.Default.CreateMessageResult);

_options.Capabilities ??= new();
_options.Capabilities.Sampling ??= new();
}
Expand Down Expand Up @@ -207,6 +207,28 @@ await this.SendNotificationAsync(
LogClientConnected(_endpointName);
}

/// <summary>
/// Configures the client to use an already initialized session without performing the handshake.
/// </summary>
/// <param name="resumeOptions">The metadata captured from the previous session that should be applied to the resumed client.</param>
internal void ResumeSession(ResumeClientSessionOptions resumeOptions)
{
Throw.IfNull(resumeOptions);
Throw.IfNull(resumeOptions.ServerCapabilities);
Throw.IfNull(resumeOptions.ServerInfo);

_ = _sessionHandler.ProcessMessagesAsync(CancellationToken.None);

_serverCapabilities = resumeOptions.ServerCapabilities;
_serverInfo = resumeOptions.ServerInfo;
_serverInstructions = resumeOptions.ServerInstructions;
_negotiatedProtocolVersion = resumeOptions.NegotiatedProtocolVersion
?? _options.ProtocolVersion
?? McpSessionHandler.LatestProtocolVersion;

LogClientSessionResumed(_endpointName);
}

/// <inheritdoc/>
public override Task<JsonRpcResponse> SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default)
=> _sessionHandler.SendRequestAsync(request, cancellationToken);
Expand Down Expand Up @@ -249,4 +271,7 @@ public override async ValueTask DisposeAsync()

[LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} client created and connected.")]
private partial void LogClientConnected(string endpointName);

[LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} client resumed existing session.")]
private partial void LogClientSessionResumed(string endpointName);
}
29 changes: 29 additions & 0 deletions src/ModelContextProtocol.Core/Client/ResumeClientSessionOptions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
using ModelContextProtocol.Protocol;

namespace ModelContextProtocol.Client;

/// <summary>
/// Provides the metadata captured from a previous MCP client session that is required to resume it.
/// </summary>
public sealed class ResumeClientSessionOptions
{
/// <summary>
/// Gets or sets the server capabilities that were negotiated during the original session setialization.
/// </summary>
public required ServerCapabilities ServerCapabilities { get; set; }

/// <summary>
/// Gets or sets the server implementation metadata that identifies the connected MCP server.
/// </summary>
public required Implementation ServerInfo { get; set; }

/// <summary>
/// Gets or sets any instructions previously supplied by the server.
/// </summary>
public string? ServerInstructions { get; set; }

/// <summary>
/// Gets or sets the protocol version that was negotiated with the server.
/// </summary>
public string? NegotiatedProtocolVersion { get; set; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ public StreamableHttpClientSessionTransport(
// until the first call to SendMessageAsync. Fortunately, that happens internally in McpClient.ConnectAsync
// so we still throw any connection-related Exceptions from there and never expose a pre-connected client to the user.
SetConnected();

if (_options.KnownSessionId is { } knownSessionId)
{
SessionId = knownSessionId;
_getReceiveTask = ReceiveUnsolicitedMessagesAsync();
}
}

/// <inheritdoc/>
Expand All @@ -60,6 +66,14 @@ public override async Task SendMessageAsync(JsonRpcMessage message, Cancellation
// This is used by the auto transport so it can fall back and try SSE given a non-200 response without catching an exception.
internal async Task<HttpResponseMessage> SendHttpRequestAsync(JsonRpcMessage message, CancellationToken cancellationToken)
{
if (_options.KnownSessionId is not null &&
message is JsonRpcRequest { Method: RequestMethods.Initialize })
{
throw new InvalidOperationException(
$"Cannot send '{RequestMethods.Initialize}' when {nameof(HttpClientTransportOptions)}.{nameof(HttpClientTransportOptions.KnownSessionId)} is configured. " +
$"Call {nameof(McpClient)}.{nameof(McpClient.ResumeSessionAsync)} to resume existing sessions.");
}

using var sendCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _connectionCts.Token);
cancellationToken = sendCts.Token;

Expand Down Expand Up @@ -116,7 +130,7 @@ internal async Task<HttpResponseMessage> SendHttpRequestAsync(JsonRpcMessage mes
var initializeResult = JsonSerializer.Deserialize(initResponse.Result, McpJsonUtilities.JsonContext.Default.InitializeResult);
_negotiatedProtocolVersion = initializeResult?.ProtocolVersion;

_getReceiveTask = ReceiveUnsolicitedMessagesAsync();
_getReceiveTask ??= ReceiveUnsolicitedMessagesAsync();
}

return response;
Expand All @@ -139,7 +153,7 @@ public override async ValueTask DisposeAsync()
try
{
// Send DELETE request to terminate the session. Only send if we have a session ID, per MCP spec.
if (!string.IsNullOrEmpty(SessionId))
if (_options.OwnsSession && !string.IsNullOrEmpty(SessionId))
{
await SendDeleteRequest();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Primitives;
using ModelContextProtocol.Client;
using ModelContextProtocol.Protocol;
using ModelContextProtocol.Server;
using System.Collections.Concurrent;
using System.Threading;
using System.Threading.Tasks;

namespace ModelContextProtocol.AspNetCore.Tests;

Expand Down Expand Up @@ -188,4 +192,95 @@ public async Task StreamableHttpClient_SendsMcpProtocolVersionHeader_AfterInitia
Assert.True(protocolVersionHeaderValues.Count > 1);
Assert.All(protocolVersionHeaderValues, v => Assert.Equal("2025-03-26", v));
}

[Fact]
public async Task CanResumeSessionWithMapMcpAndRunSessionHandler()
{
Assert.SkipWhen(Stateless, "Session resumption relies on server-side session tracking.");

var runSessionCount = 0;
var serverTcs = new TaskCompletionSource<McpServer>(TaskCreationOptions.RunContinuationsAsynchronously);

Builder.Services.AddMcpServer(options =>
{
options.ServerInfo = new Implementation
{
Name = "ResumeServer",
Version = "1.0.0",
};
}).WithHttpTransport(opts =>
{
ConfigureStateless(opts);
opts.RunSessionHandler = async (context, server, cancellationToken) =>
{
Interlocked.Increment(ref runSessionCount);
serverTcs.TrySetResult(server);
await server.RunAsync(cancellationToken);
};
}).WithTools<EchoHttpContextUserTools>();

await using var app = Builder.Build();
app.MapMcp();
await app.StartAsync(TestContext.Current.CancellationToken);

ServerCapabilities? serverCapabilities = null;
Implementation? serverInfo = null;
string? serverInstructions = null;
string? negotiatedProtocolVersion = null;
string? resumedSessionId = null;

await using var initialTransport = new HttpClientTransport(new()
{
Endpoint = new("http://localhost:5000/"),
TransportMode = HttpTransportMode.StreamableHttp,
OwnsSession = false,
}, HttpClient, LoggerFactory);

await using (var initialClient = await McpClient.CreateAsync(initialTransport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken))
{
resumedSessionId = initialClient.SessionId ?? throw new InvalidOperationException("SessionId not negotiated.");
serverCapabilities = initialClient.ServerCapabilities;
serverInfo = initialClient.ServerInfo;
serverInstructions = initialClient.ServerInstructions;
negotiatedProtocolVersion = initialClient.NegotiatedProtocolVersion;

await initialClient.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken);
}

Assert.NotNull(serverCapabilities);
Assert.NotNull(serverInfo);
Assert.False(string.IsNullOrEmpty(resumedSessionId));

await serverTcs.Task.WaitAsync(TimeSpan.FromSeconds(5), TestContext.Current.CancellationToken);

await using var resumeTransport = new HttpClientTransport(new()
{
Endpoint = new("http://localhost:5000/"),
TransportMode = HttpTransportMode.StreamableHttp,
KnownSessionId = resumedSessionId!,
}, HttpClient, LoggerFactory);

var resumeOptions = new ResumeClientSessionOptions
{
ServerCapabilities = serverCapabilities!,
ServerInfo = serverInfo!,
ServerInstructions = serverInstructions,
NegotiatedProtocolVersion = negotiatedProtocolVersion,
};

await using (var resumedClient = await McpClient.ResumeSessionAsync(
resumeTransport,
resumeOptions,
loggerFactory: LoggerFactory,
cancellationToken: TestContext.Current.CancellationToken))
{
var tools = await resumedClient.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken);
Assert.NotEmpty(tools);

Assert.Equal(serverInstructions, resumedClient.ServerInstructions);
Assert.Equal(negotiatedProtocolVersion, resumedClient.NegotiatedProtocolVersion);
}

Assert.Equal(1, runSessionCount);
}
}
Loading
Loading