diff --git a/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs
index f67c57afd..c9a5ba87f 100644
--- a/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs
+++ b/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs
@@ -11,11 +11,11 @@ public static class HttpMcpServerBuilderExtensions
{
///
/// Adds the services necessary for
- /// to handle MCP requests and sessions using the MCP HTTP Streaming transport. For more information on configuring the underlying HTTP server
+ /// to handle MCP requests and sessions using the MCP Streamable HTTP transport. For more information on configuring the underlying HTTP server
/// to control things like port binding custom TLS certificates, see the Minimal APIs quick reference.
///
/// The builder instance.
- /// Configures options for the HTTP Streaming transport. This allows configuring per-session
+ /// Configures options for the Streamable HTTP transport. This allows configuring per-session
/// and running logic before and after a session.
/// The builder provided in .
/// is .
@@ -23,6 +23,8 @@ public static IMcpServerBuilder WithHttpTransport(this IMcpServerBuilder builder
{
ArgumentNullException.ThrowIfNull(builder);
builder.Services.TryAddSingleton();
+ builder.Services.TryAddSingleton();
+ builder.Services.AddHostedService();
if (configureOptions is not null)
{
diff --git a/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs b/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs
index 216962a8b..fed2f131e 100644
--- a/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs
+++ b/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs
@@ -1,18 +1,61 @@
using ModelContextProtocol.Protocol.Transport;
+using ModelContextProtocol.Server;
using System.Security.Claims;
namespace ModelContextProtocol.AspNetCore;
-internal class HttpMcpSession
+internal sealed class HttpMcpSession(string sessionId, TTransport transport, ClaimsPrincipal user, TimeProvider timeProvider) : IAsyncDisposable
+ where TTransport : ITransport
{
- public HttpMcpSession(SseResponseStreamTransport transport, ClaimsPrincipal user)
+ private int _referenceCount;
+ private int _getRequestStarted;
+ private CancellationTokenSource _disposeCts = new();
+
+ public string Id { get; } = sessionId;
+ public TTransport Transport { get; } = transport;
+ public (string Type, string Value, string Issuer)? UserIdClaim { get; } = GetUserIdClaim(user);
+
+ public CancellationToken SessionClosed => _disposeCts.Token;
+
+ public bool IsActive => !SessionClosed.IsCancellationRequested && _referenceCount > 0;
+ public long LastActivityTicks { get; private set; } = timeProvider.GetTimestamp();
+
+ public IMcpServer? Server { get; set; }
+ public Task? ServerRunTask { get; set; }
+
+ public IDisposable AcquireReference()
{
- Transport = transport;
- UserIdClaim = GetUserIdClaim(user);
+ Interlocked.Increment(ref _referenceCount);
+ return new UnreferenceDisposable(this, timeProvider);
}
- public SseResponseStreamTransport Transport { get; }
- public (string Type, string Value, string Issuer)? UserIdClaim { get; }
+ public bool TryStartGetRequest() => Interlocked.Exchange(ref _getRequestStarted, 1) == 0;
+
+ public async ValueTask DisposeAsync()
+ {
+ try
+ {
+ await _disposeCts.CancelAsync();
+
+ if (ServerRunTask is not null)
+ {
+ await ServerRunTask;
+ }
+ }
+ catch (OperationCanceledException)
+ {
+ }
+ finally
+ {
+ if (Server is not null)
+ {
+ await Server.DisposeAsync();
+ }
+
+ await Transport.DisposeAsync();
+ _disposeCts.Dispose();
+ }
+ }
public bool HasSameUserId(ClaimsPrincipal user)
=> UserIdClaim == GetUserIdClaim(user);
@@ -36,4 +79,15 @@ private static (string Type, string Value, string Issuer)? GetUserIdClaim(Claims
return null;
}
+
+ private sealed class UnreferenceDisposable(HttpMcpSession session, TimeProvider timeProvider) : IDisposable
+ {
+ public void Dispose()
+ {
+ if (Interlocked.Decrement(ref session._referenceCount) == 0)
+ {
+ session.LastActivityTicks = timeProvider.GetTimestamp();
+ }
+ }
+ }
}
diff --git a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs
index 850dac244..23eeddbea 100644
--- a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs
+++ b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs
@@ -21,4 +21,17 @@ public class HttpServerTransportOptions
/// This is useful for running logic before a sessions starts and after it completes.
///
public Func? RunSessionHandler { get; set; }
+
+ ///
+ /// Represents the duration of time the server will wait between any active requests before timing out an
+ /// MCP session. This is checked in background every 5 seconds. A client trying to resume a session will
+ /// receive a 404 status code and should restart their session. A client can keep their session open by
+ /// keeping a GET request open. The default value is set to 2 minutes.
+ ///
+ public TimeSpan IdleTimeout { get; set; } = TimeSpan.FromMinutes(2);
+
+ ///
+ /// Used for testing the .
+ ///
+ public TimeProvider TimeProvider { get; set; } = TimeProvider.System;
}
diff --git a/src/ModelContextProtocol.AspNetCore/IdleTrackingBackgroundService.cs b/src/ModelContextProtocol.AspNetCore/IdleTrackingBackgroundService.cs
new file mode 100644
index 000000000..df3203b5b
--- /dev/null
+++ b/src/ModelContextProtocol.AspNetCore/IdleTrackingBackgroundService.cs
@@ -0,0 +1,105 @@
+using Microsoft.Extensions.Hosting;
+using Microsoft.Extensions.Logging;
+using Microsoft.Extensions.Options;
+using ModelContextProtocol.Protocol.Transport;
+
+namespace ModelContextProtocol.AspNetCore;
+
+internal sealed partial class IdleTrackingBackgroundService(
+ StreamableHttpHandler handler,
+ IOptions options,
+ ILogger logger) : BackgroundService
+{
+ // The compiler will complain about the parameter being unused otherwise despite the source generator.
+ private ILogger _logger = logger;
+
+ // We can make this configurable once we properly harden the MCP server. In the meantime, anyone running
+ // this should be taking a cattle not pets approach to their servers and be able to launch more processes
+ // to handle more than 10,000 idle sessions at a time.
+ private const int MaxIdleSessionCount = 10_000;
+
+ protected override async Task ExecuteAsync(CancellationToken stoppingToken)
+ {
+ var timeProvider = options.Value.TimeProvider;
+ using var timer = new PeriodicTimer(TimeSpan.FromSeconds(5), timeProvider);
+
+ try
+ {
+ while (!stoppingToken.IsCancellationRequested && await timer.WaitForNextTickAsync(stoppingToken))
+ {
+ var idleActivityCutoff = timeProvider.GetTimestamp() - options.Value.IdleTimeout.Ticks;
+
+ var idleCount = 0;
+ foreach (var (_, session) in handler.Sessions)
+ {
+ if (session.IsActive || session.SessionClosed.IsCancellationRequested)
+ {
+ // There's a request currently active or the session is already being closed.
+ continue;
+ }
+
+ idleCount++;
+ if (idleCount == MaxIdleSessionCount)
+ {
+ // Emit critical log at most once every 5 seconds the idle count it exceeded,
+ //since the IdleTimeout will no longer be respected.
+ LogMaxSessionIdleCountExceeded();
+ }
+ else if (idleCount < MaxIdleSessionCount && session.LastActivityTicks > idleActivityCutoff)
+ {
+ continue;
+ }
+
+ if (handler.Sessions.TryRemove(session.Id, out var removedSession))
+ {
+ LogSessionIdle(removedSession.Id);
+
+ // Don't slow down the idle tracking loop. DisposeSessionAsync logs. We only await during graceful shutdown.
+ _ = DisposeSessionAsync(removedSession);
+ }
+ }
+ }
+ }
+ catch (OperationCanceledException) when (stoppingToken.IsCancellationRequested)
+ {
+ }
+ finally
+ {
+ if (stoppingToken.IsCancellationRequested)
+ {
+ List disposeSessionTasks = [];
+
+ foreach (var (sessionKey, _) in handler.Sessions)
+ {
+ if (handler.Sessions.TryRemove(sessionKey, out var session))
+ {
+ disposeSessionTasks.Add(DisposeSessionAsync(session));
+ }
+ }
+
+ await Task.WhenAll(disposeSessionTasks);
+ }
+ }
+ }
+
+ private async Task DisposeSessionAsync(HttpMcpSession session)
+ {
+ try
+ {
+ await session.DisposeAsync();
+ }
+ catch (Exception ex)
+ {
+ LogSessionDisposeError(session.Id, ex);
+ }
+ }
+
+ [LoggerMessage(Level = LogLevel.Information, Message = "Closing idle session {sessionId}.")]
+ private partial void LogSessionIdle(string sessionId);
+
+ [LoggerMessage(Level = LogLevel.Critical, Message = "Exceeded static maximum of 10,000 idle connections. Now clearing all inactive connections regardless of timeout.")]
+ private partial void LogMaxSessionIdleCountExceeded();
+
+ [LoggerMessage(Level = LogLevel.Error, Message = "Error disposing the IMcpServer for session {sessionId}.")]
+ private partial void LogSessionDisposeError(string sessionId, Exception ex);
+}
diff --git a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs
index ac424cc8b..0eefa52fb 100644
--- a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs
+++ b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs
@@ -1,6 +1,9 @@
-using Microsoft.AspNetCore.Routing;
+using Microsoft.AspNetCore.Http;
+using Microsoft.AspNetCore.Http.Metadata;
+using Microsoft.AspNetCore.Routing;
using Microsoft.Extensions.DependencyInjection;
using ModelContextProtocol.AspNetCore;
+using ModelContextProtocol.Protocol.Messages;
using System.Diagnostics.CodeAnalysis;
namespace Microsoft.AspNetCore.Builder;
@@ -11,21 +14,42 @@ namespace Microsoft.AspNetCore.Builder;
public static class McpEndpointRouteBuilderExtensions
{
///
- /// Sets up endpoints for handling MCP HTTP Streaming transport.
- /// See the protocol specification for details about the Streamable HTTP transport.
+ /// Sets up endpoints for handling MCP Streamable HTTP transport.
+ /// See the 2025-03-26 protocol specification for details about the Streamable HTTP transport.
+ /// Also maps legacy SSE endpoints for backward compatibility at the path "/sse" and "/message". the 2024-11-05 protocol specification for details about the HTTP with SSE transport.
///
/// The web application to attach MCP HTTP endpoints.
/// The route pattern prefix to map to.
/// Returns a builder for configuring additional endpoint conventions like authorization policies.
public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpoints, [StringSyntax("Route")] string pattern = "")
{
- var handler = endpoints.ServiceProvider.GetService() ??
+ var streamableHttpHandler = endpoints.ServiceProvider.GetService() ??
throw new InvalidOperationException("You must call WithHttpTransport(). Unable to find required services. Call builder.Services.AddMcpServer().WithHttpTransport() in application startup code.");
- var routeGroup = endpoints.MapGroup(pattern);
- routeGroup.MapGet("", handler.HandleRequestAsync);
- routeGroup.MapGet("/sse", handler.HandleRequestAsync);
- routeGroup.MapPost("/message", handler.HandleRequestAsync);
- return routeGroup;
+ var mcpGroup = endpoints.MapGroup(pattern);
+ var streamableHttpGroup = mcpGroup.MapGroup("")
+ .WithDisplayName(b => $"MCP Streamable HTTP | {b.DisplayName}")
+ .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status404NotFound, typeof(JsonRpcError), contentTypes: ["application/json"]));
+
+ streamableHttpGroup.MapPost("", streamableHttpHandler.HandlePostRequestAsync)
+ .WithMetadata(new AcceptsMetadata(["application/json"]))
+ .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status200OK, contentTypes: ["text/event-stream"]))
+ .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status202Accepted));
+ streamableHttpGroup.MapGet("", streamableHttpHandler.HandleGetRequestAsync)
+ .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status200OK, contentTypes: ["text/event-stream"]));
+ streamableHttpGroup.MapDelete("", streamableHttpHandler.HandleDeleteRequestAsync);
+
+ // Map legacy HTTP with SSE endpoints.
+ var sseHandler = endpoints.ServiceProvider.GetRequiredService();
+ var sseGroup = mcpGroup.MapGroup("")
+ .WithDisplayName(b => $"MCP HTTP with SSE | {b.DisplayName}");
+
+ sseGroup.MapGet("/sse", sseHandler.HandleSseRequestAsync)
+ .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status200OK, contentTypes: ["text/event-stream"]));
+ sseGroup.MapPost("/message", sseHandler.HandleMessageRequestAsync)
+ .WithMetadata(new AcceptsMetadata(["application/json"]))
+ .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status202Accepted));
+
+ return mcpGroup;
}
}
diff --git a/src/ModelContextProtocol.AspNetCore/SseHandler.cs b/src/ModelContextProtocol.AspNetCore/SseHandler.cs
new file mode 100644
index 000000000..638592198
--- /dev/null
+++ b/src/ModelContextProtocol.AspNetCore/SseHandler.cs
@@ -0,0 +1,110 @@
+using Microsoft.AspNetCore.Http;
+using Microsoft.Extensions.Hosting;
+using Microsoft.Extensions.Logging;
+using Microsoft.Extensions.Options;
+using ModelContextProtocol.Protocol.Messages;
+using ModelContextProtocol.Protocol.Transport;
+using ModelContextProtocol.Server;
+using ModelContextProtocol.Utils.Json;
+using System.Collections.Concurrent;
+using System.Diagnostics;
+
+namespace ModelContextProtocol.AspNetCore;
+
+internal sealed class SseHandler(
+ IOptions mcpServerOptionsSnapshot,
+ IOptionsFactory mcpServerOptionsFactory,
+ IOptions httpMcpServerOptions,
+ IHostApplicationLifetime hostApplicationLifetime,
+ ILoggerFactory loggerFactory)
+{
+ private readonly ConcurrentDictionary> _sessions = new(StringComparer.Ordinal);
+
+ public async Task HandleSseRequestAsync(HttpContext context)
+ {
+ var sessionId = StreamableHttpHandler.MakeNewSessionId();
+
+ // If the server is shutting down, we need to cancel all SSE connections immediately without waiting for HostOptions.ShutdownTimeout
+ // which defaults to 30 seconds.
+ using var sseCts = CancellationTokenSource.CreateLinkedTokenSource(context.RequestAborted, hostApplicationLifetime.ApplicationStopping);
+ var cancellationToken = sseCts.Token;
+
+ StreamableHttpHandler.InitializeSseResponse(context);
+
+ await using var transport = new SseResponseStreamTransport(context.Response.Body, $"message?sessionId={sessionId}");
+ await using var httpMcpSession = new HttpMcpSession(sessionId, transport, context.User, httpMcpServerOptions.Value.TimeProvider);
+ if (!_sessions.TryAdd(sessionId, httpMcpSession))
+ {
+ throw new UnreachableException($"Unreachable given good entropy! Session with ID '{sessionId}' has already been created.");
+ }
+
+ try
+ {
+ var mcpServerOptions = mcpServerOptionsSnapshot.Value;
+ if (httpMcpServerOptions.Value.ConfigureSessionOptions is { } configureSessionOptions)
+ {
+ mcpServerOptions = mcpServerOptionsFactory.Create(Options.DefaultName);
+ await configureSessionOptions(context, mcpServerOptions, cancellationToken);
+ }
+
+ var transportTask = transport.RunAsync(cancellationToken);
+
+ try
+ {
+ await using var mcpServer = McpServerFactory.Create(transport, mcpServerOptions, loggerFactory, context.RequestServices);
+ httpMcpSession.Server = mcpServer;
+ context.Features.Set(mcpServer);
+
+ var runSessionAsync = httpMcpServerOptions.Value.RunSessionHandler ?? StreamableHttpHandler.RunSessionAsync;
+ httpMcpSession.ServerRunTask = runSessionAsync(context, mcpServer, cancellationToken);
+ await httpMcpSession.ServerRunTask;
+ }
+ finally
+ {
+ await transport.DisposeAsync();
+ await transportTask;
+ }
+ }
+ catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
+ {
+ // RequestAborted always triggers when the client disconnects before a complete response body is written,
+ // but this is how SSE connections are typically closed.
+ }
+ finally
+ {
+ _sessions.TryRemove(sessionId, out _);
+ }
+ }
+
+ public async Task HandleMessageRequestAsync(HttpContext context)
+ {
+ if (!context.Request.Query.TryGetValue("sessionId", out var sessionId))
+ {
+ await Results.BadRequest("Missing sessionId query parameter.").ExecuteAsync(context);
+ return;
+ }
+
+ if (!_sessions.TryGetValue(sessionId.ToString(), out var httpMcpSession))
+ {
+ await Results.BadRequest($"Session ID not found.").ExecuteAsync(context);
+ return;
+ }
+
+ if (!httpMcpSession.HasSameUserId(context.User))
+ {
+ await Results.Forbid().ExecuteAsync(context);
+ return;
+ }
+
+ var message = (JsonRpcMessage?)await context.Request.ReadFromJsonAsync(McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage)), context.RequestAborted);
+ if (message is null)
+ {
+ await Results.BadRequest("No message in request body.").ExecuteAsync(context);
+ return;
+ }
+
+ await httpMcpSession.Transport.OnMessageReceivedAsync(message, context.RequestAborted);
+ context.Response.StatusCode = StatusCodes.Status202Accepted;
+ await context.Response.WriteAsync("Accepted");
+ }
+}
diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs
index b88cc07c6..1dc976192 100644
--- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs
+++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs
@@ -1,7 +1,6 @@
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.WebUtilities;
-using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using ModelContextProtocol.Protocol.Messages;
@@ -10,7 +9,9 @@
using ModelContextProtocol.Utils.Json;
using System.Collections.Concurrent;
using System.Diagnostics;
+using System.IO.Pipelines;
using System.Security.Cryptography;
+using System.Text.Json.Serialization.Metadata;
namespace ModelContextProtocol.AspNetCore;
@@ -18,130 +19,207 @@ internal sealed class StreamableHttpHandler(
IOptions mcpServerOptionsSnapshot,
IOptionsFactory mcpServerOptionsFactory,
IOptions httpMcpServerOptions,
- IHostApplicationLifetime hostApplicationLifetime,
- ILoggerFactory loggerFactory)
+ ILoggerFactory loggerFactory,
+ IServiceProvider applicationServices)
{
+ private static JsonTypeInfo s_errorTypeInfo = GetRequiredJsonTypeInfo();
- private readonly ConcurrentDictionary _sessions = new(StringComparer.Ordinal);
- private readonly ILogger _logger = loggerFactory.CreateLogger();
+ public ConcurrentDictionary> Sessions { get; } = new(StringComparer.Ordinal);
- public async Task HandleRequestAsync(HttpContext context)
+ public async Task HandlePostRequestAsync(HttpContext context)
{
- if (context.Request.Method == HttpMethods.Get)
+ // The Streamable HTTP spec mandates the client MUST accept both application/json and text/event-stream.
+ // ASP.NET Core Minimal APIs mostly ry to stay out of the business of response content negotiation, so
+ // we have to do this manually. The spec doesn't mandate that servers MUST reject these requests, but it's
+ // probably good to at least start out trying to be strict.
+ var acceptHeader = context.Request.Headers.Accept.ToString();
+ if (!acceptHeader.Contains("application/json", StringComparison.Ordinal) ||
+ !acceptHeader.Contains("text/event-stream", StringComparison.Ordinal))
{
- await HandleSseRequestAsync(context);
+ await WriteJsonRpcErrorAsync(context,
+ "Not Acceptable: Client must accept both application/json and text/event-stream",
+ StatusCodes.Status406NotAcceptable);
+ return;
}
- else if (context.Request.Method == HttpMethods.Post)
+
+ var session = await GetOrCreateSessionAsync(context);
+ if (session is null)
{
- await HandleMessageRequestAsync(context);
+ return;
}
- else
+
+ using var _ = session.AcquireReference();
+ using var cts = CancellationTokenSource.CreateLinkedTokenSource(context.RequestAborted, session.SessionClosed);
+ InitializeSseResponse(context);
+ var wroteResponse = await session.Transport.HandlePostRequest(new HttpDuplexPipe(context), cts.Token);
+ if (!wroteResponse)
{
- context.Response.StatusCode = StatusCodes.Status405MethodNotAllowed;
- await context.Response.WriteAsync("Method Not Allowed");
+ // We wound up writing nothing, so there should be no Content-Type response header.
+ context.Response.Headers.ContentType = (string?)null;
+ context.Response.StatusCode = StatusCodes.Status202Accepted;
}
}
- public async Task HandleSseRequestAsync(HttpContext context)
+ public async Task HandleGetRequestAsync(HttpContext context)
{
- // If the server is shutting down, we need to cancel all SSE connections immediately without waiting for HostOptions.ShutdownTimeout
- // which defaults to 30 seconds.
- using var sseCts = CancellationTokenSource.CreateLinkedTokenSource(context.RequestAborted, hostApplicationLifetime.ApplicationStopping);
- var cancellationToken = sseCts.Token;
+ var acceptHeader = context.Request.Headers.Accept.ToString();
+ if (!acceptHeader.Contains("application/json", StringComparison.Ordinal))
+ {
+ await WriteJsonRpcErrorAsync(context,
+ "Not Acceptable: Client must accept text/event-stream",
+ StatusCodes.Status406NotAcceptable);
+ return;
+ }
- var response = context.Response;
- response.Headers.ContentType = "text/event-stream";
- response.Headers.CacheControl = "no-cache,no-store";
+ var sessionId = context.Request.Headers["mcp-session-id"].ToString();
+ var session = await GetSessionAsync(context, sessionId);
+ if (session is null)
+ {
+ return;
+ }
- // Make sure we disable all response buffering for SSE
- context.Response.Headers.ContentEncoding = "identity";
- context.Features.GetRequiredFeature().DisableBuffering();
+ if (!session.TryStartGetRequest())
+ {
+ await WriteJsonRpcErrorAsync(context,
+ "Bad Request: This server does not support multiple GET requests. Start a new session to get a new GET SSE response.",
+ StatusCodes.Status400BadRequest);
+ return;
+ }
- var sessionId = MakeNewSessionId();
- await using var transport = new SseResponseStreamTransport(response.Body, $"message?sessionId={sessionId}");
- var httpMcpSession = new HttpMcpSession(transport, context.User);
- if (!_sessions.TryAdd(sessionId, httpMcpSession))
+ using var _ = session.AcquireReference();
+ InitializeSseResponse(context);
+
+ // We should flush headers to indicate a 200 success quickly, because the initialization response
+ // will be sent in response to a different POST request. It might be a while before we send a message
+ // over this response body.
+ await context.Response.Body.FlushAsync(context.RequestAborted);
+ await session.Transport.HandleGetRequest(context.Response.Body, context.RequestAborted);
+ }
+
+ public async Task HandleDeleteRequestAsync(HttpContext context)
+ {
+ var sessionId = context.Request.Headers["mcp-session-id"].ToString();
+ if (Sessions.TryRemove(sessionId, out var session))
{
- Debug.Fail("Unreachable given good entropy!");
- throw new InvalidOperationException($"Session with ID '{sessionId}' has already been created.");
+ await session.DisposeAsync();
}
+ }
+
+ private void InitializeSessionResponse(HttpContext context, HttpMcpSession session)
+ {
+ context.Response.Headers["mcp-session-id"] = session.Id;
+ context.Features.Set(session.Server);
+ }
- try
+ private async ValueTask?> GetSessionAsync(HttpContext context, string sessionId)
+ {
+ if (Sessions.TryGetValue(sessionId, out var existingSession))
{
- var mcpServerOptions = mcpServerOptionsSnapshot.Value;
- if (httpMcpServerOptions.Value.ConfigureSessionOptions is { } configureSessionOptions)
+ if (!existingSession.HasSameUserId(context.User))
{
- mcpServerOptions = mcpServerOptionsFactory.Create(Options.DefaultName);
- await configureSessionOptions(context, mcpServerOptions, cancellationToken);
+ await WriteJsonRpcErrorAsync(context,
+ "Forbidden: The currently authenticated user does not match the user who initiated the session.",
+ StatusCodes.Status403Forbidden);
+ return null;
}
- var transportTask = transport.RunAsync(cancellationToken);
+ InitializeSessionResponse(context, existingSession);
+ return existingSession;
+ }
- try
- {
- await using var mcpServer = McpServerFactory.Create(transport, mcpServerOptions, loggerFactory, context.RequestServices);
- context.Features.Set(mcpServer);
+ // -32001 isn't part of the MCP standard, but this is what the typescript-sdk currently does.
+ // One of the few other usages I found was from some Ethereum JSON-RPC documentation and this
+ // JSON-RPC library from Microsoft called StreamJsonRpc where it's called JsonRpcErrorCode.NoMarshaledObjectFound
+ // https://learn.microsoft.com/dotnet/api/streamjsonrpc.protocol.jsonrpcerrorcode?view=streamjsonrpc-2.9#fields
+ await WriteJsonRpcErrorAsync(context, "Session not found", StatusCodes.Status404NotFound, 32001);
+ return null;
+ }
- var runSessionAsync = httpMcpServerOptions.Value.RunSessionHandler ?? RunSessionAsync;
- await runSessionAsync(context, mcpServer, cancellationToken);
- }
- finally
+ private async ValueTask?> GetOrCreateSessionAsync(HttpContext context)
+ {
+ var sessionId = context.Request.Headers["mcp-session-id"].ToString();
+ HttpMcpSession? session;
+
+ if (string.IsNullOrEmpty(sessionId))
+ {
+ session = await CreateSessionAsync(context);
+
+ if (!Sessions.TryAdd(session.Id, session))
{
- await transport.DisposeAsync();
- await transportTask;
+ throw new UnreachableException($"Unreachable given good entropy! Session with ID '{sessionId}' has already been created.");
}
+
+ return session;
}
- catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
- {
- // RequestAborted always triggers when the client disconnects before a complete response body is written,
- // but this is how SSE connections are typically closed.
- }
- finally
+ else
{
- _sessions.TryRemove(sessionId, out _);
+ return await GetSessionAsync(context, sessionId);
}
}
- public async Task HandleMessageRequestAsync(HttpContext context)
+ private async ValueTask> CreateSessionAsync(HttpContext context)
{
- if (!context.Request.Query.TryGetValue("sessionId", out var sessionId))
+ var mcpServerOptions = mcpServerOptionsSnapshot.Value;
+ if (httpMcpServerOptions.Value.ConfigureSessionOptions is { } configureSessionOptions)
{
- await Results.BadRequest("Missing sessionId query parameter.").ExecuteAsync(context);
- return;
+ mcpServerOptions = mcpServerOptionsFactory.Create(Options.DefaultName);
+ await configureSessionOptions(context, mcpServerOptions, context.RequestAborted);
}
- if (!_sessions.TryGetValue(sessionId.ToString(), out var httpMcpSession))
- {
- await Results.BadRequest($"Session ID not found.").ExecuteAsync(context);
- return;
- }
+ var transport = new StreamableHttpServerTransport();
+ // Use application instead of request services, because the session will likely outlive the first initialization request.
+ var server = McpServerFactory.Create(transport, mcpServerOptions, loggerFactory, applicationServices);
- if (!httpMcpSession.HasSameUserId(context.User))
+ var session = new HttpMcpSession(MakeNewSessionId(), transport, context.User, httpMcpServerOptions.Value.TimeProvider)
{
- await Results.Forbid().ExecuteAsync(context);
- return;
- }
+ Server = server,
+ };
- var message = (IJsonRpcMessage?)await context.Request.ReadFromJsonAsync(McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage)), context.RequestAborted);
- if (message is null)
- {
- await Results.BadRequest("No message in request body.").ExecuteAsync(context);
- return;
- }
+ var runSessionAsync = httpMcpServerOptions.Value.RunSessionHandler ?? RunSessionAsync;
+ session.ServerRunTask = runSessionAsync(context, server, session.SessionClosed);
- await httpMcpSession.Transport.OnMessageReceivedAsync(message, context.RequestAborted);
- context.Response.StatusCode = StatusCodes.Status202Accepted;
- await context.Response.WriteAsync("Accepted");
+ InitializeSessionResponse(context, session);
+ return session;
}
- private static Task RunSessionAsync(HttpContext httpContext, IMcpServer session, CancellationToken requestAborted)
- => session.RunAsync(requestAborted);
+ private static Task WriteJsonRpcErrorAsync(HttpContext context, string errorMessage, int statusCode, int errorCode = -32000)
+ {
+ var jsonRpcError = new JsonRpcError
+ {
+ Error = new()
+ {
+ Code = errorCode,
+ Message = errorMessage,
+ },
+ };
+ return Results.Json(jsonRpcError, s_errorTypeInfo, statusCode: statusCode).ExecuteAsync(context);
+ }
- private static string MakeNewSessionId()
+ internal static void InitializeSseResponse(HttpContext context)
+ {
+ context.Response.Headers.ContentType = "text/event-stream";
+ context.Response.Headers.CacheControl = "no-cache,no-store";
+
+ // Make sure we disable all response buffering for SSE.
+ context.Response.Headers.ContentEncoding = "identity";
+ context.Features.GetRequiredFeature().DisableBuffering();
+ }
+
+ internal static string MakeNewSessionId()
{
- // 128 bits
Span buffer = stackalloc byte[16];
RandomNumberGenerator.Fill(buffer);
return WebEncoders.Base64UrlEncode(buffer);
}
+
+ internal static Task RunSessionAsync(HttpContext httpContext, IMcpServer session, CancellationToken requestAborted)
+ => session.RunAsync(requestAborted);
+
+ private static JsonTypeInfo GetRequiredJsonTypeInfo() => (JsonTypeInfo)McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(T));
+
+ private sealed class HttpDuplexPipe(HttpContext context) : IDuplexPipe
+ {
+ public PipeReader Input => context.Request.BodyReader;
+ public PipeWriter Output => context.Response.BodyWriter;
+ }
}
diff --git a/src/ModelContextProtocol/Client/McpClient.cs b/src/ModelContextProtocol/Client/McpClient.cs
index 50c0a195f..2bcfacaa8 100644
--- a/src/ModelContextProtocol/Client/McpClient.cs
+++ b/src/ModelContextProtocol/Client/McpClient.cs
@@ -5,6 +5,7 @@
using ModelContextProtocol.Shared;
using ModelContextProtocol.Utils.Json;
using System.Text.Json;
+using System.Threading;
namespace ModelContextProtocol.Client;
@@ -59,7 +60,7 @@ public McpClient(IClientTransport clientTransport, McpClientOptions? options, IL
RequestHandlers.Set(
RequestMethods.SamplingCreateMessage,
- (request, cancellationToken) => samplingHandler(
+ (request, _, cancellationToken) => samplingHandler(
request,
request?.Meta?.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance,
cancellationToken),
@@ -76,7 +77,7 @@ public McpClient(IClientTransport clientTransport, McpClientOptions? options, IL
RequestHandlers.Set(
RequestMethods.RootsList,
- rootsHandler,
+ (request, _, cancellationToken) => rootsHandler(request, cancellationToken),
McpJsonUtilities.JsonContext.Default.ListRootsRequestParams,
McpJsonUtilities.JsonContext.Default.ListRootsResult);
}
diff --git a/src/ModelContextProtocol/Diagnostics.cs b/src/ModelContextProtocol/Diagnostics.cs
index a5d293f0c..52d82e200 100644
--- a/src/ModelContextProtocol/Diagnostics.cs
+++ b/src/ModelContextProtocol/Diagnostics.cs
@@ -38,7 +38,7 @@ internal static Histogram CreateDurationHistogram(string name, string de
};
#endif
- internal static ActivityContext ExtractActivityContext(this DistributedContextPropagator propagator, IJsonRpcMessage message)
+ internal static ActivityContext ExtractActivityContext(this DistributedContextPropagator propagator, JsonRpcMessage message)
{
propagator.ExtractTraceIdAndState(message, ExtractContext, out var traceparent, out var tracestate);
ActivityContext.TryParse(traceparent, tracestate, true, out var activityContext);
@@ -71,7 +71,7 @@ private static void ExtractContext(object? message, string fieldName, out string
}
}
- internal static void InjectActivityContext(this DistributedContextPropagator propagator, Activity? activity, IJsonRpcMessage message)
+ internal static void InjectActivityContext(this DistributedContextPropagator propagator, Activity? activity, JsonRpcMessage message)
{
// noop if activity is null
propagator.Inject(activity, message, InjectContext);
@@ -100,7 +100,7 @@ private static void InjectContext(object? message, string key, string value)
}
}
- internal static bool ShouldInstrumentMessage(IJsonRpcMessage message) =>
+ internal static bool ShouldInstrumentMessage(JsonRpcMessage message) =>
ActivitySource.HasListeners() &&
message switch
{
diff --git a/src/ModelContextProtocol/IMcpEndpoint.cs b/src/ModelContextProtocol/IMcpEndpoint.cs
index a3b941cce..dcfdf687c 100644
--- a/src/ModelContextProtocol/IMcpEndpoint.cs
+++ b/src/ModelContextProtocol/IMcpEndpoint.cs
@@ -45,7 +45,7 @@ public interface IMcpEndpoint : IAsyncDisposable
/// Sends a JSON-RPC message to the connected endpoint.
///
///
- /// The JSON-RPC message to send. This can be any type that implements IJsonRpcMessage, such as
+ /// The JSON-RPC message to send. This can be any type that implements JsonRpcMessage, such as
/// JsonRpcRequest, JsonRpcResponse, JsonRpcNotification, or JsonRpcError.
///
/// The to monitor for cancellation requests. The default is .
@@ -63,7 +63,7 @@ public interface IMcpEndpoint : IAsyncDisposable
/// The method will serialize the message and transmit it using the underlying transport mechanism.
///
///
- Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default);
+ Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default);
/// Registers a handler to be invoked when a notification for the specified method is received.
/// The notification method.
diff --git a/src/ModelContextProtocol/McpEndpointExtensions.cs b/src/ModelContextProtocol/McpEndpointExtensions.cs
index 8c266b29f..3969fa336 100644
--- a/src/ModelContextProtocol/McpEndpointExtensions.cs
+++ b/src/ModelContextProtocol/McpEndpointExtensions.cs
@@ -41,7 +41,7 @@ public static Task SendRequestAsync(
string method,
TParameters parameters,
JsonSerializerOptions? serializerOptions = null,
- RequestId? requestId = null,
+ RequestId requestId = default,
CancellationToken cancellationToken = default)
where TResult : notnull
{
@@ -72,7 +72,7 @@ internal static async Task SendRequestAsync(
TParameters parameters,
JsonTypeInfo parametersTypeInfo,
JsonTypeInfo resultTypeInfo,
- RequestId? requestId = null,
+ RequestId requestId = default,
CancellationToken cancellationToken = default)
where TResult : notnull
{
@@ -83,15 +83,11 @@ internal static async Task SendRequestAsync(
JsonRpcRequest jsonRpcRequest = new()
{
+ Id = requestId,
Method = method,
Params = JsonSerializer.SerializeToNode(parameters, parametersTypeInfo),
};
- if (requestId is { } id)
- {
- jsonRpcRequest.Id = id;
- }
-
JsonRpcResponse response = await endpoint.SendRequestAsync(jsonRpcRequest, cancellationToken).ConfigureAwait(false);
return JsonSerializer.Deserialize(response.Result, resultTypeInfo) ?? throw new JsonException("Unexpected JSON result in response.");
}
diff --git a/src/ModelContextProtocol/Protocol/Messages/IJsonRpcMessage.cs b/src/ModelContextProtocol/Protocol/Messages/IJsonRpcMessage.cs
deleted file mode 100644
index 9880247fb..000000000
--- a/src/ModelContextProtocol/Protocol/Messages/IJsonRpcMessage.cs
+++ /dev/null
@@ -1,21 +0,0 @@
-using ModelContextProtocol.Utils.Json;
-using System.Text.Json.Serialization;
-
-namespace ModelContextProtocol.Protocol.Messages;
-
-///
-/// Represents any JSON-RPC message used in the Model Context Protocol (MCP).
-///
-///
-/// This interface serves as the foundation for all message types in the JSON-RPC 2.0 protocol
-/// used by MCP, including requests, responses, notifications, and errors. JSON-RPC is a stateless,
-/// lightweight remote procedure call (RPC) protocol that uses JSON as its data format.
-///
-[JsonConverter(typeof(JsonRpcMessageConverter))]
-public interface IJsonRpcMessage
-{
- ///
- /// Gets the JSON-RPC protocol version used.
- ///
- string JsonRpc { get; }
-}
diff --git a/src/ModelContextProtocol/Protocol/Messages/JsonRpcError.cs b/src/ModelContextProtocol/Protocol/Messages/JsonRpcError.cs
index 30ace711d..1109e647b 100644
--- a/src/ModelContextProtocol/Protocol/Messages/JsonRpcError.cs
+++ b/src/ModelContextProtocol/Protocol/Messages/JsonRpcError.cs
@@ -16,16 +16,8 @@ namespace ModelContextProtocol.Protocol.Messages;
/// and optional additional data to provide more context about the error.
///
///
-public record JsonRpcError : IJsonRpcMessageWithId
+public class JsonRpcError : JsonRpcMessageWithId
{
- ///
- [JsonPropertyName("jsonrpc")]
- public string JsonRpc { get; init; } = "2.0";
-
- ///
- [JsonPropertyName("id")]
- public required RequestId Id { get; init; }
-
///
/// Gets detailed error information for the failed request, containing an error code,
/// message, and optional additional data
diff --git a/src/ModelContextProtocol/Protocol/Messages/JsonRpcMessage.cs b/src/ModelContextProtocol/Protocol/Messages/JsonRpcMessage.cs
new file mode 100644
index 000000000..d307598f0
--- /dev/null
+++ b/src/ModelContextProtocol/Protocol/Messages/JsonRpcMessage.cs
@@ -0,0 +1,35 @@
+using ModelContextProtocol.Protocol.Transport;
+using ModelContextProtocol.Utils.Json;
+using System.Text.Json.Serialization;
+
+namespace ModelContextProtocol.Protocol.Messages;
+
+///
+/// Represents any JSON-RPC message used in the Model Context Protocol (MCP).
+///
+///
+/// This interface serves as the foundation for all message types in the JSON-RPC 2.0 protocol
+/// used by MCP, including requests, responses, notifications, and errors. JSON-RPC is a stateless,
+/// lightweight remote procedure call (RPC) protocol that uses JSON as its data format.
+///
+[JsonConverter(typeof(JsonRpcMessageConverter))]
+public abstract class JsonRpcMessage
+{
+ ///
+ /// Gets the JSON-RPC protocol version used.
+ ///
+ ///
+ [JsonPropertyName("jsonrpc")]
+ public string JsonRpc { get; init; } = "2.0";
+
+ ///
+ /// Gets or sets the transport the was received on or should be sent over.
+ ///
+ ///
+ /// This is used to support the Streamable HTTP transport where the specification states that the server
+ /// SHOULD include JSON-RPC responses in the HTTP response body for the POST request containing
+ /// the corresponding JSON-RPC request. It may be for other transports.
+ ///
+ [JsonIgnore]
+ public ITransport? RelatedTransport { get; set; }
+}
diff --git a/src/ModelContextProtocol/Protocol/Messages/IJsonRpcMessageWithId.cs b/src/ModelContextProtocol/Protocol/Messages/JsonRpcMessageWithId.cs
similarity index 83%
rename from src/ModelContextProtocol/Protocol/Messages/IJsonRpcMessageWithId.cs
rename to src/ModelContextProtocol/Protocol/Messages/JsonRpcMessageWithId.cs
index 4e2d84f4a..f095c50bc 100644
--- a/src/ModelContextProtocol/Protocol/Messages/IJsonRpcMessageWithId.cs
+++ b/src/ModelContextProtocol/Protocol/Messages/JsonRpcMessageWithId.cs
@@ -1,3 +1,5 @@
+using System.Text.Json.Serialization;
+
namespace ModelContextProtocol.Protocol.Messages;
///
@@ -10,7 +12,7 @@ namespace ModelContextProtocol.Protocol.Messages;
/// The ID is used to correlate requests with their responses, allowing asynchronous
/// communication where multiple requests can be sent without waiting for responses.
///
-public interface IJsonRpcMessageWithId : IJsonRpcMessage
+public abstract class JsonRpcMessageWithId : JsonRpcMessage
{
///
/// Gets the message identifier.
@@ -18,5 +20,6 @@ public interface IJsonRpcMessageWithId : IJsonRpcMessage
///
/// Each ID is expected to be unique within the context of a given session.
///
- RequestId Id { get; }
+ [JsonPropertyName("id")]
+ public RequestId Id { get; init; }
}
diff --git a/src/ModelContextProtocol/Protocol/Messages/JsonRpcNotification.cs b/src/ModelContextProtocol/Protocol/Messages/JsonRpcNotification.cs
index ea0c35cbb..bdbd8d46f 100644
--- a/src/ModelContextProtocol/Protocol/Messages/JsonRpcNotification.cs
+++ b/src/ModelContextProtocol/Protocol/Messages/JsonRpcNotification.cs
@@ -11,12 +11,8 @@ namespace ModelContextProtocol.Protocol.Messages;
/// They are useful for one-way communication, such as log notifications and progress updates.
/// Unlike requests, notifications do not include an ID field, since there will be no response to match with it.
///
-public record JsonRpcNotification : IJsonRpcMessage
+public class JsonRpcNotification : JsonRpcMessage
{
- ///
- [JsonPropertyName("jsonrpc")]
- public string JsonRpc { get; init; } = "2.0";
-
///
/// Gets or sets the name of the notification method.
///
diff --git a/src/ModelContextProtocol/Protocol/Messages/JsonRpcRequest.cs b/src/ModelContextProtocol/Protocol/Messages/JsonRpcRequest.cs
index 275503912..ff7a45044 100644
--- a/src/ModelContextProtocol/Protocol/Messages/JsonRpcRequest.cs
+++ b/src/ModelContextProtocol/Protocol/Messages/JsonRpcRequest.cs
@@ -14,16 +14,8 @@ namespace ModelContextProtocol.Protocol.Messages;
/// and return either a with the result, or a
/// if the method execution fails.
///
-public record JsonRpcRequest : IJsonRpcMessageWithId
+public class JsonRpcRequest : JsonRpcMessageWithId
{
- ///
- [JsonPropertyName("jsonrpc")]
- public string JsonRpc { get; init; } = "2.0";
-
- ///
- [JsonPropertyName("id")]
- public RequestId Id { get; set; }
-
///
/// Name of the method to invoke.
///
@@ -35,4 +27,15 @@ public record JsonRpcRequest : IJsonRpcMessageWithId
///
[JsonPropertyName("params")]
public JsonNode? Params { get; init; }
+
+ internal JsonRpcRequest WithId(RequestId id)
+ {
+ return new JsonRpcRequest
+ {
+ JsonRpc = JsonRpc,
+ Id = id,
+ Method = Method,
+ Params = Params
+ };
+ }
}
diff --git a/src/ModelContextProtocol/Protocol/Messages/JsonRpcResponse.cs b/src/ModelContextProtocol/Protocol/Messages/JsonRpcResponse.cs
index c10e0fb55..01eef51e9 100644
--- a/src/ModelContextProtocol/Protocol/Messages/JsonRpcResponse.cs
+++ b/src/ModelContextProtocol/Protocol/Messages/JsonRpcResponse.cs
@@ -16,16 +16,8 @@ namespace ModelContextProtocol.Protocol.Messages;
/// This class represents a successful response with a result. For error responses, see .
///
///
-public record JsonRpcResponse : IJsonRpcMessageWithId
+public class JsonRpcResponse : JsonRpcMessageWithId
{
- ///
- [JsonPropertyName("jsonrpc")]
- public string JsonRpc { get; init; } = "2.0";
-
- ///
- [JsonPropertyName("id")]
- public required RequestId Id { get; init; }
-
///
/// Gets the result of the method invocation.
///
diff --git a/src/ModelContextProtocol/Protocol/Transport/ITransport.cs b/src/ModelContextProtocol/Protocol/Transport/ITransport.cs
index de4760488..b8cb7e3bb 100644
--- a/src/ModelContextProtocol/Protocol/Transport/ITransport.cs
+++ b/src/ModelContextProtocol/Protocol/Transport/ITransport.cs
@@ -26,22 +26,6 @@ namespace ModelContextProtocol.Protocol.Transport;
///
public interface ITransport : IAsyncDisposable
{
- ///
- /// Gets whether the transport is currently connected and able to send/receive messages.
- ///
- ///
- ///
- /// The property indicates the current state of the transport connection.
- /// When , the transport is ready to send and receive messages. When ,
- /// any attempt to send messages will typically result in exceptions being thrown.
- ///
- ///
- /// The property transitions to when the transport successfully establishes a connection,
- /// and transitions to when the transport is disposed or encounters a connection error.
- ///
- ///
- bool IsConnected { get; }
-
///
/// Gets a channel reader for receiving messages from the transport.
///
@@ -56,7 +40,7 @@ public interface ITransport : IAsyncDisposable
/// any already transmitted messages are consumed.
///
///
- ChannelReader MessageReader { get; }
+ ChannelReader MessageReader { get; }
///
/// Sends a JSON-RPC message through the transport.
@@ -76,5 +60,5 @@ public interface ITransport : IAsyncDisposable
/// rather than accessing this method directly.
///
///
- Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default);
+ Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default);
}
diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs
index e2c7cc5cc..5d952f8a6 100644
--- a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs
+++ b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs
@@ -68,21 +68,21 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default)
///
public override async Task SendMessageAsync(
- IJsonRpcMessage message,
+ JsonRpcMessage message,
CancellationToken cancellationToken = default)
{
if (_messageEndpoint == null)
throw new InvalidOperationException("Transport not connected");
using var content = new StringContent(
- JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.IJsonRpcMessage),
+ JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage),
Encoding.UTF8,
"application/json"
);
string messageId = "(no id)";
- if (message is IJsonRpcMessageWithId messageWithId)
+ if (message is JsonRpcMessageWithId messageWithId)
{
messageId = messageWithId.Id.ToString();
}
@@ -213,17 +213,20 @@ private async Task ReceiveMessagesAsync(CancellationToken cancellationToken)
}
}
}
- catch when (cancellationToken.IsCancellationRequested)
- {
- // Normal shutdown
- LogTransportReadMessagesCancelled(Name);
- _connectionEstablished.TrySetCanceled(cancellationToken);
- }
catch (Exception ex)
{
- LogTransportReadMessagesFailed(Name, ex);
- _connectionEstablished.TrySetException(ex);
- throw;
+ if (cancellationToken.IsCancellationRequested)
+ {
+ // Normal shutdown
+ LogTransportReadMessagesCancelled(Name);
+ _connectionEstablished.TrySetCanceled(cancellationToken);
+ }
+ else
+ {
+ LogTransportReadMessagesFailed(Name, ex);
+ _connectionEstablished.TrySetException(ex);
+ throw;
+ }
}
finally
{
@@ -241,7 +244,7 @@ private async Task ProcessSseMessage(string data, CancellationToken cancellation
try
{
- var message = JsonSerializer.Deserialize(data, McpJsonUtilities.JsonContext.Default.IJsonRpcMessage);
+ var message = JsonSerializer.Deserialize(data, McpJsonUtilities.JsonContext.Default.JsonRpcMessage);
if (message == null)
{
LogTransportMessageParseUnexpectedTypeSensitive(Name, data);
@@ -249,7 +252,7 @@ private async Task ProcessSseMessage(string data, CancellationToken cancellation
}
string messageId = "(no id)";
- if (message is IJsonRpcMessageWithId messageWithId)
+ if (message is JsonRpcMessageWithId messageWithId)
{
messageId = messageWithId.Id.ToString();
}
diff --git a/src/ModelContextProtocol/Protocol/Transport/SseResponseStreamTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseResponseStreamTransport.cs
index da86fd16a..f830862d1 100644
--- a/src/ModelContextProtocol/Protocol/Transport/SseResponseStreamTransport.cs
+++ b/src/ModelContextProtocol/Protocol/Transport/SseResponseStreamTransport.cs
@@ -1,10 +1,6 @@
-using System.Text;
-using System.Buffers;
-using System.Net.ServerSentEvents;
-using System.Text.Json;
-using System.Threading.Channels;
using ModelContextProtocol.Protocol.Messages;
-using ModelContextProtocol.Utils.Json;
+using ModelContextProtocol.Utils;
+using System.Threading.Channels;
namespace ModelContextProtocol.Protocol.Transport;
@@ -22,16 +18,22 @@ namespace ModelContextProtocol.Protocol.Transport;
/// such as when streaming completion results or providing progress updates during long-running operations.
///
///
-public sealed class SseResponseStreamTransport(Stream sseResponseStream, string messageEndpoint = "/message") : ITransport
+/// The response stream to write MCP JSON-RPC messages as SSE events to.
+///
+/// The relative or absolute URI the client should use to post MCP JSON-RPC messages for this session.
+/// These messages should be passed to .
+/// Defaults to "/message".
+///
+public sealed class SseResponseStreamTransport(Stream sseResponseStream, string? messageEndpoint = "/message") : ITransport
{
- private readonly Channel _incomingChannel = CreateBoundedChannel();
- private readonly Channel> _outgoingSseChannel = CreateBoundedChannel>();
-
- private Task? _sseWriteTask;
- private Utf8JsonWriter? _jsonWriter;
+ private readonly SseWriter _sseWriter = new(messageEndpoint);
+ private readonly Channel _incomingChannel = Channel.CreateBounded(new BoundedChannelOptions(1)
+ {
+ SingleReader = true,
+ SingleWriter = false,
+ });
- ///
- public bool IsConnected { get; private set; }
+ private bool _isConnected;
///
/// Starts the transport and writes the JSON-RPC messages sent via
@@ -39,54 +41,27 @@ public sealed class SseResponseStreamTransport(Stream sseResponseStream, string
///
/// The to monitor for cancellation requests. The default is .
/// A task representing the send loop that writes JSON-RPC messages to the SSE response stream.
- public Task RunAsync(CancellationToken cancellationToken)
+ public async Task RunAsync(CancellationToken cancellationToken)
{
- // The very first SSE event isn't really an IJsonRpcMessage, but there's no API to write a single item of a different type,
- // so we fib and special-case the "endpoint" event type in the formatter.
- if (!_outgoingSseChannel.Writer.TryWrite(new SseItem(null, "endpoint")))
- {
- throw new InvalidOperationException($"You must call ${nameof(RunAsync)} before calling ${nameof(SendMessageAsync)}.");
- }
-
- IsConnected = true;
-
- var sseItems = _outgoingSseChannel.Reader.ReadAllAsync(cancellationToken);
- return _sseWriteTask = SseFormatter.WriteAsync(sseItems, sseResponseStream, WriteJsonRpcMessageToBuffer, cancellationToken);
- }
-
- private void WriteJsonRpcMessageToBuffer(SseItem item, IBufferWriter writer)
- {
- if (item.EventType == "endpoint")
- {
- writer.Write(Encoding.UTF8.GetBytes(messageEndpoint));
- return;
- }
-
- JsonSerializer.Serialize(GetUtf8JsonWriter(writer), item.Data, McpJsonUtilities.JsonContext.Default.IJsonRpcMessage!);
+ _isConnected = true;
+ await _sseWriter.WriteAllAsync(sseResponseStream, cancellationToken).ConfigureAwait(false);
}
///
- public ChannelReader MessageReader => _incomingChannel.Reader;
+ public ChannelReader MessageReader => _incomingChannel.Reader;
///
- public ValueTask DisposeAsync()
+ public async ValueTask DisposeAsync()
{
- IsConnected = false;
+ _isConnected = false;
_incomingChannel.Writer.TryComplete();
- _outgoingSseChannel.Writer.TryComplete();
- return new ValueTask(_sseWriteTask ?? Task.CompletedTask);
+ await _sseWriter.DisposeAsync().ConfigureAwait(false);
}
///
- public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
+ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default)
{
- if (!IsConnected)
- {
- throw new InvalidOperationException($"Transport is not connected. Make sure to call {nameof(RunAsync)} first.");
- }
-
- // Emit redundant "event: message" lines for better compatibility with other SDKs.
- await _outgoingSseChannel.Writer.WriteAsync(new SseItem(message, SseParser.EventTypeDefault), cancellationToken).ConfigureAwait(false);
+ await _sseWriter.SendMessageAsync(message, cancellationToken).ConfigureAwait(false);
}
///
@@ -111,34 +86,15 @@ public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken ca
/// sequencing of operations in the transport lifecycle.
///
///
- public async Task OnMessageReceivedAsync(IJsonRpcMessage message, CancellationToken cancellationToken)
+ public async Task OnMessageReceivedAsync(JsonRpcMessage message, CancellationToken cancellationToken)
{
- if (!IsConnected)
+ Throw.IfNull(message);
+
+ if (!_isConnected)
{
throw new InvalidOperationException($"Transport is not connected. Make sure to call {nameof(RunAsync)} first.");
}
await _incomingChannel.Writer.WriteAsync(message, cancellationToken).ConfigureAwait(false);
}
-
- private static Channel CreateBoundedChannel(int capacity = 1) =>
- Channel.CreateBounded(new BoundedChannelOptions(capacity)
- {
- SingleReader = true,
- SingleWriter = false,
- });
-
- private Utf8JsonWriter GetUtf8JsonWriter(IBufferWriter writer)
- {
- if (_jsonWriter is null)
- {
- _jsonWriter = new Utf8JsonWriter(writer);
- }
- else
- {
- _jsonWriter.Reset(writer);
- }
-
- return _jsonWriter;
- }
}
diff --git a/src/ModelContextProtocol/Protocol/Transport/SseWriter.cs b/src/ModelContextProtocol/Protocol/Transport/SseWriter.cs
new file mode 100644
index 000000000..a3eb0ce46
--- /dev/null
+++ b/src/ModelContextProtocol/Protocol/Transport/SseWriter.cs
@@ -0,0 +1,120 @@
+using ModelContextProtocol.Protocol.Messages;
+using ModelContextProtocol.Utils;
+using ModelContextProtocol.Utils.Json;
+using System.Buffers;
+using System.Net.ServerSentEvents;
+using System.Text;
+using System.Text.Json;
+using System.Threading.Channels;
+
+namespace ModelContextProtocol.Protocol.Transport;
+
+internal sealed class SseWriter(string? messageEndpoint = null, BoundedChannelOptions? channelOptions = null) : IAsyncDisposable
+{
+ private readonly Channel> _messages = Channel.CreateBounded>(channelOptions ?? new BoundedChannelOptions(1)
+ {
+ SingleReader = true,
+ SingleWriter = false,
+ });
+
+ private Utf8JsonWriter? _jsonWriter;
+ private Task? _writeTask;
+ private CancellationToken? _writeCancellationToken;
+
+ private readonly SemaphoreSlim _disposeLock = new(1, 1);
+ private bool _disposed;
+
+ public Func>, CancellationToken, IAsyncEnumerable>>? MessageFilter { get; set; }
+
+ public Task WriteAllAsync(Stream sseResponseStream, CancellationToken cancellationToken)
+ {
+ // When messageEndpoint is set, the very first SSE event isn't really an IJsonRpcMessage, but there's no API to write a single
+ // item of a different type, so we fib and special-case the "endpoint" event type in the formatter.
+ if (messageEndpoint is not null && !_messages.Writer.TryWrite(new SseItem(null, "endpoint")))
+ {
+ throw new InvalidOperationException("You must call RunAsync before calling SendMessageAsync.");
+ }
+
+ _writeCancellationToken = cancellationToken;
+
+ var messages = _messages.Reader.ReadAllAsync(cancellationToken);
+ if (MessageFilter is not null)
+ {
+ messages = MessageFilter(messages, cancellationToken);
+ }
+
+ _writeTask = SseFormatter.WriteAsync(messages, sseResponseStream, WriteJsonRpcMessageToBuffer, cancellationToken);
+ return _writeTask;
+ }
+
+ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default)
+ {
+ Throw.IfNull(message);
+
+ using var _ = await _disposeLock.LockAsync().ConfigureAwait(false);
+
+ if (_disposed)
+ {
+ // Don't throw an ODE, because this is disposed internally when the transport disconnects due to an abort
+ // or sending all the responses for the a give given Streamable HTTP POST request, so the user might not be at fault.
+ // There's precedence for no-oping here similar to writing to the response body of an aborted request in ASP.NET Core.
+ return;
+ }
+
+ // Emit redundant "event: message" lines for better compatibility with other SDKs.
+ await _messages.Writer.WriteAsync(new SseItem(message, SseParser.EventTypeDefault), cancellationToken).ConfigureAwait(false);
+ }
+
+ public async ValueTask DisposeAsync()
+ {
+ using var _ = await _disposeLock.LockAsync().ConfigureAwait(false);
+
+ if (_disposed)
+ {
+ return;
+ }
+
+ _messages.Writer.Complete();
+ try
+ {
+ if (_writeTask is not null)
+ {
+ await _writeTask.ConfigureAwait(false);
+ }
+ }
+ catch (OperationCanceledException) when (_writeCancellationToken?.IsCancellationRequested == true)
+ {
+ // Ignore exceptions caused by intentional cancellation during shutdown.
+ }
+ finally
+ {
+ _jsonWriter?.Dispose();
+ _disposed = true;
+ }
+ }
+
+ private void WriteJsonRpcMessageToBuffer(SseItem item, IBufferWriter writer)
+ {
+ if (item.EventType == "endpoint" && messageEndpoint is not null)
+ {
+ writer.Write(Encoding.UTF8.GetBytes(messageEndpoint));
+ return;
+ }
+
+ JsonSerializer.Serialize(GetUtf8JsonWriter(writer), item.Data, McpJsonUtilities.JsonContext.Default.JsonRpcMessage!);
+ }
+
+ private Utf8JsonWriter GetUtf8JsonWriter(IBufferWriter writer)
+ {
+ if (_jsonWriter is null)
+ {
+ _jsonWriter = new Utf8JsonWriter(writer);
+ }
+ else
+ {
+ _jsonWriter.Reset(writer);
+ }
+
+ return _jsonWriter;
+ }
+}
diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs
index aa92c9d4b..015b74913 100644
--- a/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs
+++ b/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs
@@ -33,7 +33,7 @@ public StdioClientSessionTransport(StdioClientTransportOptions options, Process
///
/// Thrown when the underlying process has exited or cannot be accessed.
///
- public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
+ public override async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default)
{
Exception? processException = null;
bool hasExited = false;
diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs
index 3c24416bb..6fcdf0a8b 100644
--- a/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs
+++ b/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs
@@ -57,7 +57,7 @@ public StreamClientSessionTransport(
}
///
- public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
+ public override async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default)
{
if (!IsConnected)
{
@@ -65,12 +65,12 @@ public override async Task SendMessageAsync(IJsonRpcMessage message, Cancellatio
}
string id = "(no id)";
- if (message is IJsonRpcMessageWithId messageWithId)
+ if (message is JsonRpcMessageWithId messageWithId)
{
id = messageWithId.Id.ToString();
}
- var json = JsonSerializer.Serialize(message, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage)));
+ var json = JsonSerializer.Serialize(message, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage)));
using var _ = await _sendLock.LockAsync(cancellationToken).ConfigureAwait(false);
try
@@ -143,11 +143,11 @@ private async Task ProcessMessageAsync(string line, CancellationToken cancellati
{
try
{
- var message = (IJsonRpcMessage?)JsonSerializer.Deserialize(line.AsSpan().Trim(), McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage)));
+ var message = (JsonRpcMessage?)JsonSerializer.Deserialize(line.AsSpan().Trim(), McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage)));
if (message != null)
{
string messageId = "(no id)";
- if (message is IJsonRpcMessageWithId messageWithId)
+ if (message is JsonRpcMessageWithId messageWithId)
{
messageId = messageWithId.Id.ToString();
}
diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs
index c94509955..acf18984a 100644
--- a/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs
+++ b/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs
@@ -56,7 +56,7 @@ public StreamServerTransport(Stream inputStream, Stream outputStream, string? se
}
///
- public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
+ public override async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default)
{
if (!IsConnected)
{
@@ -66,14 +66,14 @@ public override async Task SendMessageAsync(IJsonRpcMessage message, Cancellatio
using var _ = await _sendLock.LockAsync(cancellationToken).ConfigureAwait(false);
string id = "(no id)";
- if (message is IJsonRpcMessageWithId messageWithId)
+ if (message is JsonRpcMessageWithId messageWithId)
{
id = messageWithId.Id.ToString();
}
try
{
- await JsonSerializer.SerializeAsync(_outputStream, message, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage)), cancellationToken).ConfigureAwait(false);
+ await JsonSerializer.SerializeAsync(_outputStream, message, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage)), cancellationToken).ConfigureAwait(false);
await _outputStream.WriteAsync(s_newlineBytes, cancellationToken).ConfigureAwait(false);
await _outputStream.FlushAsync(cancellationToken).ConfigureAwait(false);
}
@@ -109,10 +109,10 @@ private async Task ReadMessagesAsync()
try
{
- if (JsonSerializer.Deserialize(line, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage))) is IJsonRpcMessage message)
+ if (JsonSerializer.Deserialize(line, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage))) is JsonRpcMessage message)
{
string messageId = "(no id)";
- if (message is IJsonRpcMessageWithId messageWithId)
+ if (message is JsonRpcMessageWithId messageWithId)
{
messageId = messageWithId.Id.ToString();
}
diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamableHttpPostTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamableHttpPostTransport.cs
new file mode 100644
index 000000000..2bd8f2784
--- /dev/null
+++ b/src/ModelContextProtocol/Protocol/Transport/StreamableHttpPostTransport.cs
@@ -0,0 +1,134 @@
+using ModelContextProtocol.Protocol.Messages;
+using ModelContextProtocol.Utils;
+using ModelContextProtocol.Utils.Json;
+using System.Buffers;
+using System.IO.Pipelines;
+using System.Net.ServerSentEvents;
+using System.Runtime.CompilerServices;
+using System.Text.Json;
+using System.Threading.Channels;
+
+namespace ModelContextProtocol.Protocol.Transport;
+
+///
+/// Handles processing the request/response body pairs for the Streamable HTTP transport.
+/// This is typically used via .
+///
+internal sealed class StreamableHttpPostTransport(ChannelWriter? incomingChannel, IDuplexPipe httpBodies) : ITransport
+{
+ private readonly SseWriter _sseWriter = new();
+ private readonly HashSet _pendingRequests = [];
+
+ // REVIEW: Should we introduce a send-only interface for RelatedTransport?
+ public ChannelReader MessageReader => throw new NotSupportedException("JsonRpcMessage.RelatedTransport should only be used for sending messages.");
+
+ ///
+ /// True, if data was written to the respond body.
+ /// False, if nothing was written because the request body did not contain any messages to respond to.
+ /// The HTTP application should typically respond with an empty "202 Accepted" response in this scenario.
+ ///
+ public async ValueTask RunAsync(CancellationToken cancellationToken)
+ {
+ // The incomingChannel is null to handle the potential client GET request to handle unsolicited JsonRpcMessages.
+ if (incomingChannel is not null)
+ {
+ await OnPostBodyReceivedAsync(httpBodies.Input, cancellationToken).ConfigureAwait(false);
+ }
+
+ if (_pendingRequests.Count == 0)
+ {
+ return false;
+ }
+
+ _sseWriter.MessageFilter = StopOnFinalResponseFilter;
+ await _sseWriter.WriteAllAsync(httpBodies.Output.AsStream(), cancellationToken).ConfigureAwait(false);
+ return true;
+ }
+
+ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default)
+ {
+ await _sseWriter.SendMessageAsync(message, cancellationToken).ConfigureAwait(false);
+ }
+
+ public async ValueTask DisposeAsync()
+ {
+ await _sseWriter.DisposeAsync().ConfigureAwait(false);
+ }
+
+ private async IAsyncEnumerable> StopOnFinalResponseFilter(IAsyncEnumerable> messages, [EnumeratorCancellation] CancellationToken cancellationToken)
+ {
+ await foreach (var message in messages.WithCancellation(cancellationToken))
+ {
+ yield return message;
+
+ if (message.Data is JsonRpcResponse response)
+ {
+ if (_pendingRequests.Remove(response.Id) && _pendingRequests.Count == 0)
+ {
+ // Complete the SSE response stream now that all pending requests have been processed.
+ break;
+ }
+ }
+ }
+ }
+
+ private async ValueTask OnPostBodyReceivedAsync(PipeReader streamableHttpRequestBody, CancellationToken cancellationToken)
+ {
+ if (!await IsJsonArrayAsync(streamableHttpRequestBody, cancellationToken).ConfigureAwait(false))
+ {
+ var message = await JsonSerializer.DeserializeAsync(streamableHttpRequestBody.AsStream(), McpJsonUtilities.JsonContext.Default.JsonRpcMessage, cancellationToken).ConfigureAwait(false);
+ await OnMessageReceivedAsync(message, cancellationToken).ConfigureAwait(false);
+ }
+ else
+ {
+ // Batched JSON-RPC message
+ var messages = JsonSerializer.DeserializeAsyncEnumerable(streamableHttpRequestBody.AsStream(), McpJsonUtilities.JsonContext.Default.JsonRpcMessage, cancellationToken).ConfigureAwait(false);
+ await foreach (var message in messages.WithCancellation(cancellationToken))
+ {
+ await OnMessageReceivedAsync(message, cancellationToken).ConfigureAwait(false);
+ }
+ }
+ }
+
+ private async ValueTask OnMessageReceivedAsync(JsonRpcMessage? message, CancellationToken cancellationToken)
+ {
+ if (message is null)
+ {
+ throw new InvalidOperationException("Received invalid null message.");
+ }
+
+ if (message is JsonRpcRequest request)
+ {
+ _pendingRequests.Add(request.Id);
+ }
+
+ message.RelatedTransport = this;
+
+ // Really an assertion. This doesn't get called when incomingChannel is null for GET requests.
+ Throw.IfNull(incomingChannel);
+ await incomingChannel.WriteAsync(message, cancellationToken).ConfigureAwait(false);
+ }
+
+ private async ValueTask IsJsonArrayAsync(PipeReader requestBody, CancellationToken cancellationToken)
+ {
+ // REVIEW: Should we bother trimming whitespace before checking for '['?
+ var firstCharacterResult = await requestBody.ReadAtLeastAsync(1, cancellationToken).ConfigureAwait(false);
+
+ try
+ {
+ if (firstCharacterResult.Buffer.Length == 0)
+ {
+ return false;
+ }
+
+ Span firstCharBuffer = stackalloc byte[1];
+ firstCharacterResult.Buffer.Slice(0, 1).CopyTo(firstCharBuffer);
+ return firstCharBuffer[0] == (byte)'[';
+ }
+ finally
+ {
+ // Never consume data when checking for '['. System.Text.Json still needs to consume it.
+ requestBody.AdvanceTo(firstCharacterResult.Buffer.Start);
+ }
+ }
+}
diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamableHttpServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamableHttpServerTransport.cs
new file mode 100644
index 000000000..42e3ff70b
--- /dev/null
+++ b/src/ModelContextProtocol/Protocol/Transport/StreamableHttpServerTransport.cs
@@ -0,0 +1,99 @@
+using ModelContextProtocol.Protocol.Messages;
+using System.IO.Pipelines;
+using System.Threading.Channels;
+
+namespace ModelContextProtocol.Protocol.Transport;
+
+///
+/// Provides an implementation using Server-Sent Events (SSE) for server-to-client communication.
+///
+///
+///
+/// This transport provides one-way communication from server to client using the SSE protocol over HTTP,
+/// while receiving client messages through a separate mechanism. It writes messages as
+/// SSE events to a response stream, typically associated with an HTTP response.
+///
+///
+/// This transport is used in scenarios where the server needs to push messages to the client in real-time,
+/// such as when streaming completion results or providing progress updates during long-running operations.
+///
+///
+public sealed class StreamableHttpServerTransport : ITransport
+{
+ // For JsonRpcMessages without a RelatedTransport, we don't want to block just because the client didn't make a GET request to handle unsolicited messages.
+ private readonly SseWriter _sseWriter = new(channelOptions: new BoundedChannelOptions(1)
+ {
+ SingleReader = true,
+ SingleWriter = false,
+ FullMode = BoundedChannelFullMode.DropOldest,
+ });
+ private readonly Channel _incomingChannel = Channel.CreateBounded(new BoundedChannelOptions(1)
+ {
+ SingleReader = true,
+ SingleWriter = false,
+ });
+ private readonly CancellationTokenSource _disposeCts = new();
+
+ private int _getRequestStarted;
+
+ ///
+ /// Handles an optional SSE GET request a client using the Streamable HTTP transport might make by
+ /// writing any unsolicited JSON-RPC messages sent via
+ /// to the SSE response stream until cancellation is requested or the transport is disposed.
+ ///
+ /// The response stream to write MCP JSON-RPC messages as SSE events to.
+ /// The to monitor for cancellation requests. The default is .
+ /// A task representing the send loop that writes JSON-RPC messages to the SSE response stream.
+ public async Task HandleGetRequest(Stream sseResponseStream, CancellationToken cancellationToken)
+ {
+ if (Interlocked.Exchange(ref _getRequestStarted, 1) == 1)
+ {
+ throw new InvalidOperationException("Session resumption is not yet supported. Please start a new session.");
+ }
+
+ using var getCts = CancellationTokenSource.CreateLinkedTokenSource(_disposeCts.Token, cancellationToken);
+ await _sseWriter.WriteAllAsync(sseResponseStream, getCts.Token).ConfigureAwait(false);
+ }
+
+ ///
+ /// Handles a Streamable HTTP POST request processing both the request body and response body ensuring that
+ /// and other correlated messages are sent back to the client directly in response
+ /// to the that initiated the message.
+ ///
+ /// The duplex pipe facilitates the reading and writing of HTTP request and response data.
+ /// This token allows for the operation to be canceled if needed.
+ ///
+ /// True, if data was written to the respond body.
+ /// False, if nothing was written because the request body did not contain any messages to respond to.
+ /// The HTTP application should typically respond with an empty "202 Accepted" response in this scenario.
+ ///
+ public async Task HandlePostRequest(IDuplexPipe httpBodies, CancellationToken cancellationToken)
+ {
+ using var postCts = CancellationTokenSource.CreateLinkedTokenSource(_disposeCts.Token, cancellationToken);
+ await using var postTransport = new StreamableHttpPostTransport(_incomingChannel.Writer, httpBodies);
+ return await postTransport.RunAsync(postCts.Token).ConfigureAwait(false);
+ }
+
+ ///
+ public ChannelReader MessageReader => _incomingChannel.Reader;
+
+ ///
+ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default)
+ {
+ await _sseWriter.SendMessageAsync(message, cancellationToken).ConfigureAwait(false);
+ }
+
+ ///
+ public async ValueTask DisposeAsync()
+ {
+ _disposeCts.Cancel();
+ try
+ {
+ await _sseWriter.DisposeAsync().ConfigureAwait(false);
+ }
+ finally
+ {
+ _disposeCts.Dispose();
+ }
+ }
+}
diff --git a/src/ModelContextProtocol/Protocol/Transport/TransportBase.cs b/src/ModelContextProtocol/Protocol/Transport/TransportBase.cs
index 0d496d0a6..af9cdaefd 100644
--- a/src/ModelContextProtocol/Protocol/Transport/TransportBase.cs
+++ b/src/ModelContextProtocol/Protocol/Transport/TransportBase.cs
@@ -15,13 +15,13 @@ namespace ModelContextProtocol.Protocol.Transport;
///
///
/// Custom transport implementations should inherit from this class and implement the abstract
-/// and methods
+/// and methods
/// to handle the specific transport mechanism being used.
///
///
public abstract partial class TransportBase : ITransport
{
- private readonly Channel _messageChannel;
+ private readonly Channel _messageChannel;
private readonly ILogger _logger;
private int _isConnected;
@@ -34,7 +34,7 @@ protected TransportBase(string name, ILoggerFactory? loggerFactory)
_logger = loggerFactory?.CreateLogger(GetType()) ?? NullLogger.Instance;
// Unbounded channel to prevent blocking on writes
- _messageChannel = Channel.CreateUnbounded(new UnboundedChannelOptions
+ _messageChannel = Channel.CreateUnbounded(new UnboundedChannelOptions
{
SingleReader = true,
SingleWriter = true,
@@ -56,10 +56,10 @@ protected TransportBase(string name, ILoggerFactory? loggerFactory)
public bool IsConnected => _isConnected == 1;
///
- public ChannelReader MessageReader => _messageChannel.Reader;
+ public ChannelReader MessageReader => _messageChannel.Reader;
///
- public abstract Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default);
+ public abstract Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default);
///
public abstract ValueTask DisposeAsync();
@@ -69,7 +69,7 @@ protected TransportBase(string name, ILoggerFactory? loggerFactory)
///
/// The message to write.
/// The to monitor for cancellation requests. The default is .
- protected async Task WriteMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
+ protected async Task WriteMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default)
{
if (!IsConnected)
{
diff --git a/src/ModelContextProtocol/Server/DestinationBoundMcpServer.cs b/src/ModelContextProtocol/Server/DestinationBoundMcpServer.cs
new file mode 100644
index 000000000..0d86480de
--- /dev/null
+++ b/src/ModelContextProtocol/Server/DestinationBoundMcpServer.cs
@@ -0,0 +1,37 @@
+using ModelContextProtocol.Protocol.Messages;
+using ModelContextProtocol.Protocol.Transport;
+using ModelContextProtocol.Protocol.Types;
+using System.Diagnostics;
+
+namespace ModelContextProtocol.Server;
+
+internal sealed class DestinationBoundMcpServer(McpServer server, ITransport? transport) : IMcpServer
+{
+ public string EndpointName => server.EndpointName;
+ public ClientCapabilities? ClientCapabilities => server.ClientCapabilities;
+ public Implementation? ClientInfo => server.ClientInfo;
+ public McpServerOptions ServerOptions => server.ServerOptions;
+ public IServiceProvider? Services => server.Services;
+ public LoggingLevel? LoggingLevel => server.LoggingLevel;
+
+ public ValueTask DisposeAsync() => server.DisposeAsync();
+
+ public IAsyncDisposable RegisterNotificationHandler(string method, Func handler) => server.RegisterNotificationHandler(method, handler);
+
+ // This will throw because the server must already be running for this class to be constructed, but it should give us a good Exception message.
+ public Task RunAsync(CancellationToken cancellationToken) => server.RunAsync(cancellationToken);
+
+ public Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default)
+ {
+ Debug.Assert(message.RelatedTransport is null);
+ message.RelatedTransport = transport;
+ return server.SendMessageAsync(message, cancellationToken);
+ }
+
+ public Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default)
+ {
+ Debug.Assert(request.RelatedTransport is null);
+ request.RelatedTransport = transport;
+ return server.SendRequestAsync(request, cancellationToken);
+ }
+}
diff --git a/src/ModelContextProtocol/Server/McpServer.cs b/src/ModelContextProtocol/Server/McpServer.cs
index 2f7b59a90..ae0e7afc5 100644
--- a/src/ModelContextProtocol/Server/McpServer.cs
+++ b/src/ModelContextProtocol/Server/McpServer.cs
@@ -7,8 +7,7 @@
using ModelContextProtocol.Utils;
using ModelContextProtocol.Utils.Json;
using System.Runtime.CompilerServices;
-
-#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
+using System.Text.Json.Serialization.Metadata;
namespace ModelContextProtocol.Server;
@@ -158,7 +157,7 @@ public override async ValueTask DisposeUnsynchronizedAsync()
private void SetPingHandler()
{
- RequestHandlers.Set(RequestMethods.Ping,
+ SetHandler(RequestMethods.Ping,
async (request, _) => new PingResult(),
McpJsonUtilities.JsonContext.Default.JsonNode,
McpJsonUtilities.JsonContext.Default.PingResult);
@@ -167,7 +166,7 @@ private void SetPingHandler()
private void SetInitializeHandler(McpServerOptions options)
{
RequestHandlers.Set(RequestMethods.Initialize,
- async (request, _) =>
+ async (request, _, _) =>
{
ClientCapabilities = request?.Capabilities ?? new();
ClientInfo = request?.ClientInfo;
@@ -201,9 +200,9 @@ private void SetCompletionHandler(McpServerOptions options)
$"but {nameof(CompletionsCapability.CompleteHandler)} was not specified.");
// This capability is not optional, so return an empty result if there is no handler.
- RequestHandlers.Set(
+ SetHandler(
RequestMethods.CompletionComplete,
- (request, cancellationToken) => InvokeHandlerAsync(completeHandler, request, cancellationToken),
+ completeHandler,
McpJsonUtilities.JsonContext.Default.CompleteRequestParams,
McpJsonUtilities.JsonContext.Default.CompleteResult);
}
@@ -228,22 +227,22 @@ private void SetResourcesHandler(McpServerOptions options)
listResourcesHandler ??= static async (_, _) => new ListResourcesResult();
- RequestHandlers.Set(
+ SetHandler(
RequestMethods.ResourcesList,
- (request, cancellationToken) => InvokeHandlerAsync(listResourcesHandler, request, cancellationToken),
+ listResourcesHandler,
McpJsonUtilities.JsonContext.Default.ListResourcesRequestParams,
McpJsonUtilities.JsonContext.Default.ListResourcesResult);
- RequestHandlers.Set(
+ SetHandler(
RequestMethods.ResourcesRead,
- (request, cancellationToken) => InvokeHandlerAsync(readResourceHandler, request, cancellationToken),
+ readResourceHandler,
McpJsonUtilities.JsonContext.Default.ReadResourceRequestParams,
McpJsonUtilities.JsonContext.Default.ReadResourceResult);
listResourceTemplatesHandler ??= static async (_, _) => new ListResourceTemplatesResult();
- RequestHandlers.Set(
+ SetHandler(
RequestMethods.ResourcesTemplatesList,
- (request, cancellationToken) => InvokeHandlerAsync(listResourceTemplatesHandler, request, cancellationToken),
+ listResourceTemplatesHandler,
McpJsonUtilities.JsonContext.Default.ListResourceTemplatesRequestParams,
McpJsonUtilities.JsonContext.Default.ListResourceTemplatesResult);
@@ -261,15 +260,15 @@ private void SetResourcesHandler(McpServerOptions options)
$"but {nameof(ResourcesCapability.SubscribeToResourcesHandler)} or {nameof(ResourcesCapability.UnsubscribeFromResourcesHandler)} was not specified.");
}
- RequestHandlers.Set(
+ SetHandler(
RequestMethods.ResourcesSubscribe,
- (request, cancellationToken) => InvokeHandlerAsync(subscribeHandler, request, cancellationToken),
+ subscribeHandler,
McpJsonUtilities.JsonContext.Default.SubscribeRequestParams,
McpJsonUtilities.JsonContext.Default.EmptyResult);
- RequestHandlers.Set(
+ SetHandler(
RequestMethods.ResourcesUnsubscribe,
- (request, cancellationToken) => InvokeHandlerAsync(unsubscribeHandler, request, cancellationToken),
+ unsubscribeHandler,
McpJsonUtilities.JsonContext.Default.UnsubscribeRequestParams,
McpJsonUtilities.JsonContext.Default.EmptyResult);
}
@@ -359,15 +358,15 @@ await originalListPromptsHandler(request, cancellationToken).ConfigureAwait(fals
}
}
- RequestHandlers.Set(
+ SetHandler(
RequestMethods.PromptsList,
- (request, cancellationToken) => InvokeHandlerAsync(listPromptsHandler, request, cancellationToken),
+ listPromptsHandler,
McpJsonUtilities.JsonContext.Default.ListPromptsRequestParams,
McpJsonUtilities.JsonContext.Default.ListPromptsResult);
- RequestHandlers.Set(
+ SetHandler(
RequestMethods.PromptsGet,
- (request, cancellationToken) => InvokeHandlerAsync(getPromptHandler, request, cancellationToken),
+ getPromptHandler,
McpJsonUtilities.JsonContext.Default.GetPromptRequestParams,
McpJsonUtilities.JsonContext.Default.GetPromptResult);
}
@@ -457,15 +456,15 @@ await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false)
}
}
- RequestHandlers.Set(
+ SetHandler(
RequestMethods.ToolsList,
- (request, cancellationToken) => InvokeHandlerAsync(listToolsHandler, request, cancellationToken),
+ listToolsHandler,
McpJsonUtilities.JsonContext.Default.ListToolsRequestParams,
McpJsonUtilities.JsonContext.Default.ListToolsResult);
- RequestHandlers.Set(
+ SetHandler(
RequestMethods.ToolsCall,
- (request, cancellationToken) => InvokeHandlerAsync(callToolHandler, request, cancellationToken),
+ callToolHandler,
McpJsonUtilities.JsonContext.Default.CallToolRequestParams,
McpJsonUtilities.JsonContext.Default.CallToolResponse);
}
@@ -478,7 +477,7 @@ private void SetSetLoggingLevelHandler(McpServerOptions options)
RequestHandlers.Set(
RequestMethods.LoggingSetLevel,
- (request, cancellationToken) =>
+ (request, destinationTransport, cancellationToken) =>
{
// Store the provided level.
if (request is not null)
@@ -494,7 +493,7 @@ private void SetSetLoggingLevelHandler(McpServerOptions options)
// If a handler was provided, now delegate to it.
if (setLoggingLevelHandler is not null)
{
- return InvokeHandlerAsync(setLoggingLevelHandler, request, cancellationToken);
+ return InvokeHandlerAsync(setLoggingLevelHandler, request, destinationTransport, cancellationToken);
}
// Otherwise, consider it handled.
@@ -507,11 +506,12 @@ private void SetSetLoggingLevelHandler(McpServerOptions options)
private ValueTask InvokeHandlerAsync(
Func, CancellationToken, ValueTask> handler,
TParams? args,
- CancellationToken cancellationToken)
+ ITransport? destinationTransport = null,
+ CancellationToken cancellationToken = default)
{
return _servicesScopePerRequest ?
InvokeScopedAsync(handler, args, cancellationToken) :
- handler(new(this) { Params = args }, cancellationToken);
+ handler(new(new DestinationBoundMcpServer(this, destinationTransport)) { Params = args }, cancellationToken);
async ValueTask InvokeScopedAsync(
Func, CancellationToken, ValueTask> handler,
@@ -522,7 +522,7 @@ async ValueTask InvokeScopedAsync(
try
{
return await handler(
- new RequestContext(this)
+ new RequestContext(new DestinationBoundMcpServer(this, destinationTransport))
{
Services = scope?.ServiceProvider ?? Services,
Params = args
@@ -539,6 +539,18 @@ async ValueTask InvokeScopedAsync(
}
}
+ private void SetHandler(
+ string method,
+ Func, CancellationToken, ValueTask> handler,
+ JsonTypeInfo requestTypeInfo,
+ JsonTypeInfo responseTypeInfo)
+ {
+ RequestHandlers.Set(method,
+ (request, destinationTransport, cancellationToken) =>
+ InvokeHandlerAsync(handler, request, destinationTransport, cancellationToken),
+ requestTypeInfo, responseTypeInfo);
+ }
+
/// Maps a to a .
internal static LoggingLevel ToLoggingLevel(LogLevel level) =>
level switch
@@ -551,4 +563,4 @@ internal static LoggingLevel ToLoggingLevel(LogLevel level) =>
LogLevel.Critical => Protocol.Types.LoggingLevel.Critical,
_ => Protocol.Types.LoggingLevel.Emergency,
};
-}
\ No newline at end of file
+}
diff --git a/src/ModelContextProtocol/Shared/McpEndpoint.cs b/src/ModelContextProtocol/Shared/McpEndpoint.cs
index 45566cefe..394ccaa7b 100644
--- a/src/ModelContextProtocol/Shared/McpEndpoint.cs
+++ b/src/ModelContextProtocol/Shared/McpEndpoint.cs
@@ -45,7 +45,7 @@ protected McpEndpoint(ILoggerFactory? loggerFactory = null)
public Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default)
=> GetSessionOrThrow().SendRequestAsync(request, cancellationToken);
- public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
+ public Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default)
=> GetSessionOrThrow().SendMessageAsync(message, cancellationToken);
public IAsyncDisposable RegisterNotificationHandler(string method, Func handler) =>
diff --git a/src/ModelContextProtocol/Shared/McpSession.cs b/src/ModelContextProtocol/Shared/McpSession.cs
index 1b51fc8fd..56f0674c9 100644
--- a/src/ModelContextProtocol/Shared/McpSession.cs
+++ b/src/ModelContextProtocol/Shared/McpSession.cs
@@ -39,7 +39,7 @@ internal sealed partial class McpSession : IDisposable
private readonly DistributedContextPropagator _propagator = DistributedContextPropagator.Current;
/// Collection of requests sent on this session and waiting for responses.
- private readonly ConcurrentDictionary> _pendingRequests = [];
+ private readonly ConcurrentDictionary> _pendingRequests = [];
///
/// Collection of requests received on this session and currently being handled. The value provides a
/// that can be used to request cancellation of the in-flight handler.
@@ -47,8 +47,9 @@ internal sealed partial class McpSession : IDisposable
private readonly ConcurrentDictionary _handlingRequests = new();
private readonly ILogger _logger;
- private readonly string _id = Guid.NewGuid().ToString("N");
- private long _nextRequestId;
+ // This _sessionId is solely used to identify the session in telemetry and logs.
+ private readonly string _sessionId = Guid.NewGuid().ToString("N");
+ private long _lastRequestId;
///
/// Initializes a new instance of the class.
@@ -74,6 +75,7 @@ public McpSession(
StdioClientSessionTransport or StdioServerTransport => "stdio",
StreamClientSessionTransport or StreamServerTransport => "stream",
SseClientSessionTransport or SseResponseStreamTransport => "sse",
+ StreamableHttpServerTransport or StreamableHttpPostTransport => "http",
_ => "unknownTransport"
};
@@ -102,10 +104,11 @@ public async Task ProcessMessagesAsync(CancellationToken cancellationToken)
{
LogMessageRead(EndpointName, message.GetType().Name);
+ // Fire and forget the message handling to avoid blocking the transport.
_ = ProcessMessageAsync();
async Task ProcessMessageAsync()
{
- IJsonRpcMessageWithId? messageWithId = message as IJsonRpcMessageWithId;
+ JsonRpcMessageWithId? messageWithId = message as JsonRpcMessageWithId;
CancellationTokenSource? combinedCts = null;
try
{
@@ -118,10 +121,8 @@ async Task ProcessMessageAsync()
_handlingRequests[messageWithId.Id] = combinedCts;
}
- // Fire and forget the message handling to avoid blocking the transport
- // If awaiting the task, the transport will not be able to read more messages,
- // which could lead to a deadlock if the handler sends a message back
-
+ // If we await the handler without yielding first, the transport may not be able to read more messages,
+ // which could lead to a deadlock if the handler sends a message back.
#if NET
await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding);
#else
@@ -155,18 +156,19 @@ ex is OperationCanceledException &&
Message = "An error occurred.",
};
- await _transport.SendMessageAsync(new JsonRpcError
+ await SendMessageAsync(new JsonRpcError
{
Id = request.Id,
JsonRpc = "2.0",
Error = detail,
+ RelatedTransport = request.RelatedTransport,
}, cancellationToken).ConfigureAwait(false);
}
else if (ex is not OperationCanceledException)
{
if (_logger.IsEnabled(LogLevel.Trace))
{
- LogMessageHandlerExceptionSensitive(EndpointName, message.GetType().Name, JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.IJsonRpcMessage), ex);
+ LogMessageHandlerExceptionSensitive(EndpointName, message.GetType().Name, JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage), ex);
}
else
{
@@ -200,7 +202,7 @@ await _transport.SendMessageAsync(new JsonRpcError
}
}
- private async Task HandleMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken)
+ private async Task HandleMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken)
{
Histogram durationMetric = _isServer ? s_serverOperationDuration : s_clientOperationDuration;
string method = GetMethodName(message);
@@ -235,7 +237,7 @@ private async Task HandleMessageAsync(IJsonRpcMessage message, CancellationToken
await HandleNotification(notification, cancellationToken).ConfigureAwait(false);
break;
- case IJsonRpcMessageWithId messageWithId:
+ case JsonRpcMessageWithId messageWithId:
HandleMessageWithId(message, messageWithId);
break;
@@ -279,7 +281,7 @@ private async Task HandleNotification(JsonRpcNotification notification, Cancella
await _notificationHandlers.InvokeHandlers(notification.Method, notification, cancellationToken).ConfigureAwait(false);
}
- private void HandleMessageWithId(IJsonRpcMessage message, IJsonRpcMessageWithId messageWithId)
+ private void HandleMessageWithId(JsonRpcMessage message, JsonRpcMessageWithId messageWithId)
{
if (_pendingRequests.TryRemove(messageWithId.Id, out var tcs))
{
@@ -303,17 +305,17 @@ private void HandleMessageWithId(IJsonRpcMessage message, IJsonRpcMessageWithId
JsonNode? result = await handler(request, cancellationToken).ConfigureAwait(false);
LogRequestHandlerCompleted(EndpointName, request.Method);
- await _transport.SendMessageAsync(new JsonRpcResponse
+ await SendMessageAsync(new JsonRpcResponse
{
Id = request.Id,
- JsonRpc = "2.0",
- Result = result
+ Result = result,
+ RelatedTransport = request.RelatedTransport,
}, cancellationToken).ConfigureAwait(false);
return result;
}
- private CancellationTokenRegistration RegisterCancellation(CancellationToken cancellationToken, RequestId requestId)
+ private CancellationTokenRegistration RegisterCancellation(CancellationToken cancellationToken, JsonRpcRequest request)
{
if (!cancellationToken.CanBeCanceled)
{
@@ -322,13 +324,14 @@ private CancellationTokenRegistration RegisterCancellation(CancellationToken can
return cancellationToken.Register(static objState =>
{
- var state = (Tuple)objState!;
+ var state = (Tuple)objState!;
_ = state.Item1.SendMessageAsync(new JsonRpcNotification
{
Method = NotificationMethods.CancelledNotification,
- Params = JsonSerializer.SerializeToNode(new CancelledNotification { RequestId = state.Item2 }, McpJsonUtilities.JsonContext.Default.CancelledNotification)
+ Params = JsonSerializer.SerializeToNode(new CancelledNotification { RequestId = state.Item2.Id }, McpJsonUtilities.JsonContext.Default.CancelledNotification),
+ RelatedTransport = state.Item2.RelatedTransport,
});
- }, Tuple.Create(this, requestId));
+ }, Tuple.Create(this, request));
}
public IAsyncDisposable RegisterNotificationHandler(string method, Func handler)
@@ -349,11 +352,6 @@ public IAsyncDisposable RegisterNotificationHandler(string method, FuncA task containing the server's response.
public async Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken)
{
- if (!_transport.IsConnected)
- {
- throw new InvalidOperationException("Transport is not connected");
- }
-
cancellationToken.ThrowIfCancellationRequested();
Histogram durationMetric = _isServer ? s_serverOperationDuration : s_clientOperationDuration;
@@ -367,7 +365,7 @@ public async Task SendRequestAsync(JsonRpcRequest request, Canc
// Set request ID
if (request.Id.Id is null)
{
- request.Id = new RequestId($"{_id}-{Interlocked.Increment(ref _nextRequestId)}");
+ request = request.WithId(new RequestId(Interlocked.Increment(ref _lastRequestId)));
}
_propagator.InjectActivityContext(activity, request);
@@ -375,7 +373,7 @@ public async Task SendRequestAsync(JsonRpcRequest request, Canc
TagList tags = default;
bool addTags = activity is { IsAllDataRequested: true } || startingTimestamp is not null;
- var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
_pendingRequests[request.Id] = tcs;
try
{
@@ -386,21 +384,21 @@ public async Task SendRequestAsync(JsonRpcRequest request, Canc
if (_logger.IsEnabled(LogLevel.Trace))
{
- LogSendingRequestSensitive(EndpointName, request.Method, JsonSerializer.Serialize(request, McpJsonUtilities.JsonContext.Default.IJsonRpcMessage));
+ LogSendingRequestSensitive(EndpointName, request.Method, JsonSerializer.Serialize(request, McpJsonUtilities.JsonContext.Default.JsonRpcMessage));
}
else
{
LogSendingRequest(EndpointName, request.Method);
}
- await _transport.SendMessageAsync(request, cancellationToken).ConfigureAwait(false);
+ await SendToRelatedTransportAsync(request, cancellationToken).ConfigureAwait(false);
// Now that the request has been sent, register for cancellation. If we registered before,
// a cancellation request could arrive before the server knew about that request ID, in which
// case the server could ignore it.
LogRequestSentAwaitingResponse(EndpointName, request.Method, request.Id);
- IJsonRpcMessage? response;
- using (var registration = RegisterCancellation(cancellationToken, request.Id))
+ JsonRpcMessage? response;
+ using (var registration = RegisterCancellation(cancellationToken, request))
{
response = await tcs.Task.WaitAsync(cancellationToken).ConfigureAwait(false);
}
@@ -446,15 +444,10 @@ public async Task SendRequestAsync(JsonRpcRequest request, Canc
}
}
- public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
+ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default)
{
Throw.IfNull(message);
- if (!_transport.IsConnected)
- {
- throw new InvalidOperationException("Transport is not connected");
- }
-
cancellationToken.ThrowIfCancellationRequested();
Histogram durationMetric = _isServer ? s_serverOperationDuration : s_clientOperationDuration;
@@ -480,14 +473,14 @@ public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken ca
if (_logger.IsEnabled(LogLevel.Trace))
{
- LogSendingMessageSensitive(EndpointName, JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.IJsonRpcMessage));
+ LogSendingMessageSensitive(EndpointName, JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage));
}
else
{
LogSendingMessage(EndpointName);
}
- await _transport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false);
+ await SendToRelatedTransportAsync(message, cancellationToken).ConfigureAwait(false);
// If the sent notification was a cancellation notification, cancel the pending request's await, as either the
// server won't be sending a response, or per the specification, the response should be ignored. There are inherent
@@ -510,6 +503,12 @@ public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken ca
}
}
+ // The JsonRpcMessage should be sent over the RelatedTransport if set. This is used to support the
+ // Streamable HTTP transport where the specification states that the server SHOULD include JSON-RPC responses in
+ // the HTTP response body for the POST request containing the corresponding JSON-RPC request.
+ private Task SendToRelatedTransportAsync(JsonRpcMessage message, CancellationToken cancellationToken)
+ => (message.RelatedTransport ?? _transport).SendMessageAsync(message, cancellationToken);
+
private static CancelledNotification? GetCancelledNotificationParams(JsonNode? notificationParams)
{
try
@@ -524,7 +523,7 @@ public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken ca
private string CreateActivityName(string method) => method;
- private static string GetMethodName(IJsonRpcMessage message) =>
+ private static string GetMethodName(JsonRpcMessage message) =>
message switch
{
JsonRpcRequest request => request.Method,
@@ -532,7 +531,7 @@ private static string GetMethodName(IJsonRpcMessage message) =>
_ => "unknownMethod"
};
- private void AddTags(ref TagList tags, Activity? activity, IJsonRpcMessage message, string method)
+ private void AddTags(ref TagList tags, Activity? activity, JsonRpcMessage message, string method)
{
tags.Add("mcp.method.name", method);
tags.Add("network.transport", _transportKind);
@@ -543,9 +542,9 @@ private void AddTags(ref TagList tags, Activity? activity, IJsonRpcMessage messa
if (activity is { IsAllDataRequested: true })
{
// session and request id have high cardinality, so not applying to metric tags
- activity.AddTag("mcp.session.id", _id);
+ activity.AddTag("mcp.session.id", _sessionId);
- if (message is IJsonRpcMessageWithId withId)
+ if (message is JsonRpcMessageWithId withId)
{
activity.AddTag("mcp.request.id", withId.Id.Id?.ToString());
}
diff --git a/src/ModelContextProtocol/Shared/RequestHandlers.cs b/src/ModelContextProtocol/Shared/RequestHandlers.cs
index 184fd9077..93a3dbbf4 100644
--- a/src/ModelContextProtocol/Shared/RequestHandlers.cs
+++ b/src/ModelContextProtocol/Shared/RequestHandlers.cs
@@ -1,4 +1,5 @@
using ModelContextProtocol.Protocol.Messages;
+using ModelContextProtocol.Protocol.Transport;
using ModelContextProtocol.Utils;
using System.Text.Json;
using System.Text.Json.Nodes;
@@ -30,7 +31,7 @@ internal sealed class RequestHandlers : Dictionary
public void Set(
string method,
- Func> handler,
+ Func> handler,
JsonTypeInfo requestTypeInfo,
JsonTypeInfo responseTypeInfo)
{
@@ -42,7 +43,7 @@ public void Set(
this[method] = async (request, cancellationToken) =>
{
TRequest? typedRequest = JsonSerializer.Deserialize(request.Params, requestTypeInfo);
- object? result = await handler(typedRequest, cancellationToken).ConfigureAwait(false);
+ object? result = await handler(typedRequest, request.RelatedTransport, cancellationToken).ConfigureAwait(false);
return JsonSerializer.SerializeToNode(result, responseTypeInfo);
};
}
diff --git a/src/ModelContextProtocol/Utils/Json/JsonRpcMessageConverter.cs b/src/ModelContextProtocol/Utils/Json/JsonRpcMessageConverter.cs
index 54fd4be05..146185cac 100644
--- a/src/ModelContextProtocol/Utils/Json/JsonRpcMessageConverter.cs
+++ b/src/ModelContextProtocol/Utils/Json/JsonRpcMessageConverter.cs
@@ -6,7 +6,7 @@
namespace ModelContextProtocol.Utils.Json;
///
-/// Provides a for messages,
+/// Provides a for messages,
/// handling polymorphic deserialization of different message types.
///
///
@@ -26,10 +26,10 @@ namespace ModelContextProtocol.Utils.Json;
///
///
[EditorBrowsable(EditorBrowsableState.Never)]
-public sealed class JsonRpcMessageConverter : JsonConverter
+public sealed class JsonRpcMessageConverter : JsonConverter
{
///
- public override IJsonRpcMessage? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
+ public override JsonRpcMessage? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
if (reader.TokenType != JsonTokenType.StartObject)
{
@@ -87,7 +87,7 @@ public sealed class JsonRpcMessageConverter : JsonConverter
}
///
- public override void Write(Utf8JsonWriter writer, IJsonRpcMessage value, JsonSerializerOptions options)
+ public override void Write(Utf8JsonWriter writer, JsonRpcMessage value, JsonSerializerOptions options)
{
switch (value)
{
diff --git a/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs b/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs
index 1c6dbd9c0..b759ba975 100644
--- a/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs
+++ b/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs
@@ -82,7 +82,8 @@ internal static bool IsValidMcpToolSchema(JsonElement element)
NumberHandling = JsonNumberHandling.AllowReadingFromString)]
// JSON-RPC
- [JsonSerializable(typeof(IJsonRpcMessage))]
+ [JsonSerializable(typeof(JsonRpcMessage))]
+ [JsonSerializable(typeof(JsonRpcMessage[]))]
[JsonSerializable(typeof(JsonRpcRequest))]
[JsonSerializable(typeof(JsonRpcNotification))]
[JsonSerializable(typeof(JsonRpcResponse))]
diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs
index f7cb2c8a6..2f4cd7fbe 100644
--- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs
+++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs
@@ -13,7 +13,7 @@ namespace ModelContextProtocol.AspNetCore.Tests;
public class MapMcpTests(ITestOutputHelper testOutputHelper) : KestrelInMemoryTest(testOutputHelper)
{
- private async Task ConnectAsync(string? path = null)
+ private async Task ConnectAsync(string? path = "/sse")
{
var sseClientTransportOptions = new SseClientTransportOptions()
{
@@ -48,9 +48,7 @@ public async Task Allows_Customizing_Route()
[Theory]
[InlineData("/a", "/a/sse")]
- [InlineData("/a", "/a/")]
[InlineData("/a/", "/a/sse")]
- [InlineData("/a/", "/a/")]
public async Task CanConnect_WithMcpClient_AfterCustomizingRoute(string routePattern, string requestPath)
{
Builder.Services.AddMcpServer(options =>
diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs
index 8cb5cef1f..b659ff172 100644
--- a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs
+++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs
@@ -198,7 +198,7 @@ public async Task AdditionalHeaders_AreSent_InGetAndPostRequests()
}
[Fact]
- public async Task EmptyAdditionalHeadersKey_Throws_InvalidOpearionException()
+ public async Task EmptyAdditionalHeadersKey_Throws_InvalidOperationException()
{
Builder.Services.AddMcpServer()
.WithHttpTransport();
@@ -269,7 +269,7 @@ private static void MapAbsoluteEndpointUriMcp(IEndpointRouteBuilder endpoints)
await Results.BadRequest("Session not started.").ExecuteAsync(context);
return;
}
- var message = (IJsonRpcMessage?)await context.Request.ReadFromJsonAsync(McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage)), context.RequestAborted);
+ var message = (JsonRpcMessage?)await context.Request.ReadFromJsonAsync(McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage)), context.RequestAborted);
if (message is null)
{
await Results.BadRequest("No message in request body.").ExecuteAsync(context);
diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpTests.cs
new file mode 100644
index 000000000..cf3aa4f41
--- /dev/null
+++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpTests.cs
@@ -0,0 +1,510 @@
+using Microsoft.AspNetCore.Builder;
+using Microsoft.Extensions.DependencyInjection;
+using ModelContextProtocol.AspNetCore.Tests.Utils;
+using ModelContextProtocol.Protocol.Messages;
+using ModelContextProtocol.Protocol.Types;
+using ModelContextProtocol.Server;
+using ModelContextProtocol.Utils.Json;
+using System.Net;
+using System.Net.ServerSentEvents;
+using System.Text;
+using System.Text.Json;
+using System.Text.Json.Nodes;
+using System.Text.Json.Serialization.Metadata;
+
+namespace ModelContextProtocol.AspNetCore.Tests;
+
+public class StreamableHttpTests(ITestOutputHelper outputHelper) : KestrelInMemoryTest(outputHelper), IAsyncDisposable
+{
+ private static McpServerTool[] Tools { get; } = [
+ McpServerTool.Create(EchoAsync),
+ McpServerTool.Create(LongRunningAsync),
+ McpServerTool.Create(Progress),
+ McpServerTool.Create(Throw),
+ ];
+
+ private WebApplication? _app;
+
+ private async Task StartAsync()
+ {
+ AddDefaultHttpClientRequestHeaders();
+
+ Builder.Services.AddMcpServer(options =>
+ {
+ options.ServerInfo = new Implementation
+ {
+ Name = nameof(StreamableHttpTests),
+ Version = "73",
+ };
+ }).WithTools(Tools).WithHttpTransport();
+
+ _app = Builder.Build();
+
+ _app.MapMcp();
+
+ await _app.StartAsync(TestContext.Current.CancellationToken);
+ }
+
+ public async ValueTask DisposeAsync()
+ {
+ if (_app is not null)
+ {
+ await _app.DisposeAsync();
+ }
+ base.Dispose();
+ }
+
+ [Fact]
+ public async Task InitialPostResponse_Includes_McpSessionIdHeader()
+ {
+ await StartAsync();
+
+ using var response = await HttpClient.PostAsync("", JsonContent(InitializeRequest), TestContext.Current.CancellationToken);
+ Assert.Equal(HttpStatusCode.OK, response.StatusCode);
+ Assert.Single(response.Headers.GetValues("mcp-session-id"));
+ Assert.Equal("text/event-stream", Assert.Single(response.Content.Headers.GetValues("content-type")));
+ }
+
+ [Fact]
+ public async Task PostRequest_IsUnsupportedMediaType_WithoutJsonContentType()
+ {
+ await StartAsync();
+
+ using var response = await HttpClient.PostAsync("", new StringContent(InitializeRequest, Encoding.UTF8, "text/javascript"), TestContext.Current.CancellationToken);
+ Assert.Equal(HttpStatusCode.UnsupportedMediaType, response.StatusCode);
+ }
+
+ [Fact]
+ public async Task PostRequest_IsNotAcceptable_WithoutApplicationJsonAcceptHeader()
+ {
+ await StartAsync();
+
+ HttpClient.DefaultRequestHeaders.Accept.Clear();
+ HttpClient.DefaultRequestHeaders.Accept.Add(new("text/event-stream"));
+
+ using var response = await HttpClient.PostAsync("", JsonContent(InitializeRequest), TestContext.Current.CancellationToken);
+ Assert.Equal(HttpStatusCode.NotAcceptable, response.StatusCode);
+ }
+
+
+ [Fact]
+ public async Task PostRequest_IsNotAcceptable_WithoutTextEventStreamAcceptHeader()
+ {
+ await StartAsync();
+
+ HttpClient.DefaultRequestHeaders.Accept.Clear();
+ HttpClient.DefaultRequestHeaders.Accept.Add(new("text/json"));
+
+ using var response = await HttpClient.PostAsync("", JsonContent(InitializeRequest), TestContext.Current.CancellationToken);
+ Assert.Equal(HttpStatusCode.NotAcceptable, response.StatusCode);
+ }
+
+ [Fact]
+ public async Task GetRequest_IsNotAcceptable_WithoutTextEventStreamAcceptHeader()
+ {
+ await StartAsync();
+
+ HttpClient.DefaultRequestHeaders.Accept.Clear();
+ HttpClient.DefaultRequestHeaders.Accept.Add(new("text/json"));
+
+ using var response = await HttpClient.GetAsync("", TestContext.Current.CancellationToken);
+ Assert.Equal(HttpStatusCode.NotAcceptable, response.StatusCode);
+ }
+
+ [Fact]
+ public async Task PostRequest_IsNotFound_WithUnrecognizedSessionId()
+ {
+ await StartAsync();
+
+ using var request = new HttpRequestMessage(HttpMethod.Post, "")
+ {
+ Content = JsonContent(EchoRequest),
+ Headers =
+ {
+ { "mcp-session-id", "fakeSession" },
+ },
+ };
+ using var response = await HttpClient.SendAsync(request, TestContext.Current.CancellationToken);
+ Assert.Equal(HttpStatusCode.NotFound, response.StatusCode);
+ }
+
+ [Fact]
+ public async Task InitializeRequest_Matches_CustomRoute()
+ {
+ AddDefaultHttpClientRequestHeaders();
+ Builder.Services.AddMcpServer().WithHttpTransport();
+ await using var app = Builder.Build();
+
+ app.MapMcp("/custom-route");
+
+ await app.StartAsync(TestContext.Current.CancellationToken);
+
+ using var response = await HttpClient.PostAsync("/custom-route", JsonContent(InitializeRequest), TestContext.Current.CancellationToken);
+ Assert.Equal(HttpStatusCode.OK, response.StatusCode);
+ }
+
+ [Fact]
+ public async Task PostWithSingleNotification_IsAccepted_WithEmptyResponse()
+ {
+ await StartAsync();
+ await CallInitializeAndValidateAsync();
+
+ var response = await HttpClient.PostAsync("", JsonContent(ProgressNotification("1")), TestContext.Current.CancellationToken);
+ Assert.Equal(HttpStatusCode.Accepted, response.StatusCode);
+ Assert.Equal("", await response.Content.ReadAsStringAsync(TestContext.Current.CancellationToken));
+ }
+
+ [Fact]
+ public async Task InitializeJsonRpcRequest_IsHandled_WithCompleteSseResponse()
+ {
+ await StartAsync();
+ await CallInitializeAndValidateAsync();
+ }
+
+ [Fact]
+ public async Task BatchedJsonRpcRequests_IsHandled_WithCompleteSseResponse()
+ {
+ await StartAsync();
+
+ using var response = await HttpClient.PostAsync("", JsonContent($"[{InitializeRequest},{EchoRequest}]"), TestContext.Current.CancellationToken);
+ Assert.Equal(HttpStatusCode.OK, response.StatusCode);
+
+ var eventCount = 0;
+ await foreach (var sseEvent in ReadSseAsync(response.Content))
+ {
+ var jsonRpcResponse = JsonSerializer.Deserialize(sseEvent, GetJsonTypeInfo());
+ Assert.NotNull(jsonRpcResponse);
+ var responseId = Assert.IsType(jsonRpcResponse.Id.Id);
+
+ switch (responseId)
+ {
+ case 1:
+ AssertServerInfo(jsonRpcResponse);
+ break;
+ case 2:
+ AssertEchoResponse(jsonRpcResponse);
+ break;
+ default:
+ throw new Exception($"Unexpected response ID: {jsonRpcResponse.Id}");
+ }
+
+ eventCount++;
+ }
+
+ Assert.Equal(2, eventCount);
+ }
+
+ [Fact]
+ public async Task SingleJsonRpcRequest_ThatThrowsIsHandled_WithCompleteSseResponse()
+ {
+ await StartAsync();
+ await CallInitializeAndValidateAsync();
+
+ var response = await HttpClient.PostAsync("", JsonContent(CallTool("throw")), TestContext.Current.CancellationToken);
+ var rpcError = await AssertSingleSseResponseAsync(response);
+
+ var error = AssertType(rpcError.Result);
+ var content = Assert.Single(error.Content);
+ Assert.Contains("'throw'", content.Text);
+ }
+
+ [Fact]
+ public async Task MultipleSerialJsonRpcRequests_IsHandled_OneAtATime()
+ {
+ await StartAsync();
+
+ await CallInitializeAndValidateAsync();
+ await CallEchoAndValidateAsync();
+ await CallEchoAndValidateAsync();
+ }
+
+ [Fact]
+ public async Task MultipleConcurrentJsonRpcRequests_IsHandled_InParallel()
+ {
+ await StartAsync();
+
+ await CallInitializeAndValidateAsync();
+
+ var echoTasks = new Task[100];
+ for (int i = 0; i < echoTasks.Length; i++)
+ {
+ echoTasks[i] = CallEchoAndValidateAsync();
+ }
+
+ await Task.WhenAll(echoTasks);
+ }
+
+ [Fact]
+ public async Task GetRequest_Receives_UnsolicitedNotifications()
+ {
+ IMcpServer? server = null;
+
+ Builder.Services.AddMcpServer()
+ .WithHttpTransport(options =>
+ {
+ options.RunSessionHandler = (httpContext, mcpServer, cancellationToken) =>
+ {
+ server = mcpServer;
+ return mcpServer.RunAsync(cancellationToken);
+ };
+ });
+
+ await StartAsync();
+
+ await CallInitializeAndValidateAsync();
+ Assert.NotNull(server);
+
+ // Headers should be sent even before any messages are ready on the GET endpoint.
+ using var getResponse = await HttpClient.GetAsync("", HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken);
+ async Task GetFirstNotificationAsync()
+ {
+ await foreach (var sseEvent in ReadSseAsync(getResponse.Content))
+ {
+ var notification = JsonSerializer.Deserialize(sseEvent, GetJsonTypeInfo());
+ Assert.NotNull(notification);
+ return notification.Method;
+ }
+
+ throw new Exception("No notifications received.");
+ }
+
+ await server.SendNotificationAsync("test-method", TestContext.Current.CancellationToken);
+ Assert.Equal("test-method", await GetFirstNotificationAsync());
+ }
+
+ [Fact]
+ public async Task SecondGetRequests_IsRejected_AsBadRequest()
+ {
+ await StartAsync();
+
+ await CallInitializeAndValidateAsync();
+ using var getResponse1 = await HttpClient.GetAsync("", HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken);
+ using var getResponse2 = await HttpClient.GetAsync("", HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken);
+
+ Assert.Equal(HttpStatusCode.OK, getResponse1.StatusCode);
+ Assert.Equal(HttpStatusCode.BadRequest, getResponse2.StatusCode);
+ }
+
+ [Fact]
+ public async Task DeleteRequest_CompletesSession_WhichIsNoLongerFound()
+ {
+ await StartAsync();
+
+ await CallInitializeAndValidateAsync();
+ await CallEchoAndValidateAsync();
+ await HttpClient.DeleteAsync("", TestContext.Current.CancellationToken);
+
+ using var response = await HttpClient.PostAsync("", JsonContent(EchoRequest), TestContext.Current.CancellationToken);
+ Assert.Equal(HttpStatusCode.NotFound, response.StatusCode);
+ }
+
+ [Fact]
+ public async Task DeleteRequest_CompletesSession_WhichCancelsLongRunningToolCalls()
+ {
+ await StartAsync();
+
+ await CallInitializeAndValidateAsync();
+
+ Task CallLongRunningToolAsync() =>
+ HttpClient.PostAsync("", JsonContent(CallTool("long-running")), TestContext.Current.CancellationToken);
+
+ var longRunningToolTasks = new Task[10];
+ for (int i = 0; i < longRunningToolTasks.Length; i++)
+ {
+ longRunningToolTasks[i] = CallLongRunningToolAsync();
+ Assert.False(longRunningToolTasks[i].IsCompleted);
+ }
+ await HttpClient.DeleteAsync("", TestContext.Current.CancellationToken);
+
+ // Currently, the OCE thrown by the canceled session is unhandled and turned into a 500 error by Kestrel.
+ // The spec suggests sending CancelledNotifications. That would be good, but we can do that later.
+ // For now, the important thing is that request completes without indicating success.
+ await Task.WhenAll(longRunningToolTasks);
+ foreach (var task in longRunningToolTasks)
+ {
+ var response = await task;
+ Assert.False(response.IsSuccessStatusCode);
+ }
+ }
+
+ [Fact]
+ public async Task Progress_IsReported_InSameSseResponseAsRpcResponse()
+ {
+ await StartAsync();
+
+ await CallInitializeAndValidateAsync();
+
+ using var response = await HttpClient.PostAsync("", JsonContent(CallToolWithProgressToken("progress")), TestContext.Current.CancellationToken);
+ Assert.Equal(HttpStatusCode.OK, response.StatusCode);
+
+ var currentSseItem = 0;
+ await foreach (var sseEvent in ReadSseAsync(response.Content))
+ {
+ currentSseItem++;
+
+ if (currentSseItem <= 10)
+ {
+ var notification = JsonSerializer.Deserialize(sseEvent, GetJsonTypeInfo());
+ var progressNotification = AssertType(notification?.Params);
+ Assert.Equal($"Progress {currentSseItem - 1}", progressNotification.Progress.Message);
+ }
+ else
+ {
+ var rpcResponse = JsonSerializer.Deserialize(sseEvent, GetJsonTypeInfo());
+ var callToolResponse = AssertType(rpcResponse?.Result);
+ var callToolContent = Assert.Single(callToolResponse.Content);
+ Assert.Equal("text", callToolContent.Type);
+ Assert.Equal("done", callToolContent.Text);
+ }
+ }
+
+ Assert.Equal(11, currentSseItem);
+ }
+
+ private void AddDefaultHttpClientRequestHeaders()
+ {
+ HttpClient.DefaultRequestHeaders.Accept.Add(new("application/json"));
+ HttpClient.DefaultRequestHeaders.Accept.Add(new("text/event-stream"));
+ }
+
+ private static StringContent JsonContent(string json) => new StringContent(json, Encoding.UTF8, "application/json");
+ private static JsonTypeInfo GetJsonTypeInfo() => (JsonTypeInfo)McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(T));
+
+ private static T AssertType(JsonNode? jsonNode)
+ {
+ var type = JsonSerializer.Deserialize(jsonNode, GetJsonTypeInfo());
+ Assert.NotNull(type);
+ return type;
+ }
+
+ private static async IAsyncEnumerable ReadSseAsync(HttpContent responseContent)
+ {
+ var responseStream = await responseContent.ReadAsStreamAsync(TestContext.Current.CancellationToken);
+ await foreach (var sseItem in SseParser.Create(responseStream).EnumerateAsync(TestContext.Current.CancellationToken))
+ {
+ Assert.Equal("message", sseItem.EventType);
+ yield return sseItem.Data;
+ }
+ }
+
+ private static async Task AssertSingleSseResponseAsync(HttpResponseMessage response)
+ {
+ Assert.Equal(HttpStatusCode.OK, response.StatusCode);
+ Assert.Equal("text/event-stream", response.Content.Headers.ContentType?.MediaType);
+
+ var sseItem = Assert.Single(await ReadSseAsync(response.Content).ToListAsync(TestContext.Current.CancellationToken));
+ var jsonRpcResponse = JsonSerializer.Deserialize(sseItem, GetJsonTypeInfo());
+
+ Assert.NotNull(jsonRpcResponse);
+ return jsonRpcResponse;
+ }
+
+ private static string InitializeRequest => """
+ {"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{},"clientInfo":{"name":"IntegrationTestClient","version":"1.0.0"}}}
+ """;
+
+ private long _lastRequestId = 1;
+ private string EchoRequest
+ {
+ get
+ {
+ var id = Interlocked.Increment(ref _lastRequestId);
+ return $$$$"""
+ {"jsonrpc":"2.0","id":{{{{id}}}},"method":"tools/call","params":{"name":"echo","arguments":{"message":"Hello world! ({{{{id}}}})"}}}
+ """;
+ }
+ }
+
+ private string ProgressNotification(string progress)
+ {
+ return $$$"""
+ {"jsonrpc":"2.0","method":"notifications/progress","params":{"progressToken":"","progress":{{{progress}}}}}
+ """;
+ }
+
+ private string Request(string method, string parameters = "{}")
+ {
+ var id = Interlocked.Increment(ref _lastRequestId);
+ return $$"""
+ {"jsonrpc":"2.0","id":{{id}},"method":"{{method}}","params":{{parameters}}}
+ """;
+ }
+
+ private string CallTool(string toolName, string arguments = "{}") =>
+ Request("tools/call", $$"""
+ {"name":"{{toolName}}","arguments":{{arguments}}}
+ """);
+
+ private string CallToolWithProgressToken(string toolName, string arguments = "{}") =>
+ Request("tools/call", $$$"""
+ {"name":"{{{toolName}}}","arguments":{{{arguments}}}, "_meta":{"progressToken": "abc123"}}
+ """);
+
+ private static InitializeResult AssertServerInfo(JsonRpcResponse rpcResponse)
+ {
+ var initializeResult = AssertType(rpcResponse.Result);
+ Assert.Equal(nameof(StreamableHttpTests), initializeResult.ServerInfo.Name);
+ Assert.Equal("73", initializeResult.ServerInfo.Version);
+ return initializeResult;
+ }
+
+ private static CallToolResponse AssertEchoResponse(JsonRpcResponse rpcResponse)
+ {
+ var callToolResponse = AssertType(rpcResponse.Result);
+ var callToolContent = Assert.Single(callToolResponse.Content);
+ Assert.Equal("text", callToolContent.Type);
+ Assert.Equal($"Hello world! ({rpcResponse.Id})", callToolContent.Text);
+ return callToolResponse;
+ }
+
+ private async Task CallInitializeAndValidateAsync()
+ {
+ using var response = await HttpClient.PostAsync("", JsonContent(InitializeRequest), TestContext.Current.CancellationToken);
+ var rpcResponse = await AssertSingleSseResponseAsync(response);
+ AssertServerInfo(rpcResponse);
+
+ var sessionId = Assert.Single(response.Headers.GetValues("mcp-session-id"));
+ HttpClient.DefaultRequestHeaders.Add("mcp-session-id", sessionId);
+ }
+
+ private async Task CallEchoAndValidateAsync()
+ {
+ using var response = await HttpClient.PostAsync("", JsonContent(EchoRequest), TestContext.Current.CancellationToken);
+ var rpcResponse = await AssertSingleSseResponseAsync(response);
+ AssertEchoResponse(rpcResponse);
+ }
+
+ [McpServerTool(Name = "echo")]
+ private static async Task EchoAsync(string message)
+ {
+ // McpSession.ProcessMessagesAsync() already yields before calling any handlers, but this makes it even
+ // more explicit that we're not relying on synchronous execution of the tool.
+ await Task.Yield();
+ return message;
+ }
+
+ [McpServerTool(Name = "long-running")]
+ private static async Task LongRunningAsync(CancellationToken cancellation)
+ {
+ // McpSession.ProcessMessagesAsync() already yields before calling any handlers, but this makes it even
+ // more explicit that we're not relying on synchronous execution of the tool.
+ await Task.Delay(Timeout.Infinite, cancellation);
+ }
+
+ [McpServerTool(Name = "progress")]
+ public static string Progress(IProgress progress)
+ {
+ for (int i = 0; i < 10; i++)
+ {
+ progress.Report(new() { Progress = i, Total = 10, Message = $"Progress {i}" });
+ }
+
+ return "done";
+ }
+
+ [McpServerTool(Name = "throw")]
+ private static void Throw()
+ {
+ throw new Exception();
+ }
+}
diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs
index b10a0d674..597b39de7 100644
--- a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs
+++ b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs
@@ -27,7 +27,11 @@ public KestrelInMemoryTest(ITestOutputHelper testOutputHelper)
var connection = _inMemoryTransport.CreateConnection();
return new(connection.ClientStream);
},
- });
+ })
+ {
+ BaseAddress = new Uri("http://localhost:5000/"),
+ Timeout = TimeSpan.FromSeconds(10),
+ };
}
public WebApplicationBuilder Builder { get; }
diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/TestServerTransport.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/TestServerTransport.cs
index 0c9cb4377..a221b8a38 100644
--- a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/TestServerTransport.cs
+++ b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/TestServerTransport.cs
@@ -8,19 +8,19 @@ namespace ModelContextProtocol.Tests.Utils;
public class TestServerTransport : ITransport
{
- private readonly Channel _messageChannel;
+ private readonly Channel _messageChannel;
public bool IsConnected { get; set; }
- public ChannelReader MessageReader => _messageChannel;
+ public ChannelReader MessageReader => _messageChannel;
- public List SentMessages { get; } = [];
+ public List SentMessages { get; } = [];
- public Action? OnMessageSent { get; set; }
+ public Action? OnMessageSent { get; set; }
public TestServerTransport()
{
- _messageChannel = Channel.CreateUnbounded(new UnboundedChannelOptions
+ _messageChannel = Channel.CreateUnbounded(new UnboundedChannelOptions
{
SingleReader = true,
SingleWriter = true,
@@ -35,7 +35,7 @@ public ValueTask DisposeAsync()
return ValueTask.CompletedTask;
}
- public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
+ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default)
{
SentMessages.Add(message);
if (message is JsonRpcRequest request)
@@ -76,7 +76,7 @@ await WriteMessageAsync(new JsonRpcResponse
}, cancellationToken);
}
- private async Task WriteMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
+ private async Task WriteMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default)
{
await _messageChannel.Writer.WriteAsync(message, cancellationToken);
}
diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs
index 454222e6b..4b3fe3a97 100644
--- a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs
+++ b/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs
@@ -106,11 +106,11 @@ public async Task CreateAsync_WithCapabilitiesOptions(Type transportType)
private class NopTransport : ITransport, IClientTransport
{
- private readonly Channel _channel = Channel.CreateUnbounded();
+ private readonly Channel _channel = Channel.CreateUnbounded();
public bool IsConnected => true;
- public ChannelReader MessageReader => _channel.Reader;
+ public ChannelReader MessageReader => _channel.Reader;
public Task ConnectAsync(CancellationToken cancellationToken = default) => Task.FromResult(this);
@@ -118,7 +118,7 @@ private class NopTransport : ITransport, IClientTransport
public string Name => "Test Nop Transport";
- public virtual Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
+ public virtual Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default)
{
switch (message)
{
@@ -148,7 +148,7 @@ private sealed class FailureTransport : NopTransport
{
public const string ExpectedMessage = "Something failed";
- public override Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
+ public override Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default)
{
throw new InvalidOperationException(ExpectedMessage);
}
diff --git a/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs b/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs
index 1d42c3d8a..0d18667e9 100644
--- a/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs
+++ b/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs
@@ -33,7 +33,7 @@ public async Task RegistrationsAreRemovedWhenDisposed()
}
}
- Assert.Equal(Iterations, counter);
+ Assert.Equal(Iterations, counter);
}
[Fact]
diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs
index d7ccb6e15..958fe124d 100644
--- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs
+++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs
@@ -619,7 +619,7 @@ public Task SendRequestAsync(JsonRpcRequest request, Cancellati
public McpServerOptions ServerOptions => throw new NotImplementedException();
public IServiceProvider? Services => throw new NotImplementedException();
public LoggingLevel? LoggingLevel => throw new NotImplementedException();
- public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) =>
+ public Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) =>
throw new NotImplementedException();
public Task RunAsync(CancellationToken cancellationToken = default) =>
throw new NotImplementedException();
diff --git a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs
index 6cfdd06cb..1f74a9565 100644
--- a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs
+++ b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs
@@ -216,6 +216,7 @@ public async Task DisposeAsync_Should_Dispose_Resources()
await session.DisposeAsync();
- Assert.False(session.IsConnected);
+ var transportBase = Assert.IsAssignableFrom(session);
+ Assert.False(transportBase.IsConnected);
}
}
\ No newline at end of file
diff --git a/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs b/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs
index 570708d54..aa01a894d 100644
--- a/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs
+++ b/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs
@@ -9,19 +9,19 @@ namespace ModelContextProtocol.Tests.Utils;
public class TestServerTransport : ITransport
{
- private readonly Channel _messageChannel;
+ private readonly Channel _messageChannel;
public bool IsConnected { get; set; }
- public ChannelReader MessageReader => _messageChannel;
+ public ChannelReader MessageReader => _messageChannel;
- public List SentMessages { get; } = [];
+ public List SentMessages { get; } = [];
- public Action? OnMessageSent { get; set; }
+ public Action? OnMessageSent { get; set; }
public TestServerTransport()
{
- _messageChannel = Channel.CreateUnbounded(new UnboundedChannelOptions
+ _messageChannel = Channel.CreateUnbounded(new UnboundedChannelOptions
{
SingleReader = true,
SingleWriter = true,
@@ -36,7 +36,7 @@ public ValueTask DisposeAsync()
return default;
}
- public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
+ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default)
{
SentMessages.Add(message);
if (message is JsonRpcRequest request)
@@ -77,7 +77,7 @@ await WriteMessageAsync(new JsonRpcResponse
}, cancellationToken);
}
- private async Task WriteMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
+ private async Task WriteMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default)
{
await _messageChannel.Writer.WriteAsync(message, cancellationToken);
}