diff --git a/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs index f67c57afd..c9a5ba87f 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs @@ -11,11 +11,11 @@ public static class HttpMcpServerBuilderExtensions { /// /// Adds the services necessary for - /// to handle MCP requests and sessions using the MCP HTTP Streaming transport. For more information on configuring the underlying HTTP server + /// to handle MCP requests and sessions using the MCP Streamable HTTP transport. For more information on configuring the underlying HTTP server /// to control things like port binding custom TLS certificates, see the Minimal APIs quick reference. /// /// The builder instance. - /// Configures options for the HTTP Streaming transport. This allows configuring per-session + /// Configures options for the Streamable HTTP transport. This allows configuring per-session /// and running logic before and after a session. /// The builder provided in . /// is . @@ -23,6 +23,8 @@ public static IMcpServerBuilder WithHttpTransport(this IMcpServerBuilder builder { ArgumentNullException.ThrowIfNull(builder); builder.Services.TryAddSingleton(); + builder.Services.TryAddSingleton(); + builder.Services.AddHostedService(); if (configureOptions is not null) { diff --git a/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs b/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs index 216962a8b..fed2f131e 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs @@ -1,18 +1,61 @@ using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Server; using System.Security.Claims; namespace ModelContextProtocol.AspNetCore; -internal class HttpMcpSession +internal sealed class HttpMcpSession(string sessionId, TTransport transport, ClaimsPrincipal user, TimeProvider timeProvider) : IAsyncDisposable + where TTransport : ITransport { - public HttpMcpSession(SseResponseStreamTransport transport, ClaimsPrincipal user) + private int _referenceCount; + private int _getRequestStarted; + private CancellationTokenSource _disposeCts = new(); + + public string Id { get; } = sessionId; + public TTransport Transport { get; } = transport; + public (string Type, string Value, string Issuer)? UserIdClaim { get; } = GetUserIdClaim(user); + + public CancellationToken SessionClosed => _disposeCts.Token; + + public bool IsActive => !SessionClosed.IsCancellationRequested && _referenceCount > 0; + public long LastActivityTicks { get; private set; } = timeProvider.GetTimestamp(); + + public IMcpServer? Server { get; set; } + public Task? ServerRunTask { get; set; } + + public IDisposable AcquireReference() { - Transport = transport; - UserIdClaim = GetUserIdClaim(user); + Interlocked.Increment(ref _referenceCount); + return new UnreferenceDisposable(this, timeProvider); } - public SseResponseStreamTransport Transport { get; } - public (string Type, string Value, string Issuer)? UserIdClaim { get; } + public bool TryStartGetRequest() => Interlocked.Exchange(ref _getRequestStarted, 1) == 0; + + public async ValueTask DisposeAsync() + { + try + { + await _disposeCts.CancelAsync(); + + if (ServerRunTask is not null) + { + await ServerRunTask; + } + } + catch (OperationCanceledException) + { + } + finally + { + if (Server is not null) + { + await Server.DisposeAsync(); + } + + await Transport.DisposeAsync(); + _disposeCts.Dispose(); + } + } public bool HasSameUserId(ClaimsPrincipal user) => UserIdClaim == GetUserIdClaim(user); @@ -36,4 +79,15 @@ private static (string Type, string Value, string Issuer)? GetUserIdClaim(Claims return null; } + + private sealed class UnreferenceDisposable(HttpMcpSession session, TimeProvider timeProvider) : IDisposable + { + public void Dispose() + { + if (Interlocked.Decrement(ref session._referenceCount) == 0) + { + session.LastActivityTicks = timeProvider.GetTimestamp(); + } + } + } } diff --git a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs index 850dac244..23eeddbea 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs @@ -21,4 +21,17 @@ public class HttpServerTransportOptions /// This is useful for running logic before a sessions starts and after it completes. /// public Func? RunSessionHandler { get; set; } + + /// + /// Represents the duration of time the server will wait between any active requests before timing out an + /// MCP session. This is checked in background every 5 seconds. A client trying to resume a session will + /// receive a 404 status code and should restart their session. A client can keep their session open by + /// keeping a GET request open. The default value is set to 2 minutes. + /// + public TimeSpan IdleTimeout { get; set; } = TimeSpan.FromMinutes(2); + + /// + /// Used for testing the . + /// + public TimeProvider TimeProvider { get; set; } = TimeProvider.System; } diff --git a/src/ModelContextProtocol.AspNetCore/IdleTrackingBackgroundService.cs b/src/ModelContextProtocol.AspNetCore/IdleTrackingBackgroundService.cs new file mode 100644 index 000000000..df3203b5b --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/IdleTrackingBackgroundService.cs @@ -0,0 +1,105 @@ +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using ModelContextProtocol.Protocol.Transport; + +namespace ModelContextProtocol.AspNetCore; + +internal sealed partial class IdleTrackingBackgroundService( + StreamableHttpHandler handler, + IOptions options, + ILogger logger) : BackgroundService +{ + // The compiler will complain about the parameter being unused otherwise despite the source generator. + private ILogger _logger = logger; + + // We can make this configurable once we properly harden the MCP server. In the meantime, anyone running + // this should be taking a cattle not pets approach to their servers and be able to launch more processes + // to handle more than 10,000 idle sessions at a time. + private const int MaxIdleSessionCount = 10_000; + + protected override async Task ExecuteAsync(CancellationToken stoppingToken) + { + var timeProvider = options.Value.TimeProvider; + using var timer = new PeriodicTimer(TimeSpan.FromSeconds(5), timeProvider); + + try + { + while (!stoppingToken.IsCancellationRequested && await timer.WaitForNextTickAsync(stoppingToken)) + { + var idleActivityCutoff = timeProvider.GetTimestamp() - options.Value.IdleTimeout.Ticks; + + var idleCount = 0; + foreach (var (_, session) in handler.Sessions) + { + if (session.IsActive || session.SessionClosed.IsCancellationRequested) + { + // There's a request currently active or the session is already being closed. + continue; + } + + idleCount++; + if (idleCount == MaxIdleSessionCount) + { + // Emit critical log at most once every 5 seconds the idle count it exceeded, + //since the IdleTimeout will no longer be respected. + LogMaxSessionIdleCountExceeded(); + } + else if (idleCount < MaxIdleSessionCount && session.LastActivityTicks > idleActivityCutoff) + { + continue; + } + + if (handler.Sessions.TryRemove(session.Id, out var removedSession)) + { + LogSessionIdle(removedSession.Id); + + // Don't slow down the idle tracking loop. DisposeSessionAsync logs. We only await during graceful shutdown. + _ = DisposeSessionAsync(removedSession); + } + } + } + } + catch (OperationCanceledException) when (stoppingToken.IsCancellationRequested) + { + } + finally + { + if (stoppingToken.IsCancellationRequested) + { + List disposeSessionTasks = []; + + foreach (var (sessionKey, _) in handler.Sessions) + { + if (handler.Sessions.TryRemove(sessionKey, out var session)) + { + disposeSessionTasks.Add(DisposeSessionAsync(session)); + } + } + + await Task.WhenAll(disposeSessionTasks); + } + } + } + + private async Task DisposeSessionAsync(HttpMcpSession session) + { + try + { + await session.DisposeAsync(); + } + catch (Exception ex) + { + LogSessionDisposeError(session.Id, ex); + } + } + + [LoggerMessage(Level = LogLevel.Information, Message = "Closing idle session {sessionId}.")] + private partial void LogSessionIdle(string sessionId); + + [LoggerMessage(Level = LogLevel.Critical, Message = "Exceeded static maximum of 10,000 idle connections. Now clearing all inactive connections regardless of timeout.")] + private partial void LogMaxSessionIdleCountExceeded(); + + [LoggerMessage(Level = LogLevel.Error, Message = "Error disposing the IMcpServer for session {sessionId}.")] + private partial void LogSessionDisposeError(string sessionId, Exception ex); +} diff --git a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs index ac424cc8b..0eefa52fb 100644 --- a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs +++ b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs @@ -1,6 +1,9 @@ -using Microsoft.AspNetCore.Routing; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Metadata; +using Microsoft.AspNetCore.Routing; using Microsoft.Extensions.DependencyInjection; using ModelContextProtocol.AspNetCore; +using ModelContextProtocol.Protocol.Messages; using System.Diagnostics.CodeAnalysis; namespace Microsoft.AspNetCore.Builder; @@ -11,21 +14,42 @@ namespace Microsoft.AspNetCore.Builder; public static class McpEndpointRouteBuilderExtensions { /// - /// Sets up endpoints for handling MCP HTTP Streaming transport. - /// See the protocol specification for details about the Streamable HTTP transport. + /// Sets up endpoints for handling MCP Streamable HTTP transport. + /// See the 2025-03-26 protocol specification for details about the Streamable HTTP transport. + /// Also maps legacy SSE endpoints for backward compatibility at the path "/sse" and "/message". the 2024-11-05 protocol specification for details about the HTTP with SSE transport. /// /// The web application to attach MCP HTTP endpoints. /// The route pattern prefix to map to. /// Returns a builder for configuring additional endpoint conventions like authorization policies. public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpoints, [StringSyntax("Route")] string pattern = "") { - var handler = endpoints.ServiceProvider.GetService() ?? + var streamableHttpHandler = endpoints.ServiceProvider.GetService() ?? throw new InvalidOperationException("You must call WithHttpTransport(). Unable to find required services. Call builder.Services.AddMcpServer().WithHttpTransport() in application startup code."); - var routeGroup = endpoints.MapGroup(pattern); - routeGroup.MapGet("", handler.HandleRequestAsync); - routeGroup.MapGet("/sse", handler.HandleRequestAsync); - routeGroup.MapPost("/message", handler.HandleRequestAsync); - return routeGroup; + var mcpGroup = endpoints.MapGroup(pattern); + var streamableHttpGroup = mcpGroup.MapGroup("") + .WithDisplayName(b => $"MCP Streamable HTTP | {b.DisplayName}") + .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status404NotFound, typeof(JsonRpcError), contentTypes: ["application/json"])); + + streamableHttpGroup.MapPost("", streamableHttpHandler.HandlePostRequestAsync) + .WithMetadata(new AcceptsMetadata(["application/json"])) + .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status200OK, contentTypes: ["text/event-stream"])) + .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status202Accepted)); + streamableHttpGroup.MapGet("", streamableHttpHandler.HandleGetRequestAsync) + .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status200OK, contentTypes: ["text/event-stream"])); + streamableHttpGroup.MapDelete("", streamableHttpHandler.HandleDeleteRequestAsync); + + // Map legacy HTTP with SSE endpoints. + var sseHandler = endpoints.ServiceProvider.GetRequiredService(); + var sseGroup = mcpGroup.MapGroup("") + .WithDisplayName(b => $"MCP HTTP with SSE | {b.DisplayName}"); + + sseGroup.MapGet("/sse", sseHandler.HandleSseRequestAsync) + .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status200OK, contentTypes: ["text/event-stream"])); + sseGroup.MapPost("/message", sseHandler.HandleMessageRequestAsync) + .WithMetadata(new AcceptsMetadata(["application/json"])) + .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status202Accepted)); + + return mcpGroup; } } diff --git a/src/ModelContextProtocol.AspNetCore/SseHandler.cs b/src/ModelContextProtocol.AspNetCore/SseHandler.cs new file mode 100644 index 000000000..638592198 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/SseHandler.cs @@ -0,0 +1,110 @@ +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Server; +using ModelContextProtocol.Utils.Json; +using System.Collections.Concurrent; +using System.Diagnostics; + +namespace ModelContextProtocol.AspNetCore; + +internal sealed class SseHandler( + IOptions mcpServerOptionsSnapshot, + IOptionsFactory mcpServerOptionsFactory, + IOptions httpMcpServerOptions, + IHostApplicationLifetime hostApplicationLifetime, + ILoggerFactory loggerFactory) +{ + private readonly ConcurrentDictionary> _sessions = new(StringComparer.Ordinal); + + public async Task HandleSseRequestAsync(HttpContext context) + { + var sessionId = StreamableHttpHandler.MakeNewSessionId(); + + // If the server is shutting down, we need to cancel all SSE connections immediately without waiting for HostOptions.ShutdownTimeout + // which defaults to 30 seconds. + using var sseCts = CancellationTokenSource.CreateLinkedTokenSource(context.RequestAborted, hostApplicationLifetime.ApplicationStopping); + var cancellationToken = sseCts.Token; + + StreamableHttpHandler.InitializeSseResponse(context); + + await using var transport = new SseResponseStreamTransport(context.Response.Body, $"message?sessionId={sessionId}"); + await using var httpMcpSession = new HttpMcpSession(sessionId, transport, context.User, httpMcpServerOptions.Value.TimeProvider); + if (!_sessions.TryAdd(sessionId, httpMcpSession)) + { + throw new UnreachableException($"Unreachable given good entropy! Session with ID '{sessionId}' has already been created."); + } + + try + { + var mcpServerOptions = mcpServerOptionsSnapshot.Value; + if (httpMcpServerOptions.Value.ConfigureSessionOptions is { } configureSessionOptions) + { + mcpServerOptions = mcpServerOptionsFactory.Create(Options.DefaultName); + await configureSessionOptions(context, mcpServerOptions, cancellationToken); + } + + var transportTask = transport.RunAsync(cancellationToken); + + try + { + await using var mcpServer = McpServerFactory.Create(transport, mcpServerOptions, loggerFactory, context.RequestServices); + httpMcpSession.Server = mcpServer; + context.Features.Set(mcpServer); + + var runSessionAsync = httpMcpServerOptions.Value.RunSessionHandler ?? StreamableHttpHandler.RunSessionAsync; + httpMcpSession.ServerRunTask = runSessionAsync(context, mcpServer, cancellationToken); + await httpMcpSession.ServerRunTask; + } + finally + { + await transport.DisposeAsync(); + await transportTask; + } + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + // RequestAborted always triggers when the client disconnects before a complete response body is written, + // but this is how SSE connections are typically closed. + } + finally + { + _sessions.TryRemove(sessionId, out _); + } + } + + public async Task HandleMessageRequestAsync(HttpContext context) + { + if (!context.Request.Query.TryGetValue("sessionId", out var sessionId)) + { + await Results.BadRequest("Missing sessionId query parameter.").ExecuteAsync(context); + return; + } + + if (!_sessions.TryGetValue(sessionId.ToString(), out var httpMcpSession)) + { + await Results.BadRequest($"Session ID not found.").ExecuteAsync(context); + return; + } + + if (!httpMcpSession.HasSameUserId(context.User)) + { + await Results.Forbid().ExecuteAsync(context); + return; + } + + var message = (JsonRpcMessage?)await context.Request.ReadFromJsonAsync(McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage)), context.RequestAborted); + if (message is null) + { + await Results.BadRequest("No message in request body.").ExecuteAsync(context); + return; + } + + await httpMcpSession.Transport.OnMessageReceivedAsync(message, context.RequestAborted); + context.Response.StatusCode = StatusCodes.Status202Accepted; + await context.Response.WriteAsync("Accepted"); + } +} diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index b88cc07c6..1dc976192 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -1,7 +1,6 @@ using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.WebUtilities; -using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using ModelContextProtocol.Protocol.Messages; @@ -10,7 +9,9 @@ using ModelContextProtocol.Utils.Json; using System.Collections.Concurrent; using System.Diagnostics; +using System.IO.Pipelines; using System.Security.Cryptography; +using System.Text.Json.Serialization.Metadata; namespace ModelContextProtocol.AspNetCore; @@ -18,130 +19,207 @@ internal sealed class StreamableHttpHandler( IOptions mcpServerOptionsSnapshot, IOptionsFactory mcpServerOptionsFactory, IOptions httpMcpServerOptions, - IHostApplicationLifetime hostApplicationLifetime, - ILoggerFactory loggerFactory) + ILoggerFactory loggerFactory, + IServiceProvider applicationServices) { + private static JsonTypeInfo s_errorTypeInfo = GetRequiredJsonTypeInfo(); - private readonly ConcurrentDictionary _sessions = new(StringComparer.Ordinal); - private readonly ILogger _logger = loggerFactory.CreateLogger(); + public ConcurrentDictionary> Sessions { get; } = new(StringComparer.Ordinal); - public async Task HandleRequestAsync(HttpContext context) + public async Task HandlePostRequestAsync(HttpContext context) { - if (context.Request.Method == HttpMethods.Get) + // The Streamable HTTP spec mandates the client MUST accept both application/json and text/event-stream. + // ASP.NET Core Minimal APIs mostly ry to stay out of the business of response content negotiation, so + // we have to do this manually. The spec doesn't mandate that servers MUST reject these requests, but it's + // probably good to at least start out trying to be strict. + var acceptHeader = context.Request.Headers.Accept.ToString(); + if (!acceptHeader.Contains("application/json", StringComparison.Ordinal) || + !acceptHeader.Contains("text/event-stream", StringComparison.Ordinal)) { - await HandleSseRequestAsync(context); + await WriteJsonRpcErrorAsync(context, + "Not Acceptable: Client must accept both application/json and text/event-stream", + StatusCodes.Status406NotAcceptable); + return; } - else if (context.Request.Method == HttpMethods.Post) + + var session = await GetOrCreateSessionAsync(context); + if (session is null) { - await HandleMessageRequestAsync(context); + return; } - else + + using var _ = session.AcquireReference(); + using var cts = CancellationTokenSource.CreateLinkedTokenSource(context.RequestAborted, session.SessionClosed); + InitializeSseResponse(context); + var wroteResponse = await session.Transport.HandlePostRequest(new HttpDuplexPipe(context), cts.Token); + if (!wroteResponse) { - context.Response.StatusCode = StatusCodes.Status405MethodNotAllowed; - await context.Response.WriteAsync("Method Not Allowed"); + // We wound up writing nothing, so there should be no Content-Type response header. + context.Response.Headers.ContentType = (string?)null; + context.Response.StatusCode = StatusCodes.Status202Accepted; } } - public async Task HandleSseRequestAsync(HttpContext context) + public async Task HandleGetRequestAsync(HttpContext context) { - // If the server is shutting down, we need to cancel all SSE connections immediately without waiting for HostOptions.ShutdownTimeout - // which defaults to 30 seconds. - using var sseCts = CancellationTokenSource.CreateLinkedTokenSource(context.RequestAborted, hostApplicationLifetime.ApplicationStopping); - var cancellationToken = sseCts.Token; + var acceptHeader = context.Request.Headers.Accept.ToString(); + if (!acceptHeader.Contains("application/json", StringComparison.Ordinal)) + { + await WriteJsonRpcErrorAsync(context, + "Not Acceptable: Client must accept text/event-stream", + StatusCodes.Status406NotAcceptable); + return; + } - var response = context.Response; - response.Headers.ContentType = "text/event-stream"; - response.Headers.CacheControl = "no-cache,no-store"; + var sessionId = context.Request.Headers["mcp-session-id"].ToString(); + var session = await GetSessionAsync(context, sessionId); + if (session is null) + { + return; + } - // Make sure we disable all response buffering for SSE - context.Response.Headers.ContentEncoding = "identity"; - context.Features.GetRequiredFeature().DisableBuffering(); + if (!session.TryStartGetRequest()) + { + await WriteJsonRpcErrorAsync(context, + "Bad Request: This server does not support multiple GET requests. Start a new session to get a new GET SSE response.", + StatusCodes.Status400BadRequest); + return; + } - var sessionId = MakeNewSessionId(); - await using var transport = new SseResponseStreamTransport(response.Body, $"message?sessionId={sessionId}"); - var httpMcpSession = new HttpMcpSession(transport, context.User); - if (!_sessions.TryAdd(sessionId, httpMcpSession)) + using var _ = session.AcquireReference(); + InitializeSseResponse(context); + + // We should flush headers to indicate a 200 success quickly, because the initialization response + // will be sent in response to a different POST request. It might be a while before we send a message + // over this response body. + await context.Response.Body.FlushAsync(context.RequestAborted); + await session.Transport.HandleGetRequest(context.Response.Body, context.RequestAborted); + } + + public async Task HandleDeleteRequestAsync(HttpContext context) + { + var sessionId = context.Request.Headers["mcp-session-id"].ToString(); + if (Sessions.TryRemove(sessionId, out var session)) { - Debug.Fail("Unreachable given good entropy!"); - throw new InvalidOperationException($"Session with ID '{sessionId}' has already been created."); + await session.DisposeAsync(); } + } + + private void InitializeSessionResponse(HttpContext context, HttpMcpSession session) + { + context.Response.Headers["mcp-session-id"] = session.Id; + context.Features.Set(session.Server); + } - try + private async ValueTask?> GetSessionAsync(HttpContext context, string sessionId) + { + if (Sessions.TryGetValue(sessionId, out var existingSession)) { - var mcpServerOptions = mcpServerOptionsSnapshot.Value; - if (httpMcpServerOptions.Value.ConfigureSessionOptions is { } configureSessionOptions) + if (!existingSession.HasSameUserId(context.User)) { - mcpServerOptions = mcpServerOptionsFactory.Create(Options.DefaultName); - await configureSessionOptions(context, mcpServerOptions, cancellationToken); + await WriteJsonRpcErrorAsync(context, + "Forbidden: The currently authenticated user does not match the user who initiated the session.", + StatusCodes.Status403Forbidden); + return null; } - var transportTask = transport.RunAsync(cancellationToken); + InitializeSessionResponse(context, existingSession); + return existingSession; + } - try - { - await using var mcpServer = McpServerFactory.Create(transport, mcpServerOptions, loggerFactory, context.RequestServices); - context.Features.Set(mcpServer); + // -32001 isn't part of the MCP standard, but this is what the typescript-sdk currently does. + // One of the few other usages I found was from some Ethereum JSON-RPC documentation and this + // JSON-RPC library from Microsoft called StreamJsonRpc where it's called JsonRpcErrorCode.NoMarshaledObjectFound + // https://learn.microsoft.com/dotnet/api/streamjsonrpc.protocol.jsonrpcerrorcode?view=streamjsonrpc-2.9#fields + await WriteJsonRpcErrorAsync(context, "Session not found", StatusCodes.Status404NotFound, 32001); + return null; + } - var runSessionAsync = httpMcpServerOptions.Value.RunSessionHandler ?? RunSessionAsync; - await runSessionAsync(context, mcpServer, cancellationToken); - } - finally + private async ValueTask?> GetOrCreateSessionAsync(HttpContext context) + { + var sessionId = context.Request.Headers["mcp-session-id"].ToString(); + HttpMcpSession? session; + + if (string.IsNullOrEmpty(sessionId)) + { + session = await CreateSessionAsync(context); + + if (!Sessions.TryAdd(session.Id, session)) { - await transport.DisposeAsync(); - await transportTask; + throw new UnreachableException($"Unreachable given good entropy! Session with ID '{sessionId}' has already been created."); } + + return session; } - catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) - { - // RequestAborted always triggers when the client disconnects before a complete response body is written, - // but this is how SSE connections are typically closed. - } - finally + else { - _sessions.TryRemove(sessionId, out _); + return await GetSessionAsync(context, sessionId); } } - public async Task HandleMessageRequestAsync(HttpContext context) + private async ValueTask> CreateSessionAsync(HttpContext context) { - if (!context.Request.Query.TryGetValue("sessionId", out var sessionId)) + var mcpServerOptions = mcpServerOptionsSnapshot.Value; + if (httpMcpServerOptions.Value.ConfigureSessionOptions is { } configureSessionOptions) { - await Results.BadRequest("Missing sessionId query parameter.").ExecuteAsync(context); - return; + mcpServerOptions = mcpServerOptionsFactory.Create(Options.DefaultName); + await configureSessionOptions(context, mcpServerOptions, context.RequestAborted); } - if (!_sessions.TryGetValue(sessionId.ToString(), out var httpMcpSession)) - { - await Results.BadRequest($"Session ID not found.").ExecuteAsync(context); - return; - } + var transport = new StreamableHttpServerTransport(); + // Use application instead of request services, because the session will likely outlive the first initialization request. + var server = McpServerFactory.Create(transport, mcpServerOptions, loggerFactory, applicationServices); - if (!httpMcpSession.HasSameUserId(context.User)) + var session = new HttpMcpSession(MakeNewSessionId(), transport, context.User, httpMcpServerOptions.Value.TimeProvider) { - await Results.Forbid().ExecuteAsync(context); - return; - } + Server = server, + }; - var message = (IJsonRpcMessage?)await context.Request.ReadFromJsonAsync(McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage)), context.RequestAborted); - if (message is null) - { - await Results.BadRequest("No message in request body.").ExecuteAsync(context); - return; - } + var runSessionAsync = httpMcpServerOptions.Value.RunSessionHandler ?? RunSessionAsync; + session.ServerRunTask = runSessionAsync(context, server, session.SessionClosed); - await httpMcpSession.Transport.OnMessageReceivedAsync(message, context.RequestAborted); - context.Response.StatusCode = StatusCodes.Status202Accepted; - await context.Response.WriteAsync("Accepted"); + InitializeSessionResponse(context, session); + return session; } - private static Task RunSessionAsync(HttpContext httpContext, IMcpServer session, CancellationToken requestAborted) - => session.RunAsync(requestAborted); + private static Task WriteJsonRpcErrorAsync(HttpContext context, string errorMessage, int statusCode, int errorCode = -32000) + { + var jsonRpcError = new JsonRpcError + { + Error = new() + { + Code = errorCode, + Message = errorMessage, + }, + }; + return Results.Json(jsonRpcError, s_errorTypeInfo, statusCode: statusCode).ExecuteAsync(context); + } - private static string MakeNewSessionId() + internal static void InitializeSseResponse(HttpContext context) + { + context.Response.Headers.ContentType = "text/event-stream"; + context.Response.Headers.CacheControl = "no-cache,no-store"; + + // Make sure we disable all response buffering for SSE. + context.Response.Headers.ContentEncoding = "identity"; + context.Features.GetRequiredFeature().DisableBuffering(); + } + + internal static string MakeNewSessionId() { - // 128 bits Span buffer = stackalloc byte[16]; RandomNumberGenerator.Fill(buffer); return WebEncoders.Base64UrlEncode(buffer); } + + internal static Task RunSessionAsync(HttpContext httpContext, IMcpServer session, CancellationToken requestAborted) + => session.RunAsync(requestAborted); + + private static JsonTypeInfo GetRequiredJsonTypeInfo() => (JsonTypeInfo)McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(T)); + + private sealed class HttpDuplexPipe(HttpContext context) : IDuplexPipe + { + public PipeReader Input => context.Request.BodyReader; + public PipeWriter Output => context.Response.BodyWriter; + } } diff --git a/src/ModelContextProtocol/Client/McpClient.cs b/src/ModelContextProtocol/Client/McpClient.cs index 50c0a195f..2bcfacaa8 100644 --- a/src/ModelContextProtocol/Client/McpClient.cs +++ b/src/ModelContextProtocol/Client/McpClient.cs @@ -5,6 +5,7 @@ using ModelContextProtocol.Shared; using ModelContextProtocol.Utils.Json; using System.Text.Json; +using System.Threading; namespace ModelContextProtocol.Client; @@ -59,7 +60,7 @@ public McpClient(IClientTransport clientTransport, McpClientOptions? options, IL RequestHandlers.Set( RequestMethods.SamplingCreateMessage, - (request, cancellationToken) => samplingHandler( + (request, _, cancellationToken) => samplingHandler( request, request?.Meta?.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance, cancellationToken), @@ -76,7 +77,7 @@ public McpClient(IClientTransport clientTransport, McpClientOptions? options, IL RequestHandlers.Set( RequestMethods.RootsList, - rootsHandler, + (request, _, cancellationToken) => rootsHandler(request, cancellationToken), McpJsonUtilities.JsonContext.Default.ListRootsRequestParams, McpJsonUtilities.JsonContext.Default.ListRootsResult); } diff --git a/src/ModelContextProtocol/Diagnostics.cs b/src/ModelContextProtocol/Diagnostics.cs index a5d293f0c..52d82e200 100644 --- a/src/ModelContextProtocol/Diagnostics.cs +++ b/src/ModelContextProtocol/Diagnostics.cs @@ -38,7 +38,7 @@ internal static Histogram CreateDurationHistogram(string name, string de }; #endif - internal static ActivityContext ExtractActivityContext(this DistributedContextPropagator propagator, IJsonRpcMessage message) + internal static ActivityContext ExtractActivityContext(this DistributedContextPropagator propagator, JsonRpcMessage message) { propagator.ExtractTraceIdAndState(message, ExtractContext, out var traceparent, out var tracestate); ActivityContext.TryParse(traceparent, tracestate, true, out var activityContext); @@ -71,7 +71,7 @@ private static void ExtractContext(object? message, string fieldName, out string } } - internal static void InjectActivityContext(this DistributedContextPropagator propagator, Activity? activity, IJsonRpcMessage message) + internal static void InjectActivityContext(this DistributedContextPropagator propagator, Activity? activity, JsonRpcMessage message) { // noop if activity is null propagator.Inject(activity, message, InjectContext); @@ -100,7 +100,7 @@ private static void InjectContext(object? message, string key, string value) } } - internal static bool ShouldInstrumentMessage(IJsonRpcMessage message) => + internal static bool ShouldInstrumentMessage(JsonRpcMessage message) => ActivitySource.HasListeners() && message switch { diff --git a/src/ModelContextProtocol/IMcpEndpoint.cs b/src/ModelContextProtocol/IMcpEndpoint.cs index a3b941cce..dcfdf687c 100644 --- a/src/ModelContextProtocol/IMcpEndpoint.cs +++ b/src/ModelContextProtocol/IMcpEndpoint.cs @@ -45,7 +45,7 @@ public interface IMcpEndpoint : IAsyncDisposable /// Sends a JSON-RPC message to the connected endpoint. /// /// - /// The JSON-RPC message to send. This can be any type that implements IJsonRpcMessage, such as + /// The JSON-RPC message to send. This can be any type that implements JsonRpcMessage, such as /// JsonRpcRequest, JsonRpcResponse, JsonRpcNotification, or JsonRpcError. /// /// The to monitor for cancellation requests. The default is . @@ -63,7 +63,7 @@ public interface IMcpEndpoint : IAsyncDisposable /// The method will serialize the message and transmit it using the underlying transport mechanism. /// /// - Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default); + Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default); /// Registers a handler to be invoked when a notification for the specified method is received. /// The notification method. diff --git a/src/ModelContextProtocol/McpEndpointExtensions.cs b/src/ModelContextProtocol/McpEndpointExtensions.cs index 8c266b29f..3969fa336 100644 --- a/src/ModelContextProtocol/McpEndpointExtensions.cs +++ b/src/ModelContextProtocol/McpEndpointExtensions.cs @@ -41,7 +41,7 @@ public static Task SendRequestAsync( string method, TParameters parameters, JsonSerializerOptions? serializerOptions = null, - RequestId? requestId = null, + RequestId requestId = default, CancellationToken cancellationToken = default) where TResult : notnull { @@ -72,7 +72,7 @@ internal static async Task SendRequestAsync( TParameters parameters, JsonTypeInfo parametersTypeInfo, JsonTypeInfo resultTypeInfo, - RequestId? requestId = null, + RequestId requestId = default, CancellationToken cancellationToken = default) where TResult : notnull { @@ -83,15 +83,11 @@ internal static async Task SendRequestAsync( JsonRpcRequest jsonRpcRequest = new() { + Id = requestId, Method = method, Params = JsonSerializer.SerializeToNode(parameters, parametersTypeInfo), }; - if (requestId is { } id) - { - jsonRpcRequest.Id = id; - } - JsonRpcResponse response = await endpoint.SendRequestAsync(jsonRpcRequest, cancellationToken).ConfigureAwait(false); return JsonSerializer.Deserialize(response.Result, resultTypeInfo) ?? throw new JsonException("Unexpected JSON result in response."); } diff --git a/src/ModelContextProtocol/Protocol/Messages/IJsonRpcMessage.cs b/src/ModelContextProtocol/Protocol/Messages/IJsonRpcMessage.cs deleted file mode 100644 index 9880247fb..000000000 --- a/src/ModelContextProtocol/Protocol/Messages/IJsonRpcMessage.cs +++ /dev/null @@ -1,21 +0,0 @@ -using ModelContextProtocol.Utils.Json; -using System.Text.Json.Serialization; - -namespace ModelContextProtocol.Protocol.Messages; - -/// -/// Represents any JSON-RPC message used in the Model Context Protocol (MCP). -/// -/// -/// This interface serves as the foundation for all message types in the JSON-RPC 2.0 protocol -/// used by MCP, including requests, responses, notifications, and errors. JSON-RPC is a stateless, -/// lightweight remote procedure call (RPC) protocol that uses JSON as its data format. -/// -[JsonConverter(typeof(JsonRpcMessageConverter))] -public interface IJsonRpcMessage -{ - /// - /// Gets the JSON-RPC protocol version used. - /// - string JsonRpc { get; } -} diff --git a/src/ModelContextProtocol/Protocol/Messages/JsonRpcError.cs b/src/ModelContextProtocol/Protocol/Messages/JsonRpcError.cs index 30ace711d..1109e647b 100644 --- a/src/ModelContextProtocol/Protocol/Messages/JsonRpcError.cs +++ b/src/ModelContextProtocol/Protocol/Messages/JsonRpcError.cs @@ -16,16 +16,8 @@ namespace ModelContextProtocol.Protocol.Messages; /// and optional additional data to provide more context about the error. /// /// -public record JsonRpcError : IJsonRpcMessageWithId +public class JsonRpcError : JsonRpcMessageWithId { - /// - [JsonPropertyName("jsonrpc")] - public string JsonRpc { get; init; } = "2.0"; - - /// - [JsonPropertyName("id")] - public required RequestId Id { get; init; } - /// /// Gets detailed error information for the failed request, containing an error code, /// message, and optional additional data diff --git a/src/ModelContextProtocol/Protocol/Messages/JsonRpcMessage.cs b/src/ModelContextProtocol/Protocol/Messages/JsonRpcMessage.cs new file mode 100644 index 000000000..d307598f0 --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Messages/JsonRpcMessage.cs @@ -0,0 +1,35 @@ +using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Utils.Json; +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Protocol.Messages; + +/// +/// Represents any JSON-RPC message used in the Model Context Protocol (MCP). +/// +/// +/// This interface serves as the foundation for all message types in the JSON-RPC 2.0 protocol +/// used by MCP, including requests, responses, notifications, and errors. JSON-RPC is a stateless, +/// lightweight remote procedure call (RPC) protocol that uses JSON as its data format. +/// +[JsonConverter(typeof(JsonRpcMessageConverter))] +public abstract class JsonRpcMessage +{ + /// + /// Gets the JSON-RPC protocol version used. + /// + /// + [JsonPropertyName("jsonrpc")] + public string JsonRpc { get; init; } = "2.0"; + + /// + /// Gets or sets the transport the was received on or should be sent over. + /// + /// + /// This is used to support the Streamable HTTP transport where the specification states that the server + /// SHOULD include JSON-RPC responses in the HTTP response body for the POST request containing + /// the corresponding JSON-RPC request. It may be for other transports. + /// + [JsonIgnore] + public ITransport? RelatedTransport { get; set; } +} diff --git a/src/ModelContextProtocol/Protocol/Messages/IJsonRpcMessageWithId.cs b/src/ModelContextProtocol/Protocol/Messages/JsonRpcMessageWithId.cs similarity index 83% rename from src/ModelContextProtocol/Protocol/Messages/IJsonRpcMessageWithId.cs rename to src/ModelContextProtocol/Protocol/Messages/JsonRpcMessageWithId.cs index 4e2d84f4a..f095c50bc 100644 --- a/src/ModelContextProtocol/Protocol/Messages/IJsonRpcMessageWithId.cs +++ b/src/ModelContextProtocol/Protocol/Messages/JsonRpcMessageWithId.cs @@ -1,3 +1,5 @@ +using System.Text.Json.Serialization; + namespace ModelContextProtocol.Protocol.Messages; /// @@ -10,7 +12,7 @@ namespace ModelContextProtocol.Protocol.Messages; /// The ID is used to correlate requests with their responses, allowing asynchronous /// communication where multiple requests can be sent without waiting for responses. /// -public interface IJsonRpcMessageWithId : IJsonRpcMessage +public abstract class JsonRpcMessageWithId : JsonRpcMessage { /// /// Gets the message identifier. @@ -18,5 +20,6 @@ public interface IJsonRpcMessageWithId : IJsonRpcMessage /// /// Each ID is expected to be unique within the context of a given session. /// - RequestId Id { get; } + [JsonPropertyName("id")] + public RequestId Id { get; init; } } diff --git a/src/ModelContextProtocol/Protocol/Messages/JsonRpcNotification.cs b/src/ModelContextProtocol/Protocol/Messages/JsonRpcNotification.cs index ea0c35cbb..bdbd8d46f 100644 --- a/src/ModelContextProtocol/Protocol/Messages/JsonRpcNotification.cs +++ b/src/ModelContextProtocol/Protocol/Messages/JsonRpcNotification.cs @@ -11,12 +11,8 @@ namespace ModelContextProtocol.Protocol.Messages; /// They are useful for one-way communication, such as log notifications and progress updates. /// Unlike requests, notifications do not include an ID field, since there will be no response to match with it. /// -public record JsonRpcNotification : IJsonRpcMessage +public class JsonRpcNotification : JsonRpcMessage { - /// - [JsonPropertyName("jsonrpc")] - public string JsonRpc { get; init; } = "2.0"; - /// /// Gets or sets the name of the notification method. /// diff --git a/src/ModelContextProtocol/Protocol/Messages/JsonRpcRequest.cs b/src/ModelContextProtocol/Protocol/Messages/JsonRpcRequest.cs index 275503912..ff7a45044 100644 --- a/src/ModelContextProtocol/Protocol/Messages/JsonRpcRequest.cs +++ b/src/ModelContextProtocol/Protocol/Messages/JsonRpcRequest.cs @@ -14,16 +14,8 @@ namespace ModelContextProtocol.Protocol.Messages; /// and return either a with the result, or a /// if the method execution fails. /// -public record JsonRpcRequest : IJsonRpcMessageWithId +public class JsonRpcRequest : JsonRpcMessageWithId { - /// - [JsonPropertyName("jsonrpc")] - public string JsonRpc { get; init; } = "2.0"; - - /// - [JsonPropertyName("id")] - public RequestId Id { get; set; } - /// /// Name of the method to invoke. /// @@ -35,4 +27,15 @@ public record JsonRpcRequest : IJsonRpcMessageWithId /// [JsonPropertyName("params")] public JsonNode? Params { get; init; } + + internal JsonRpcRequest WithId(RequestId id) + { + return new JsonRpcRequest + { + JsonRpc = JsonRpc, + Id = id, + Method = Method, + Params = Params + }; + } } diff --git a/src/ModelContextProtocol/Protocol/Messages/JsonRpcResponse.cs b/src/ModelContextProtocol/Protocol/Messages/JsonRpcResponse.cs index c10e0fb55..01eef51e9 100644 --- a/src/ModelContextProtocol/Protocol/Messages/JsonRpcResponse.cs +++ b/src/ModelContextProtocol/Protocol/Messages/JsonRpcResponse.cs @@ -16,16 +16,8 @@ namespace ModelContextProtocol.Protocol.Messages; /// This class represents a successful response with a result. For error responses, see . /// /// -public record JsonRpcResponse : IJsonRpcMessageWithId +public class JsonRpcResponse : JsonRpcMessageWithId { - /// - [JsonPropertyName("jsonrpc")] - public string JsonRpc { get; init; } = "2.0"; - - /// - [JsonPropertyName("id")] - public required RequestId Id { get; init; } - /// /// Gets the result of the method invocation. /// diff --git a/src/ModelContextProtocol/Protocol/Transport/ITransport.cs b/src/ModelContextProtocol/Protocol/Transport/ITransport.cs index de4760488..b8cb7e3bb 100644 --- a/src/ModelContextProtocol/Protocol/Transport/ITransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/ITransport.cs @@ -26,22 +26,6 @@ namespace ModelContextProtocol.Protocol.Transport; /// public interface ITransport : IAsyncDisposable { - /// - /// Gets whether the transport is currently connected and able to send/receive messages. - /// - /// - /// - /// The property indicates the current state of the transport connection. - /// When , the transport is ready to send and receive messages. When , - /// any attempt to send messages will typically result in exceptions being thrown. - /// - /// - /// The property transitions to when the transport successfully establishes a connection, - /// and transitions to when the transport is disposed or encounters a connection error. - /// - /// - bool IsConnected { get; } - /// /// Gets a channel reader for receiving messages from the transport. /// @@ -56,7 +40,7 @@ public interface ITransport : IAsyncDisposable /// any already transmitted messages are consumed. /// /// - ChannelReader MessageReader { get; } + ChannelReader MessageReader { get; } /// /// Sends a JSON-RPC message through the transport. @@ -76,5 +60,5 @@ public interface ITransport : IAsyncDisposable /// rather than accessing this method directly. /// /// - Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default); + Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default); } diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs index e2c7cc5cc..5d952f8a6 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs @@ -68,21 +68,21 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default) /// public override async Task SendMessageAsync( - IJsonRpcMessage message, + JsonRpcMessage message, CancellationToken cancellationToken = default) { if (_messageEndpoint == null) throw new InvalidOperationException("Transport not connected"); using var content = new StringContent( - JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.IJsonRpcMessage), + JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage), Encoding.UTF8, "application/json" ); string messageId = "(no id)"; - if (message is IJsonRpcMessageWithId messageWithId) + if (message is JsonRpcMessageWithId messageWithId) { messageId = messageWithId.Id.ToString(); } @@ -213,17 +213,20 @@ private async Task ReceiveMessagesAsync(CancellationToken cancellationToken) } } } - catch when (cancellationToken.IsCancellationRequested) - { - // Normal shutdown - LogTransportReadMessagesCancelled(Name); - _connectionEstablished.TrySetCanceled(cancellationToken); - } catch (Exception ex) { - LogTransportReadMessagesFailed(Name, ex); - _connectionEstablished.TrySetException(ex); - throw; + if (cancellationToken.IsCancellationRequested) + { + // Normal shutdown + LogTransportReadMessagesCancelled(Name); + _connectionEstablished.TrySetCanceled(cancellationToken); + } + else + { + LogTransportReadMessagesFailed(Name, ex); + _connectionEstablished.TrySetException(ex); + throw; + } } finally { @@ -241,7 +244,7 @@ private async Task ProcessSseMessage(string data, CancellationToken cancellation try { - var message = JsonSerializer.Deserialize(data, McpJsonUtilities.JsonContext.Default.IJsonRpcMessage); + var message = JsonSerializer.Deserialize(data, McpJsonUtilities.JsonContext.Default.JsonRpcMessage); if (message == null) { LogTransportMessageParseUnexpectedTypeSensitive(Name, data); @@ -249,7 +252,7 @@ private async Task ProcessSseMessage(string data, CancellationToken cancellation } string messageId = "(no id)"; - if (message is IJsonRpcMessageWithId messageWithId) + if (message is JsonRpcMessageWithId messageWithId) { messageId = messageWithId.Id.ToString(); } diff --git a/src/ModelContextProtocol/Protocol/Transport/SseResponseStreamTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseResponseStreamTransport.cs index da86fd16a..f830862d1 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseResponseStreamTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseResponseStreamTransport.cs @@ -1,10 +1,6 @@ -using System.Text; -using System.Buffers; -using System.Net.ServerSentEvents; -using System.Text.Json; -using System.Threading.Channels; using ModelContextProtocol.Protocol.Messages; -using ModelContextProtocol.Utils.Json; +using ModelContextProtocol.Utils; +using System.Threading.Channels; namespace ModelContextProtocol.Protocol.Transport; @@ -22,16 +18,22 @@ namespace ModelContextProtocol.Protocol.Transport; /// such as when streaming completion results or providing progress updates during long-running operations. /// /// -public sealed class SseResponseStreamTransport(Stream sseResponseStream, string messageEndpoint = "/message") : ITransport +/// The response stream to write MCP JSON-RPC messages as SSE events to. +/// +/// The relative or absolute URI the client should use to post MCP JSON-RPC messages for this session. +/// These messages should be passed to . +/// Defaults to "/message". +/// +public sealed class SseResponseStreamTransport(Stream sseResponseStream, string? messageEndpoint = "/message") : ITransport { - private readonly Channel _incomingChannel = CreateBoundedChannel(); - private readonly Channel> _outgoingSseChannel = CreateBoundedChannel>(); - - private Task? _sseWriteTask; - private Utf8JsonWriter? _jsonWriter; + private readonly SseWriter _sseWriter = new(messageEndpoint); + private readonly Channel _incomingChannel = Channel.CreateBounded(new BoundedChannelOptions(1) + { + SingleReader = true, + SingleWriter = false, + }); - /// - public bool IsConnected { get; private set; } + private bool _isConnected; /// /// Starts the transport and writes the JSON-RPC messages sent via @@ -39,54 +41,27 @@ public sealed class SseResponseStreamTransport(Stream sseResponseStream, string /// /// The to monitor for cancellation requests. The default is . /// A task representing the send loop that writes JSON-RPC messages to the SSE response stream. - public Task RunAsync(CancellationToken cancellationToken) + public async Task RunAsync(CancellationToken cancellationToken) { - // The very first SSE event isn't really an IJsonRpcMessage, but there's no API to write a single item of a different type, - // so we fib and special-case the "endpoint" event type in the formatter. - if (!_outgoingSseChannel.Writer.TryWrite(new SseItem(null, "endpoint"))) - { - throw new InvalidOperationException($"You must call ${nameof(RunAsync)} before calling ${nameof(SendMessageAsync)}."); - } - - IsConnected = true; - - var sseItems = _outgoingSseChannel.Reader.ReadAllAsync(cancellationToken); - return _sseWriteTask = SseFormatter.WriteAsync(sseItems, sseResponseStream, WriteJsonRpcMessageToBuffer, cancellationToken); - } - - private void WriteJsonRpcMessageToBuffer(SseItem item, IBufferWriter writer) - { - if (item.EventType == "endpoint") - { - writer.Write(Encoding.UTF8.GetBytes(messageEndpoint)); - return; - } - - JsonSerializer.Serialize(GetUtf8JsonWriter(writer), item.Data, McpJsonUtilities.JsonContext.Default.IJsonRpcMessage!); + _isConnected = true; + await _sseWriter.WriteAllAsync(sseResponseStream, cancellationToken).ConfigureAwait(false); } /// - public ChannelReader MessageReader => _incomingChannel.Reader; + public ChannelReader MessageReader => _incomingChannel.Reader; /// - public ValueTask DisposeAsync() + public async ValueTask DisposeAsync() { - IsConnected = false; + _isConnected = false; _incomingChannel.Writer.TryComplete(); - _outgoingSseChannel.Writer.TryComplete(); - return new ValueTask(_sseWriteTask ?? Task.CompletedTask); + await _sseWriter.DisposeAsync().ConfigureAwait(false); } /// - public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { - if (!IsConnected) - { - throw new InvalidOperationException($"Transport is not connected. Make sure to call {nameof(RunAsync)} first."); - } - - // Emit redundant "event: message" lines for better compatibility with other SDKs. - await _outgoingSseChannel.Writer.WriteAsync(new SseItem(message, SseParser.EventTypeDefault), cancellationToken).ConfigureAwait(false); + await _sseWriter.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); } /// @@ -111,34 +86,15 @@ public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken ca /// sequencing of operations in the transport lifecycle. /// /// - public async Task OnMessageReceivedAsync(IJsonRpcMessage message, CancellationToken cancellationToken) + public async Task OnMessageReceivedAsync(JsonRpcMessage message, CancellationToken cancellationToken) { - if (!IsConnected) + Throw.IfNull(message); + + if (!_isConnected) { throw new InvalidOperationException($"Transport is not connected. Make sure to call {nameof(RunAsync)} first."); } await _incomingChannel.Writer.WriteAsync(message, cancellationToken).ConfigureAwait(false); } - - private static Channel CreateBoundedChannel(int capacity = 1) => - Channel.CreateBounded(new BoundedChannelOptions(capacity) - { - SingleReader = true, - SingleWriter = false, - }); - - private Utf8JsonWriter GetUtf8JsonWriter(IBufferWriter writer) - { - if (_jsonWriter is null) - { - _jsonWriter = new Utf8JsonWriter(writer); - } - else - { - _jsonWriter.Reset(writer); - } - - return _jsonWriter; - } } diff --git a/src/ModelContextProtocol/Protocol/Transport/SseWriter.cs b/src/ModelContextProtocol/Protocol/Transport/SseWriter.cs new file mode 100644 index 000000000..a3eb0ce46 --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Transport/SseWriter.cs @@ -0,0 +1,120 @@ +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Utils; +using ModelContextProtocol.Utils.Json; +using System.Buffers; +using System.Net.ServerSentEvents; +using System.Text; +using System.Text.Json; +using System.Threading.Channels; + +namespace ModelContextProtocol.Protocol.Transport; + +internal sealed class SseWriter(string? messageEndpoint = null, BoundedChannelOptions? channelOptions = null) : IAsyncDisposable +{ + private readonly Channel> _messages = Channel.CreateBounded>(channelOptions ?? new BoundedChannelOptions(1) + { + SingleReader = true, + SingleWriter = false, + }); + + private Utf8JsonWriter? _jsonWriter; + private Task? _writeTask; + private CancellationToken? _writeCancellationToken; + + private readonly SemaphoreSlim _disposeLock = new(1, 1); + private bool _disposed; + + public Func>, CancellationToken, IAsyncEnumerable>>? MessageFilter { get; set; } + + public Task WriteAllAsync(Stream sseResponseStream, CancellationToken cancellationToken) + { + // When messageEndpoint is set, the very first SSE event isn't really an IJsonRpcMessage, but there's no API to write a single + // item of a different type, so we fib and special-case the "endpoint" event type in the formatter. + if (messageEndpoint is not null && !_messages.Writer.TryWrite(new SseItem(null, "endpoint"))) + { + throw new InvalidOperationException("You must call RunAsync before calling SendMessageAsync."); + } + + _writeCancellationToken = cancellationToken; + + var messages = _messages.Reader.ReadAllAsync(cancellationToken); + if (MessageFilter is not null) + { + messages = MessageFilter(messages, cancellationToken); + } + + _writeTask = SseFormatter.WriteAsync(messages, sseResponseStream, WriteJsonRpcMessageToBuffer, cancellationToken); + return _writeTask; + } + + public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) + { + Throw.IfNull(message); + + using var _ = await _disposeLock.LockAsync().ConfigureAwait(false); + + if (_disposed) + { + // Don't throw an ODE, because this is disposed internally when the transport disconnects due to an abort + // or sending all the responses for the a give given Streamable HTTP POST request, so the user might not be at fault. + // There's precedence for no-oping here similar to writing to the response body of an aborted request in ASP.NET Core. + return; + } + + // Emit redundant "event: message" lines for better compatibility with other SDKs. + await _messages.Writer.WriteAsync(new SseItem(message, SseParser.EventTypeDefault), cancellationToken).ConfigureAwait(false); + } + + public async ValueTask DisposeAsync() + { + using var _ = await _disposeLock.LockAsync().ConfigureAwait(false); + + if (_disposed) + { + return; + } + + _messages.Writer.Complete(); + try + { + if (_writeTask is not null) + { + await _writeTask.ConfigureAwait(false); + } + } + catch (OperationCanceledException) when (_writeCancellationToken?.IsCancellationRequested == true) + { + // Ignore exceptions caused by intentional cancellation during shutdown. + } + finally + { + _jsonWriter?.Dispose(); + _disposed = true; + } + } + + private void WriteJsonRpcMessageToBuffer(SseItem item, IBufferWriter writer) + { + if (item.EventType == "endpoint" && messageEndpoint is not null) + { + writer.Write(Encoding.UTF8.GetBytes(messageEndpoint)); + return; + } + + JsonSerializer.Serialize(GetUtf8JsonWriter(writer), item.Data, McpJsonUtilities.JsonContext.Default.JsonRpcMessage!); + } + + private Utf8JsonWriter GetUtf8JsonWriter(IBufferWriter writer) + { + if (_jsonWriter is null) + { + _jsonWriter = new Utf8JsonWriter(writer); + } + else + { + _jsonWriter.Reset(writer); + } + + return _jsonWriter; + } +} diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs index aa92c9d4b..015b74913 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs @@ -33,7 +33,7 @@ public StdioClientSessionTransport(StdioClientTransportOptions options, Process /// /// Thrown when the underlying process has exited or cannot be accessed. /// - public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + public override async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { Exception? processException = null; bool hasExited = false; diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs index 3c24416bb..6fcdf0a8b 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs @@ -57,7 +57,7 @@ public StreamClientSessionTransport( } /// - public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + public override async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { if (!IsConnected) { @@ -65,12 +65,12 @@ public override async Task SendMessageAsync(IJsonRpcMessage message, Cancellatio } string id = "(no id)"; - if (message is IJsonRpcMessageWithId messageWithId) + if (message is JsonRpcMessageWithId messageWithId) { id = messageWithId.Id.ToString(); } - var json = JsonSerializer.Serialize(message, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage))); + var json = JsonSerializer.Serialize(message, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage))); using var _ = await _sendLock.LockAsync(cancellationToken).ConfigureAwait(false); try @@ -143,11 +143,11 @@ private async Task ProcessMessageAsync(string line, CancellationToken cancellati { try { - var message = (IJsonRpcMessage?)JsonSerializer.Deserialize(line.AsSpan().Trim(), McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage))); + var message = (JsonRpcMessage?)JsonSerializer.Deserialize(line.AsSpan().Trim(), McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage))); if (message != null) { string messageId = "(no id)"; - if (message is IJsonRpcMessageWithId messageWithId) + if (message is JsonRpcMessageWithId messageWithId) { messageId = messageWithId.Id.ToString(); } diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs index c94509955..acf18984a 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs @@ -56,7 +56,7 @@ public StreamServerTransport(Stream inputStream, Stream outputStream, string? se } /// - public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + public override async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { if (!IsConnected) { @@ -66,14 +66,14 @@ public override async Task SendMessageAsync(IJsonRpcMessage message, Cancellatio using var _ = await _sendLock.LockAsync(cancellationToken).ConfigureAwait(false); string id = "(no id)"; - if (message is IJsonRpcMessageWithId messageWithId) + if (message is JsonRpcMessageWithId messageWithId) { id = messageWithId.Id.ToString(); } try { - await JsonSerializer.SerializeAsync(_outputStream, message, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage)), cancellationToken).ConfigureAwait(false); + await JsonSerializer.SerializeAsync(_outputStream, message, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage)), cancellationToken).ConfigureAwait(false); await _outputStream.WriteAsync(s_newlineBytes, cancellationToken).ConfigureAwait(false); await _outputStream.FlushAsync(cancellationToken).ConfigureAwait(false); } @@ -109,10 +109,10 @@ private async Task ReadMessagesAsync() try { - if (JsonSerializer.Deserialize(line, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage))) is IJsonRpcMessage message) + if (JsonSerializer.Deserialize(line, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage))) is JsonRpcMessage message) { string messageId = "(no id)"; - if (message is IJsonRpcMessageWithId messageWithId) + if (message is JsonRpcMessageWithId messageWithId) { messageId = messageWithId.Id.ToString(); } diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamableHttpPostTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamableHttpPostTransport.cs new file mode 100644 index 000000000..2bd8f2784 --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Transport/StreamableHttpPostTransport.cs @@ -0,0 +1,134 @@ +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Utils; +using ModelContextProtocol.Utils.Json; +using System.Buffers; +using System.IO.Pipelines; +using System.Net.ServerSentEvents; +using System.Runtime.CompilerServices; +using System.Text.Json; +using System.Threading.Channels; + +namespace ModelContextProtocol.Protocol.Transport; + +/// +/// Handles processing the request/response body pairs for the Streamable HTTP transport. +/// This is typically used via . +/// +internal sealed class StreamableHttpPostTransport(ChannelWriter? incomingChannel, IDuplexPipe httpBodies) : ITransport +{ + private readonly SseWriter _sseWriter = new(); + private readonly HashSet _pendingRequests = []; + + // REVIEW: Should we introduce a send-only interface for RelatedTransport? + public ChannelReader MessageReader => throw new NotSupportedException("JsonRpcMessage.RelatedTransport should only be used for sending messages."); + + /// + /// True, if data was written to the respond body. + /// False, if nothing was written because the request body did not contain any messages to respond to. + /// The HTTP application should typically respond with an empty "202 Accepted" response in this scenario. + /// + public async ValueTask RunAsync(CancellationToken cancellationToken) + { + // The incomingChannel is null to handle the potential client GET request to handle unsolicited JsonRpcMessages. + if (incomingChannel is not null) + { + await OnPostBodyReceivedAsync(httpBodies.Input, cancellationToken).ConfigureAwait(false); + } + + if (_pendingRequests.Count == 0) + { + return false; + } + + _sseWriter.MessageFilter = StopOnFinalResponseFilter; + await _sseWriter.WriteAllAsync(httpBodies.Output.AsStream(), cancellationToken).ConfigureAwait(false); + return true; + } + + public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) + { + await _sseWriter.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); + } + + public async ValueTask DisposeAsync() + { + await _sseWriter.DisposeAsync().ConfigureAwait(false); + } + + private async IAsyncEnumerable> StopOnFinalResponseFilter(IAsyncEnumerable> messages, [EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (var message in messages.WithCancellation(cancellationToken)) + { + yield return message; + + if (message.Data is JsonRpcResponse response) + { + if (_pendingRequests.Remove(response.Id) && _pendingRequests.Count == 0) + { + // Complete the SSE response stream now that all pending requests have been processed. + break; + } + } + } + } + + private async ValueTask OnPostBodyReceivedAsync(PipeReader streamableHttpRequestBody, CancellationToken cancellationToken) + { + if (!await IsJsonArrayAsync(streamableHttpRequestBody, cancellationToken).ConfigureAwait(false)) + { + var message = await JsonSerializer.DeserializeAsync(streamableHttpRequestBody.AsStream(), McpJsonUtilities.JsonContext.Default.JsonRpcMessage, cancellationToken).ConfigureAwait(false); + await OnMessageReceivedAsync(message, cancellationToken).ConfigureAwait(false); + } + else + { + // Batched JSON-RPC message + var messages = JsonSerializer.DeserializeAsyncEnumerable(streamableHttpRequestBody.AsStream(), McpJsonUtilities.JsonContext.Default.JsonRpcMessage, cancellationToken).ConfigureAwait(false); + await foreach (var message in messages.WithCancellation(cancellationToken)) + { + await OnMessageReceivedAsync(message, cancellationToken).ConfigureAwait(false); + } + } + } + + private async ValueTask OnMessageReceivedAsync(JsonRpcMessage? message, CancellationToken cancellationToken) + { + if (message is null) + { + throw new InvalidOperationException("Received invalid null message."); + } + + if (message is JsonRpcRequest request) + { + _pendingRequests.Add(request.Id); + } + + message.RelatedTransport = this; + + // Really an assertion. This doesn't get called when incomingChannel is null for GET requests. + Throw.IfNull(incomingChannel); + await incomingChannel.WriteAsync(message, cancellationToken).ConfigureAwait(false); + } + + private async ValueTask IsJsonArrayAsync(PipeReader requestBody, CancellationToken cancellationToken) + { + // REVIEW: Should we bother trimming whitespace before checking for '['? + var firstCharacterResult = await requestBody.ReadAtLeastAsync(1, cancellationToken).ConfigureAwait(false); + + try + { + if (firstCharacterResult.Buffer.Length == 0) + { + return false; + } + + Span firstCharBuffer = stackalloc byte[1]; + firstCharacterResult.Buffer.Slice(0, 1).CopyTo(firstCharBuffer); + return firstCharBuffer[0] == (byte)'['; + } + finally + { + // Never consume data when checking for '['. System.Text.Json still needs to consume it. + requestBody.AdvanceTo(firstCharacterResult.Buffer.Start); + } + } +} diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamableHttpServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamableHttpServerTransport.cs new file mode 100644 index 000000000..42e3ff70b --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Transport/StreamableHttpServerTransport.cs @@ -0,0 +1,99 @@ +using ModelContextProtocol.Protocol.Messages; +using System.IO.Pipelines; +using System.Threading.Channels; + +namespace ModelContextProtocol.Protocol.Transport; + +/// +/// Provides an implementation using Server-Sent Events (SSE) for server-to-client communication. +/// +/// +/// +/// This transport provides one-way communication from server to client using the SSE protocol over HTTP, +/// while receiving client messages through a separate mechanism. It writes messages as +/// SSE events to a response stream, typically associated with an HTTP response. +/// +/// +/// This transport is used in scenarios where the server needs to push messages to the client in real-time, +/// such as when streaming completion results or providing progress updates during long-running operations. +/// +/// +public sealed class StreamableHttpServerTransport : ITransport +{ + // For JsonRpcMessages without a RelatedTransport, we don't want to block just because the client didn't make a GET request to handle unsolicited messages. + private readonly SseWriter _sseWriter = new(channelOptions: new BoundedChannelOptions(1) + { + SingleReader = true, + SingleWriter = false, + FullMode = BoundedChannelFullMode.DropOldest, + }); + private readonly Channel _incomingChannel = Channel.CreateBounded(new BoundedChannelOptions(1) + { + SingleReader = true, + SingleWriter = false, + }); + private readonly CancellationTokenSource _disposeCts = new(); + + private int _getRequestStarted; + + /// + /// Handles an optional SSE GET request a client using the Streamable HTTP transport might make by + /// writing any unsolicited JSON-RPC messages sent via + /// to the SSE response stream until cancellation is requested or the transport is disposed. + /// + /// The response stream to write MCP JSON-RPC messages as SSE events to. + /// The to monitor for cancellation requests. The default is . + /// A task representing the send loop that writes JSON-RPC messages to the SSE response stream. + public async Task HandleGetRequest(Stream sseResponseStream, CancellationToken cancellationToken) + { + if (Interlocked.Exchange(ref _getRequestStarted, 1) == 1) + { + throw new InvalidOperationException("Session resumption is not yet supported. Please start a new session."); + } + + using var getCts = CancellationTokenSource.CreateLinkedTokenSource(_disposeCts.Token, cancellationToken); + await _sseWriter.WriteAllAsync(sseResponseStream, getCts.Token).ConfigureAwait(false); + } + + /// + /// Handles a Streamable HTTP POST request processing both the request body and response body ensuring that + /// and other correlated messages are sent back to the client directly in response + /// to the that initiated the message. + /// + /// The duplex pipe facilitates the reading and writing of HTTP request and response data. + /// This token allows for the operation to be canceled if needed. + /// + /// True, if data was written to the respond body. + /// False, if nothing was written because the request body did not contain any messages to respond to. + /// The HTTP application should typically respond with an empty "202 Accepted" response in this scenario. + /// + public async Task HandlePostRequest(IDuplexPipe httpBodies, CancellationToken cancellationToken) + { + using var postCts = CancellationTokenSource.CreateLinkedTokenSource(_disposeCts.Token, cancellationToken); + await using var postTransport = new StreamableHttpPostTransport(_incomingChannel.Writer, httpBodies); + return await postTransport.RunAsync(postCts.Token).ConfigureAwait(false); + } + + /// + public ChannelReader MessageReader => _incomingChannel.Reader; + + /// + public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) + { + await _sseWriter.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); + } + + /// + public async ValueTask DisposeAsync() + { + _disposeCts.Cancel(); + try + { + await _sseWriter.DisposeAsync().ConfigureAwait(false); + } + finally + { + _disposeCts.Dispose(); + } + } +} diff --git a/src/ModelContextProtocol/Protocol/Transport/TransportBase.cs b/src/ModelContextProtocol/Protocol/Transport/TransportBase.cs index 0d496d0a6..af9cdaefd 100644 --- a/src/ModelContextProtocol/Protocol/Transport/TransportBase.cs +++ b/src/ModelContextProtocol/Protocol/Transport/TransportBase.cs @@ -15,13 +15,13 @@ namespace ModelContextProtocol.Protocol.Transport; /// /// /// Custom transport implementations should inherit from this class and implement the abstract -/// and methods +/// and methods /// to handle the specific transport mechanism being used. /// /// public abstract partial class TransportBase : ITransport { - private readonly Channel _messageChannel; + private readonly Channel _messageChannel; private readonly ILogger _logger; private int _isConnected; @@ -34,7 +34,7 @@ protected TransportBase(string name, ILoggerFactory? loggerFactory) _logger = loggerFactory?.CreateLogger(GetType()) ?? NullLogger.Instance; // Unbounded channel to prevent blocking on writes - _messageChannel = Channel.CreateUnbounded(new UnboundedChannelOptions + _messageChannel = Channel.CreateUnbounded(new UnboundedChannelOptions { SingleReader = true, SingleWriter = true, @@ -56,10 +56,10 @@ protected TransportBase(string name, ILoggerFactory? loggerFactory) public bool IsConnected => _isConnected == 1; /// - public ChannelReader MessageReader => _messageChannel.Reader; + public ChannelReader MessageReader => _messageChannel.Reader; /// - public abstract Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default); + public abstract Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default); /// public abstract ValueTask DisposeAsync(); @@ -69,7 +69,7 @@ protected TransportBase(string name, ILoggerFactory? loggerFactory) /// /// The message to write. /// The to monitor for cancellation requests. The default is . - protected async Task WriteMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + protected async Task WriteMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { if (!IsConnected) { diff --git a/src/ModelContextProtocol/Server/DestinationBoundMcpServer.cs b/src/ModelContextProtocol/Server/DestinationBoundMcpServer.cs new file mode 100644 index 000000000..0d86480de --- /dev/null +++ b/src/ModelContextProtocol/Server/DestinationBoundMcpServer.cs @@ -0,0 +1,37 @@ +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Protocol.Types; +using System.Diagnostics; + +namespace ModelContextProtocol.Server; + +internal sealed class DestinationBoundMcpServer(McpServer server, ITransport? transport) : IMcpServer +{ + public string EndpointName => server.EndpointName; + public ClientCapabilities? ClientCapabilities => server.ClientCapabilities; + public Implementation? ClientInfo => server.ClientInfo; + public McpServerOptions ServerOptions => server.ServerOptions; + public IServiceProvider? Services => server.Services; + public LoggingLevel? LoggingLevel => server.LoggingLevel; + + public ValueTask DisposeAsync() => server.DisposeAsync(); + + public IAsyncDisposable RegisterNotificationHandler(string method, Func handler) => server.RegisterNotificationHandler(method, handler); + + // This will throw because the server must already be running for this class to be constructed, but it should give us a good Exception message. + public Task RunAsync(CancellationToken cancellationToken) => server.RunAsync(cancellationToken); + + public Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) + { + Debug.Assert(message.RelatedTransport is null); + message.RelatedTransport = transport; + return server.SendMessageAsync(message, cancellationToken); + } + + public Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) + { + Debug.Assert(request.RelatedTransport is null); + request.RelatedTransport = transport; + return server.SendRequestAsync(request, cancellationToken); + } +} diff --git a/src/ModelContextProtocol/Server/McpServer.cs b/src/ModelContextProtocol/Server/McpServer.cs index 2f7b59a90..ae0e7afc5 100644 --- a/src/ModelContextProtocol/Server/McpServer.cs +++ b/src/ModelContextProtocol/Server/McpServer.cs @@ -7,8 +7,7 @@ using ModelContextProtocol.Utils; using ModelContextProtocol.Utils.Json; using System.Runtime.CompilerServices; - -#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously +using System.Text.Json.Serialization.Metadata; namespace ModelContextProtocol.Server; @@ -158,7 +157,7 @@ public override async ValueTask DisposeUnsynchronizedAsync() private void SetPingHandler() { - RequestHandlers.Set(RequestMethods.Ping, + SetHandler(RequestMethods.Ping, async (request, _) => new PingResult(), McpJsonUtilities.JsonContext.Default.JsonNode, McpJsonUtilities.JsonContext.Default.PingResult); @@ -167,7 +166,7 @@ private void SetPingHandler() private void SetInitializeHandler(McpServerOptions options) { RequestHandlers.Set(RequestMethods.Initialize, - async (request, _) => + async (request, _, _) => { ClientCapabilities = request?.Capabilities ?? new(); ClientInfo = request?.ClientInfo; @@ -201,9 +200,9 @@ private void SetCompletionHandler(McpServerOptions options) $"but {nameof(CompletionsCapability.CompleteHandler)} was not specified."); // This capability is not optional, so return an empty result if there is no handler. - RequestHandlers.Set( + SetHandler( RequestMethods.CompletionComplete, - (request, cancellationToken) => InvokeHandlerAsync(completeHandler, request, cancellationToken), + completeHandler, McpJsonUtilities.JsonContext.Default.CompleteRequestParams, McpJsonUtilities.JsonContext.Default.CompleteResult); } @@ -228,22 +227,22 @@ private void SetResourcesHandler(McpServerOptions options) listResourcesHandler ??= static async (_, _) => new ListResourcesResult(); - RequestHandlers.Set( + SetHandler( RequestMethods.ResourcesList, - (request, cancellationToken) => InvokeHandlerAsync(listResourcesHandler, request, cancellationToken), + listResourcesHandler, McpJsonUtilities.JsonContext.Default.ListResourcesRequestParams, McpJsonUtilities.JsonContext.Default.ListResourcesResult); - RequestHandlers.Set( + SetHandler( RequestMethods.ResourcesRead, - (request, cancellationToken) => InvokeHandlerAsync(readResourceHandler, request, cancellationToken), + readResourceHandler, McpJsonUtilities.JsonContext.Default.ReadResourceRequestParams, McpJsonUtilities.JsonContext.Default.ReadResourceResult); listResourceTemplatesHandler ??= static async (_, _) => new ListResourceTemplatesResult(); - RequestHandlers.Set( + SetHandler( RequestMethods.ResourcesTemplatesList, - (request, cancellationToken) => InvokeHandlerAsync(listResourceTemplatesHandler, request, cancellationToken), + listResourceTemplatesHandler, McpJsonUtilities.JsonContext.Default.ListResourceTemplatesRequestParams, McpJsonUtilities.JsonContext.Default.ListResourceTemplatesResult); @@ -261,15 +260,15 @@ private void SetResourcesHandler(McpServerOptions options) $"but {nameof(ResourcesCapability.SubscribeToResourcesHandler)} or {nameof(ResourcesCapability.UnsubscribeFromResourcesHandler)} was not specified."); } - RequestHandlers.Set( + SetHandler( RequestMethods.ResourcesSubscribe, - (request, cancellationToken) => InvokeHandlerAsync(subscribeHandler, request, cancellationToken), + subscribeHandler, McpJsonUtilities.JsonContext.Default.SubscribeRequestParams, McpJsonUtilities.JsonContext.Default.EmptyResult); - RequestHandlers.Set( + SetHandler( RequestMethods.ResourcesUnsubscribe, - (request, cancellationToken) => InvokeHandlerAsync(unsubscribeHandler, request, cancellationToken), + unsubscribeHandler, McpJsonUtilities.JsonContext.Default.UnsubscribeRequestParams, McpJsonUtilities.JsonContext.Default.EmptyResult); } @@ -359,15 +358,15 @@ await originalListPromptsHandler(request, cancellationToken).ConfigureAwait(fals } } - RequestHandlers.Set( + SetHandler( RequestMethods.PromptsList, - (request, cancellationToken) => InvokeHandlerAsync(listPromptsHandler, request, cancellationToken), + listPromptsHandler, McpJsonUtilities.JsonContext.Default.ListPromptsRequestParams, McpJsonUtilities.JsonContext.Default.ListPromptsResult); - RequestHandlers.Set( + SetHandler( RequestMethods.PromptsGet, - (request, cancellationToken) => InvokeHandlerAsync(getPromptHandler, request, cancellationToken), + getPromptHandler, McpJsonUtilities.JsonContext.Default.GetPromptRequestParams, McpJsonUtilities.JsonContext.Default.GetPromptResult); } @@ -457,15 +456,15 @@ await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) } } - RequestHandlers.Set( + SetHandler( RequestMethods.ToolsList, - (request, cancellationToken) => InvokeHandlerAsync(listToolsHandler, request, cancellationToken), + listToolsHandler, McpJsonUtilities.JsonContext.Default.ListToolsRequestParams, McpJsonUtilities.JsonContext.Default.ListToolsResult); - RequestHandlers.Set( + SetHandler( RequestMethods.ToolsCall, - (request, cancellationToken) => InvokeHandlerAsync(callToolHandler, request, cancellationToken), + callToolHandler, McpJsonUtilities.JsonContext.Default.CallToolRequestParams, McpJsonUtilities.JsonContext.Default.CallToolResponse); } @@ -478,7 +477,7 @@ private void SetSetLoggingLevelHandler(McpServerOptions options) RequestHandlers.Set( RequestMethods.LoggingSetLevel, - (request, cancellationToken) => + (request, destinationTransport, cancellationToken) => { // Store the provided level. if (request is not null) @@ -494,7 +493,7 @@ private void SetSetLoggingLevelHandler(McpServerOptions options) // If a handler was provided, now delegate to it. if (setLoggingLevelHandler is not null) { - return InvokeHandlerAsync(setLoggingLevelHandler, request, cancellationToken); + return InvokeHandlerAsync(setLoggingLevelHandler, request, destinationTransport, cancellationToken); } // Otherwise, consider it handled. @@ -507,11 +506,12 @@ private void SetSetLoggingLevelHandler(McpServerOptions options) private ValueTask InvokeHandlerAsync( Func, CancellationToken, ValueTask> handler, TParams? args, - CancellationToken cancellationToken) + ITransport? destinationTransport = null, + CancellationToken cancellationToken = default) { return _servicesScopePerRequest ? InvokeScopedAsync(handler, args, cancellationToken) : - handler(new(this) { Params = args }, cancellationToken); + handler(new(new DestinationBoundMcpServer(this, destinationTransport)) { Params = args }, cancellationToken); async ValueTask InvokeScopedAsync( Func, CancellationToken, ValueTask> handler, @@ -522,7 +522,7 @@ async ValueTask InvokeScopedAsync( try { return await handler( - new RequestContext(this) + new RequestContext(new DestinationBoundMcpServer(this, destinationTransport)) { Services = scope?.ServiceProvider ?? Services, Params = args @@ -539,6 +539,18 @@ async ValueTask InvokeScopedAsync( } } + private void SetHandler( + string method, + Func, CancellationToken, ValueTask> handler, + JsonTypeInfo requestTypeInfo, + JsonTypeInfo responseTypeInfo) + { + RequestHandlers.Set(method, + (request, destinationTransport, cancellationToken) => + InvokeHandlerAsync(handler, request, destinationTransport, cancellationToken), + requestTypeInfo, responseTypeInfo); + } + /// Maps a to a . internal static LoggingLevel ToLoggingLevel(LogLevel level) => level switch @@ -551,4 +563,4 @@ internal static LoggingLevel ToLoggingLevel(LogLevel level) => LogLevel.Critical => Protocol.Types.LoggingLevel.Critical, _ => Protocol.Types.LoggingLevel.Emergency, }; -} \ No newline at end of file +} diff --git a/src/ModelContextProtocol/Shared/McpEndpoint.cs b/src/ModelContextProtocol/Shared/McpEndpoint.cs index 45566cefe..394ccaa7b 100644 --- a/src/ModelContextProtocol/Shared/McpEndpoint.cs +++ b/src/ModelContextProtocol/Shared/McpEndpoint.cs @@ -45,7 +45,7 @@ protected McpEndpoint(ILoggerFactory? loggerFactory = null) public Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) => GetSessionOrThrow().SendRequestAsync(request, cancellationToken); - public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + public Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) => GetSessionOrThrow().SendMessageAsync(message, cancellationToken); public IAsyncDisposable RegisterNotificationHandler(string method, Func handler) => diff --git a/src/ModelContextProtocol/Shared/McpSession.cs b/src/ModelContextProtocol/Shared/McpSession.cs index 1b51fc8fd..56f0674c9 100644 --- a/src/ModelContextProtocol/Shared/McpSession.cs +++ b/src/ModelContextProtocol/Shared/McpSession.cs @@ -39,7 +39,7 @@ internal sealed partial class McpSession : IDisposable private readonly DistributedContextPropagator _propagator = DistributedContextPropagator.Current; /// Collection of requests sent on this session and waiting for responses. - private readonly ConcurrentDictionary> _pendingRequests = []; + private readonly ConcurrentDictionary> _pendingRequests = []; /// /// Collection of requests received on this session and currently being handled. The value provides a /// that can be used to request cancellation of the in-flight handler. @@ -47,8 +47,9 @@ internal sealed partial class McpSession : IDisposable private readonly ConcurrentDictionary _handlingRequests = new(); private readonly ILogger _logger; - private readonly string _id = Guid.NewGuid().ToString("N"); - private long _nextRequestId; + // This _sessionId is solely used to identify the session in telemetry and logs. + private readonly string _sessionId = Guid.NewGuid().ToString("N"); + private long _lastRequestId; /// /// Initializes a new instance of the class. @@ -74,6 +75,7 @@ public McpSession( StdioClientSessionTransport or StdioServerTransport => "stdio", StreamClientSessionTransport or StreamServerTransport => "stream", SseClientSessionTransport or SseResponseStreamTransport => "sse", + StreamableHttpServerTransport or StreamableHttpPostTransport => "http", _ => "unknownTransport" }; @@ -102,10 +104,11 @@ public async Task ProcessMessagesAsync(CancellationToken cancellationToken) { LogMessageRead(EndpointName, message.GetType().Name); + // Fire and forget the message handling to avoid blocking the transport. _ = ProcessMessageAsync(); async Task ProcessMessageAsync() { - IJsonRpcMessageWithId? messageWithId = message as IJsonRpcMessageWithId; + JsonRpcMessageWithId? messageWithId = message as JsonRpcMessageWithId; CancellationTokenSource? combinedCts = null; try { @@ -118,10 +121,8 @@ async Task ProcessMessageAsync() _handlingRequests[messageWithId.Id] = combinedCts; } - // Fire and forget the message handling to avoid blocking the transport - // If awaiting the task, the transport will not be able to read more messages, - // which could lead to a deadlock if the handler sends a message back - + // If we await the handler without yielding first, the transport may not be able to read more messages, + // which could lead to a deadlock if the handler sends a message back. #if NET await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding); #else @@ -155,18 +156,19 @@ ex is OperationCanceledException && Message = "An error occurred.", }; - await _transport.SendMessageAsync(new JsonRpcError + await SendMessageAsync(new JsonRpcError { Id = request.Id, JsonRpc = "2.0", Error = detail, + RelatedTransport = request.RelatedTransport, }, cancellationToken).ConfigureAwait(false); } else if (ex is not OperationCanceledException) { if (_logger.IsEnabled(LogLevel.Trace)) { - LogMessageHandlerExceptionSensitive(EndpointName, message.GetType().Name, JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.IJsonRpcMessage), ex); + LogMessageHandlerExceptionSensitive(EndpointName, message.GetType().Name, JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage), ex); } else { @@ -200,7 +202,7 @@ await _transport.SendMessageAsync(new JsonRpcError } } - private async Task HandleMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken) + private async Task HandleMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken) { Histogram durationMetric = _isServer ? s_serverOperationDuration : s_clientOperationDuration; string method = GetMethodName(message); @@ -235,7 +237,7 @@ private async Task HandleMessageAsync(IJsonRpcMessage message, CancellationToken await HandleNotification(notification, cancellationToken).ConfigureAwait(false); break; - case IJsonRpcMessageWithId messageWithId: + case JsonRpcMessageWithId messageWithId: HandleMessageWithId(message, messageWithId); break; @@ -279,7 +281,7 @@ private async Task HandleNotification(JsonRpcNotification notification, Cancella await _notificationHandlers.InvokeHandlers(notification.Method, notification, cancellationToken).ConfigureAwait(false); } - private void HandleMessageWithId(IJsonRpcMessage message, IJsonRpcMessageWithId messageWithId) + private void HandleMessageWithId(JsonRpcMessage message, JsonRpcMessageWithId messageWithId) { if (_pendingRequests.TryRemove(messageWithId.Id, out var tcs)) { @@ -303,17 +305,17 @@ private void HandleMessageWithId(IJsonRpcMessage message, IJsonRpcMessageWithId JsonNode? result = await handler(request, cancellationToken).ConfigureAwait(false); LogRequestHandlerCompleted(EndpointName, request.Method); - await _transport.SendMessageAsync(new JsonRpcResponse + await SendMessageAsync(new JsonRpcResponse { Id = request.Id, - JsonRpc = "2.0", - Result = result + Result = result, + RelatedTransport = request.RelatedTransport, }, cancellationToken).ConfigureAwait(false); return result; } - private CancellationTokenRegistration RegisterCancellation(CancellationToken cancellationToken, RequestId requestId) + private CancellationTokenRegistration RegisterCancellation(CancellationToken cancellationToken, JsonRpcRequest request) { if (!cancellationToken.CanBeCanceled) { @@ -322,13 +324,14 @@ private CancellationTokenRegistration RegisterCancellation(CancellationToken can return cancellationToken.Register(static objState => { - var state = (Tuple)objState!; + var state = (Tuple)objState!; _ = state.Item1.SendMessageAsync(new JsonRpcNotification { Method = NotificationMethods.CancelledNotification, - Params = JsonSerializer.SerializeToNode(new CancelledNotification { RequestId = state.Item2 }, McpJsonUtilities.JsonContext.Default.CancelledNotification) + Params = JsonSerializer.SerializeToNode(new CancelledNotification { RequestId = state.Item2.Id }, McpJsonUtilities.JsonContext.Default.CancelledNotification), + RelatedTransport = state.Item2.RelatedTransport, }); - }, Tuple.Create(this, requestId)); + }, Tuple.Create(this, request)); } public IAsyncDisposable RegisterNotificationHandler(string method, Func handler) @@ -349,11 +352,6 @@ public IAsyncDisposable RegisterNotificationHandler(string method, FuncA task containing the server's response. public async Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken) { - if (!_transport.IsConnected) - { - throw new InvalidOperationException("Transport is not connected"); - } - cancellationToken.ThrowIfCancellationRequested(); Histogram durationMetric = _isServer ? s_serverOperationDuration : s_clientOperationDuration; @@ -367,7 +365,7 @@ public async Task SendRequestAsync(JsonRpcRequest request, Canc // Set request ID if (request.Id.Id is null) { - request.Id = new RequestId($"{_id}-{Interlocked.Increment(ref _nextRequestId)}"); + request = request.WithId(new RequestId(Interlocked.Increment(ref _lastRequestId))); } _propagator.InjectActivityContext(activity, request); @@ -375,7 +373,7 @@ public async Task SendRequestAsync(JsonRpcRequest request, Canc TagList tags = default; bool addTags = activity is { IsAllDataRequested: true } || startingTimestamp is not null; - var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); _pendingRequests[request.Id] = tcs; try { @@ -386,21 +384,21 @@ public async Task SendRequestAsync(JsonRpcRequest request, Canc if (_logger.IsEnabled(LogLevel.Trace)) { - LogSendingRequestSensitive(EndpointName, request.Method, JsonSerializer.Serialize(request, McpJsonUtilities.JsonContext.Default.IJsonRpcMessage)); + LogSendingRequestSensitive(EndpointName, request.Method, JsonSerializer.Serialize(request, McpJsonUtilities.JsonContext.Default.JsonRpcMessage)); } else { LogSendingRequest(EndpointName, request.Method); } - await _transport.SendMessageAsync(request, cancellationToken).ConfigureAwait(false); + await SendToRelatedTransportAsync(request, cancellationToken).ConfigureAwait(false); // Now that the request has been sent, register for cancellation. If we registered before, // a cancellation request could arrive before the server knew about that request ID, in which // case the server could ignore it. LogRequestSentAwaitingResponse(EndpointName, request.Method, request.Id); - IJsonRpcMessage? response; - using (var registration = RegisterCancellation(cancellationToken, request.Id)) + JsonRpcMessage? response; + using (var registration = RegisterCancellation(cancellationToken, request)) { response = await tcs.Task.WaitAsync(cancellationToken).ConfigureAwait(false); } @@ -446,15 +444,10 @@ public async Task SendRequestAsync(JsonRpcRequest request, Canc } } - public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { Throw.IfNull(message); - if (!_transport.IsConnected) - { - throw new InvalidOperationException("Transport is not connected"); - } - cancellationToken.ThrowIfCancellationRequested(); Histogram durationMetric = _isServer ? s_serverOperationDuration : s_clientOperationDuration; @@ -480,14 +473,14 @@ public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken ca if (_logger.IsEnabled(LogLevel.Trace)) { - LogSendingMessageSensitive(EndpointName, JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.IJsonRpcMessage)); + LogSendingMessageSensitive(EndpointName, JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage)); } else { LogSendingMessage(EndpointName); } - await _transport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); + await SendToRelatedTransportAsync(message, cancellationToken).ConfigureAwait(false); // If the sent notification was a cancellation notification, cancel the pending request's await, as either the // server won't be sending a response, or per the specification, the response should be ignored. There are inherent @@ -510,6 +503,12 @@ public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken ca } } + // The JsonRpcMessage should be sent over the RelatedTransport if set. This is used to support the + // Streamable HTTP transport where the specification states that the server SHOULD include JSON-RPC responses in + // the HTTP response body for the POST request containing the corresponding JSON-RPC request. + private Task SendToRelatedTransportAsync(JsonRpcMessage message, CancellationToken cancellationToken) + => (message.RelatedTransport ?? _transport).SendMessageAsync(message, cancellationToken); + private static CancelledNotification? GetCancelledNotificationParams(JsonNode? notificationParams) { try @@ -524,7 +523,7 @@ public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken ca private string CreateActivityName(string method) => method; - private static string GetMethodName(IJsonRpcMessage message) => + private static string GetMethodName(JsonRpcMessage message) => message switch { JsonRpcRequest request => request.Method, @@ -532,7 +531,7 @@ private static string GetMethodName(IJsonRpcMessage message) => _ => "unknownMethod" }; - private void AddTags(ref TagList tags, Activity? activity, IJsonRpcMessage message, string method) + private void AddTags(ref TagList tags, Activity? activity, JsonRpcMessage message, string method) { tags.Add("mcp.method.name", method); tags.Add("network.transport", _transportKind); @@ -543,9 +542,9 @@ private void AddTags(ref TagList tags, Activity? activity, IJsonRpcMessage messa if (activity is { IsAllDataRequested: true }) { // session and request id have high cardinality, so not applying to metric tags - activity.AddTag("mcp.session.id", _id); + activity.AddTag("mcp.session.id", _sessionId); - if (message is IJsonRpcMessageWithId withId) + if (message is JsonRpcMessageWithId withId) { activity.AddTag("mcp.request.id", withId.Id.Id?.ToString()); } diff --git a/src/ModelContextProtocol/Shared/RequestHandlers.cs b/src/ModelContextProtocol/Shared/RequestHandlers.cs index 184fd9077..93a3dbbf4 100644 --- a/src/ModelContextProtocol/Shared/RequestHandlers.cs +++ b/src/ModelContextProtocol/Shared/RequestHandlers.cs @@ -1,4 +1,5 @@ using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Utils; using System.Text.Json; using System.Text.Json.Nodes; @@ -30,7 +31,7 @@ internal sealed class RequestHandlers : Dictionary public void Set( string method, - Func> handler, + Func> handler, JsonTypeInfo requestTypeInfo, JsonTypeInfo responseTypeInfo) { @@ -42,7 +43,7 @@ public void Set( this[method] = async (request, cancellationToken) => { TRequest? typedRequest = JsonSerializer.Deserialize(request.Params, requestTypeInfo); - object? result = await handler(typedRequest, cancellationToken).ConfigureAwait(false); + object? result = await handler(typedRequest, request.RelatedTransport, cancellationToken).ConfigureAwait(false); return JsonSerializer.SerializeToNode(result, responseTypeInfo); }; } diff --git a/src/ModelContextProtocol/Utils/Json/JsonRpcMessageConverter.cs b/src/ModelContextProtocol/Utils/Json/JsonRpcMessageConverter.cs index 54fd4be05..146185cac 100644 --- a/src/ModelContextProtocol/Utils/Json/JsonRpcMessageConverter.cs +++ b/src/ModelContextProtocol/Utils/Json/JsonRpcMessageConverter.cs @@ -6,7 +6,7 @@ namespace ModelContextProtocol.Utils.Json; /// -/// Provides a for messages, +/// Provides a for messages, /// handling polymorphic deserialization of different message types. /// /// @@ -26,10 +26,10 @@ namespace ModelContextProtocol.Utils.Json; /// /// [EditorBrowsable(EditorBrowsableState.Never)] -public sealed class JsonRpcMessageConverter : JsonConverter +public sealed class JsonRpcMessageConverter : JsonConverter { /// - public override IJsonRpcMessage? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + public override JsonRpcMessage? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { if (reader.TokenType != JsonTokenType.StartObject) { @@ -87,7 +87,7 @@ public sealed class JsonRpcMessageConverter : JsonConverter } /// - public override void Write(Utf8JsonWriter writer, IJsonRpcMessage value, JsonSerializerOptions options) + public override void Write(Utf8JsonWriter writer, JsonRpcMessage value, JsonSerializerOptions options) { switch (value) { diff --git a/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs b/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs index 1c6dbd9c0..b759ba975 100644 --- a/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs +++ b/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs @@ -82,7 +82,8 @@ internal static bool IsValidMcpToolSchema(JsonElement element) NumberHandling = JsonNumberHandling.AllowReadingFromString)] // JSON-RPC - [JsonSerializable(typeof(IJsonRpcMessage))] + [JsonSerializable(typeof(JsonRpcMessage))] + [JsonSerializable(typeof(JsonRpcMessage[]))] [JsonSerializable(typeof(JsonRpcRequest))] [JsonSerializable(typeof(JsonRpcNotification))] [JsonSerializable(typeof(JsonRpcResponse))] diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs index f7cb2c8a6..2f4cd7fbe 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs @@ -13,7 +13,7 @@ namespace ModelContextProtocol.AspNetCore.Tests; public class MapMcpTests(ITestOutputHelper testOutputHelper) : KestrelInMemoryTest(testOutputHelper) { - private async Task ConnectAsync(string? path = null) + private async Task ConnectAsync(string? path = "/sse") { var sseClientTransportOptions = new SseClientTransportOptions() { @@ -48,9 +48,7 @@ public async Task Allows_Customizing_Route() [Theory] [InlineData("/a", "/a/sse")] - [InlineData("/a", "/a/")] [InlineData("/a/", "/a/sse")] - [InlineData("/a/", "/a/")] public async Task CanConnect_WithMcpClient_AfterCustomizingRoute(string routePattern, string requestPath) { Builder.Services.AddMcpServer(options => diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs index 8cb5cef1f..b659ff172 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs @@ -198,7 +198,7 @@ public async Task AdditionalHeaders_AreSent_InGetAndPostRequests() } [Fact] - public async Task EmptyAdditionalHeadersKey_Throws_InvalidOpearionException() + public async Task EmptyAdditionalHeadersKey_Throws_InvalidOperationException() { Builder.Services.AddMcpServer() .WithHttpTransport(); @@ -269,7 +269,7 @@ private static void MapAbsoluteEndpointUriMcp(IEndpointRouteBuilder endpoints) await Results.BadRequest("Session not started.").ExecuteAsync(context); return; } - var message = (IJsonRpcMessage?)await context.Request.ReadFromJsonAsync(McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage)), context.RequestAborted); + var message = (JsonRpcMessage?)await context.Request.ReadFromJsonAsync(McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage)), context.RequestAborted); if (message is null) { await Results.BadRequest("No message in request body.").ExecuteAsync(context); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpTests.cs new file mode 100644 index 000000000..cf3aa4f41 --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpTests.cs @@ -0,0 +1,510 @@ +using Microsoft.AspNetCore.Builder; +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.AspNetCore.Tests.Utils; +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Server; +using ModelContextProtocol.Utils.Json; +using System.Net; +using System.Net.ServerSentEvents; +using System.Text; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization.Metadata; + +namespace ModelContextProtocol.AspNetCore.Tests; + +public class StreamableHttpTests(ITestOutputHelper outputHelper) : KestrelInMemoryTest(outputHelper), IAsyncDisposable +{ + private static McpServerTool[] Tools { get; } = [ + McpServerTool.Create(EchoAsync), + McpServerTool.Create(LongRunningAsync), + McpServerTool.Create(Progress), + McpServerTool.Create(Throw), + ]; + + private WebApplication? _app; + + private async Task StartAsync() + { + AddDefaultHttpClientRequestHeaders(); + + Builder.Services.AddMcpServer(options => + { + options.ServerInfo = new Implementation + { + Name = nameof(StreamableHttpTests), + Version = "73", + }; + }).WithTools(Tools).WithHttpTransport(); + + _app = Builder.Build(); + + _app.MapMcp(); + + await _app.StartAsync(TestContext.Current.CancellationToken); + } + + public async ValueTask DisposeAsync() + { + if (_app is not null) + { + await _app.DisposeAsync(); + } + base.Dispose(); + } + + [Fact] + public async Task InitialPostResponse_Includes_McpSessionIdHeader() + { + await StartAsync(); + + using var response = await HttpClient.PostAsync("", JsonContent(InitializeRequest), TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + Assert.Single(response.Headers.GetValues("mcp-session-id")); + Assert.Equal("text/event-stream", Assert.Single(response.Content.Headers.GetValues("content-type"))); + } + + [Fact] + public async Task PostRequest_IsUnsupportedMediaType_WithoutJsonContentType() + { + await StartAsync(); + + using var response = await HttpClient.PostAsync("", new StringContent(InitializeRequest, Encoding.UTF8, "text/javascript"), TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.UnsupportedMediaType, response.StatusCode); + } + + [Fact] + public async Task PostRequest_IsNotAcceptable_WithoutApplicationJsonAcceptHeader() + { + await StartAsync(); + + HttpClient.DefaultRequestHeaders.Accept.Clear(); + HttpClient.DefaultRequestHeaders.Accept.Add(new("text/event-stream")); + + using var response = await HttpClient.PostAsync("", JsonContent(InitializeRequest), TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.NotAcceptable, response.StatusCode); + } + + + [Fact] + public async Task PostRequest_IsNotAcceptable_WithoutTextEventStreamAcceptHeader() + { + await StartAsync(); + + HttpClient.DefaultRequestHeaders.Accept.Clear(); + HttpClient.DefaultRequestHeaders.Accept.Add(new("text/json")); + + using var response = await HttpClient.PostAsync("", JsonContent(InitializeRequest), TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.NotAcceptable, response.StatusCode); + } + + [Fact] + public async Task GetRequest_IsNotAcceptable_WithoutTextEventStreamAcceptHeader() + { + await StartAsync(); + + HttpClient.DefaultRequestHeaders.Accept.Clear(); + HttpClient.DefaultRequestHeaders.Accept.Add(new("text/json")); + + using var response = await HttpClient.GetAsync("", TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.NotAcceptable, response.StatusCode); + } + + [Fact] + public async Task PostRequest_IsNotFound_WithUnrecognizedSessionId() + { + await StartAsync(); + + using var request = new HttpRequestMessage(HttpMethod.Post, "") + { + Content = JsonContent(EchoRequest), + Headers = + { + { "mcp-session-id", "fakeSession" }, + }, + }; + using var response = await HttpClient.SendAsync(request, TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.NotFound, response.StatusCode); + } + + [Fact] + public async Task InitializeRequest_Matches_CustomRoute() + { + AddDefaultHttpClientRequestHeaders(); + Builder.Services.AddMcpServer().WithHttpTransport(); + await using var app = Builder.Build(); + + app.MapMcp("/custom-route"); + + await app.StartAsync(TestContext.Current.CancellationToken); + + using var response = await HttpClient.PostAsync("/custom-route", JsonContent(InitializeRequest), TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + + [Fact] + public async Task PostWithSingleNotification_IsAccepted_WithEmptyResponse() + { + await StartAsync(); + await CallInitializeAndValidateAsync(); + + var response = await HttpClient.PostAsync("", JsonContent(ProgressNotification("1")), TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.Accepted, response.StatusCode); + Assert.Equal("", await response.Content.ReadAsStringAsync(TestContext.Current.CancellationToken)); + } + + [Fact] + public async Task InitializeJsonRpcRequest_IsHandled_WithCompleteSseResponse() + { + await StartAsync(); + await CallInitializeAndValidateAsync(); + } + + [Fact] + public async Task BatchedJsonRpcRequests_IsHandled_WithCompleteSseResponse() + { + await StartAsync(); + + using var response = await HttpClient.PostAsync("", JsonContent($"[{InitializeRequest},{EchoRequest}]"), TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + + var eventCount = 0; + await foreach (var sseEvent in ReadSseAsync(response.Content)) + { + var jsonRpcResponse = JsonSerializer.Deserialize(sseEvent, GetJsonTypeInfo()); + Assert.NotNull(jsonRpcResponse); + var responseId = Assert.IsType(jsonRpcResponse.Id.Id); + + switch (responseId) + { + case 1: + AssertServerInfo(jsonRpcResponse); + break; + case 2: + AssertEchoResponse(jsonRpcResponse); + break; + default: + throw new Exception($"Unexpected response ID: {jsonRpcResponse.Id}"); + } + + eventCount++; + } + + Assert.Equal(2, eventCount); + } + + [Fact] + public async Task SingleJsonRpcRequest_ThatThrowsIsHandled_WithCompleteSseResponse() + { + await StartAsync(); + await CallInitializeAndValidateAsync(); + + var response = await HttpClient.PostAsync("", JsonContent(CallTool("throw")), TestContext.Current.CancellationToken); + var rpcError = await AssertSingleSseResponseAsync(response); + + var error = AssertType(rpcError.Result); + var content = Assert.Single(error.Content); + Assert.Contains("'throw'", content.Text); + } + + [Fact] + public async Task MultipleSerialJsonRpcRequests_IsHandled_OneAtATime() + { + await StartAsync(); + + await CallInitializeAndValidateAsync(); + await CallEchoAndValidateAsync(); + await CallEchoAndValidateAsync(); + } + + [Fact] + public async Task MultipleConcurrentJsonRpcRequests_IsHandled_InParallel() + { + await StartAsync(); + + await CallInitializeAndValidateAsync(); + + var echoTasks = new Task[100]; + for (int i = 0; i < echoTasks.Length; i++) + { + echoTasks[i] = CallEchoAndValidateAsync(); + } + + await Task.WhenAll(echoTasks); + } + + [Fact] + public async Task GetRequest_Receives_UnsolicitedNotifications() + { + IMcpServer? server = null; + + Builder.Services.AddMcpServer() + .WithHttpTransport(options => + { + options.RunSessionHandler = (httpContext, mcpServer, cancellationToken) => + { + server = mcpServer; + return mcpServer.RunAsync(cancellationToken); + }; + }); + + await StartAsync(); + + await CallInitializeAndValidateAsync(); + Assert.NotNull(server); + + // Headers should be sent even before any messages are ready on the GET endpoint. + using var getResponse = await HttpClient.GetAsync("", HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken); + async Task GetFirstNotificationAsync() + { + await foreach (var sseEvent in ReadSseAsync(getResponse.Content)) + { + var notification = JsonSerializer.Deserialize(sseEvent, GetJsonTypeInfo()); + Assert.NotNull(notification); + return notification.Method; + } + + throw new Exception("No notifications received."); + } + + await server.SendNotificationAsync("test-method", TestContext.Current.CancellationToken); + Assert.Equal("test-method", await GetFirstNotificationAsync()); + } + + [Fact] + public async Task SecondGetRequests_IsRejected_AsBadRequest() + { + await StartAsync(); + + await CallInitializeAndValidateAsync(); + using var getResponse1 = await HttpClient.GetAsync("", HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken); + using var getResponse2 = await HttpClient.GetAsync("", HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken); + + Assert.Equal(HttpStatusCode.OK, getResponse1.StatusCode); + Assert.Equal(HttpStatusCode.BadRequest, getResponse2.StatusCode); + } + + [Fact] + public async Task DeleteRequest_CompletesSession_WhichIsNoLongerFound() + { + await StartAsync(); + + await CallInitializeAndValidateAsync(); + await CallEchoAndValidateAsync(); + await HttpClient.DeleteAsync("", TestContext.Current.CancellationToken); + + using var response = await HttpClient.PostAsync("", JsonContent(EchoRequest), TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.NotFound, response.StatusCode); + } + + [Fact] + public async Task DeleteRequest_CompletesSession_WhichCancelsLongRunningToolCalls() + { + await StartAsync(); + + await CallInitializeAndValidateAsync(); + + Task CallLongRunningToolAsync() => + HttpClient.PostAsync("", JsonContent(CallTool("long-running")), TestContext.Current.CancellationToken); + + var longRunningToolTasks = new Task[10]; + for (int i = 0; i < longRunningToolTasks.Length; i++) + { + longRunningToolTasks[i] = CallLongRunningToolAsync(); + Assert.False(longRunningToolTasks[i].IsCompleted); + } + await HttpClient.DeleteAsync("", TestContext.Current.CancellationToken); + + // Currently, the OCE thrown by the canceled session is unhandled and turned into a 500 error by Kestrel. + // The spec suggests sending CancelledNotifications. That would be good, but we can do that later. + // For now, the important thing is that request completes without indicating success. + await Task.WhenAll(longRunningToolTasks); + foreach (var task in longRunningToolTasks) + { + var response = await task; + Assert.False(response.IsSuccessStatusCode); + } + } + + [Fact] + public async Task Progress_IsReported_InSameSseResponseAsRpcResponse() + { + await StartAsync(); + + await CallInitializeAndValidateAsync(); + + using var response = await HttpClient.PostAsync("", JsonContent(CallToolWithProgressToken("progress")), TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + + var currentSseItem = 0; + await foreach (var sseEvent in ReadSseAsync(response.Content)) + { + currentSseItem++; + + if (currentSseItem <= 10) + { + var notification = JsonSerializer.Deserialize(sseEvent, GetJsonTypeInfo()); + var progressNotification = AssertType(notification?.Params); + Assert.Equal($"Progress {currentSseItem - 1}", progressNotification.Progress.Message); + } + else + { + var rpcResponse = JsonSerializer.Deserialize(sseEvent, GetJsonTypeInfo()); + var callToolResponse = AssertType(rpcResponse?.Result); + var callToolContent = Assert.Single(callToolResponse.Content); + Assert.Equal("text", callToolContent.Type); + Assert.Equal("done", callToolContent.Text); + } + } + + Assert.Equal(11, currentSseItem); + } + + private void AddDefaultHttpClientRequestHeaders() + { + HttpClient.DefaultRequestHeaders.Accept.Add(new("application/json")); + HttpClient.DefaultRequestHeaders.Accept.Add(new("text/event-stream")); + } + + private static StringContent JsonContent(string json) => new StringContent(json, Encoding.UTF8, "application/json"); + private static JsonTypeInfo GetJsonTypeInfo() => (JsonTypeInfo)McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(T)); + + private static T AssertType(JsonNode? jsonNode) + { + var type = JsonSerializer.Deserialize(jsonNode, GetJsonTypeInfo()); + Assert.NotNull(type); + return type; + } + + private static async IAsyncEnumerable ReadSseAsync(HttpContent responseContent) + { + var responseStream = await responseContent.ReadAsStreamAsync(TestContext.Current.CancellationToken); + await foreach (var sseItem in SseParser.Create(responseStream).EnumerateAsync(TestContext.Current.CancellationToken)) + { + Assert.Equal("message", sseItem.EventType); + yield return sseItem.Data; + } + } + + private static async Task AssertSingleSseResponseAsync(HttpResponseMessage response) + { + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + Assert.Equal("text/event-stream", response.Content.Headers.ContentType?.MediaType); + + var sseItem = Assert.Single(await ReadSseAsync(response.Content).ToListAsync(TestContext.Current.CancellationToken)); + var jsonRpcResponse = JsonSerializer.Deserialize(sseItem, GetJsonTypeInfo()); + + Assert.NotNull(jsonRpcResponse); + return jsonRpcResponse; + } + + private static string InitializeRequest => """ + {"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{},"clientInfo":{"name":"IntegrationTestClient","version":"1.0.0"}}} + """; + + private long _lastRequestId = 1; + private string EchoRequest + { + get + { + var id = Interlocked.Increment(ref _lastRequestId); + return $$$$""" + {"jsonrpc":"2.0","id":{{{{id}}}},"method":"tools/call","params":{"name":"echo","arguments":{"message":"Hello world! ({{{{id}}}})"}}} + """; + } + } + + private string ProgressNotification(string progress) + { + return $$$""" + {"jsonrpc":"2.0","method":"notifications/progress","params":{"progressToken":"","progress":{{{progress}}}}} + """; + } + + private string Request(string method, string parameters = "{}") + { + var id = Interlocked.Increment(ref _lastRequestId); + return $$""" + {"jsonrpc":"2.0","id":{{id}},"method":"{{method}}","params":{{parameters}}} + """; + } + + private string CallTool(string toolName, string arguments = "{}") => + Request("tools/call", $$""" + {"name":"{{toolName}}","arguments":{{arguments}}} + """); + + private string CallToolWithProgressToken(string toolName, string arguments = "{}") => + Request("tools/call", $$$""" + {"name":"{{{toolName}}}","arguments":{{{arguments}}}, "_meta":{"progressToken": "abc123"}} + """); + + private static InitializeResult AssertServerInfo(JsonRpcResponse rpcResponse) + { + var initializeResult = AssertType(rpcResponse.Result); + Assert.Equal(nameof(StreamableHttpTests), initializeResult.ServerInfo.Name); + Assert.Equal("73", initializeResult.ServerInfo.Version); + return initializeResult; + } + + private static CallToolResponse AssertEchoResponse(JsonRpcResponse rpcResponse) + { + var callToolResponse = AssertType(rpcResponse.Result); + var callToolContent = Assert.Single(callToolResponse.Content); + Assert.Equal("text", callToolContent.Type); + Assert.Equal($"Hello world! ({rpcResponse.Id})", callToolContent.Text); + return callToolResponse; + } + + private async Task CallInitializeAndValidateAsync() + { + using var response = await HttpClient.PostAsync("", JsonContent(InitializeRequest), TestContext.Current.CancellationToken); + var rpcResponse = await AssertSingleSseResponseAsync(response); + AssertServerInfo(rpcResponse); + + var sessionId = Assert.Single(response.Headers.GetValues("mcp-session-id")); + HttpClient.DefaultRequestHeaders.Add("mcp-session-id", sessionId); + } + + private async Task CallEchoAndValidateAsync() + { + using var response = await HttpClient.PostAsync("", JsonContent(EchoRequest), TestContext.Current.CancellationToken); + var rpcResponse = await AssertSingleSseResponseAsync(response); + AssertEchoResponse(rpcResponse); + } + + [McpServerTool(Name = "echo")] + private static async Task EchoAsync(string message) + { + // McpSession.ProcessMessagesAsync() already yields before calling any handlers, but this makes it even + // more explicit that we're not relying on synchronous execution of the tool. + await Task.Yield(); + return message; + } + + [McpServerTool(Name = "long-running")] + private static async Task LongRunningAsync(CancellationToken cancellation) + { + // McpSession.ProcessMessagesAsync() already yields before calling any handlers, but this makes it even + // more explicit that we're not relying on synchronous execution of the tool. + await Task.Delay(Timeout.Infinite, cancellation); + } + + [McpServerTool(Name = "progress")] + public static string Progress(IProgress progress) + { + for (int i = 0; i < 10; i++) + { + progress.Report(new() { Progress = i, Total = 10, Message = $"Progress {i}" }); + } + + return "done"; + } + + [McpServerTool(Name = "throw")] + private static void Throw() + { + throw new Exception(); + } +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs index b10a0d674..597b39de7 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs @@ -27,7 +27,11 @@ public KestrelInMemoryTest(ITestOutputHelper testOutputHelper) var connection = _inMemoryTransport.CreateConnection(); return new(connection.ClientStream); }, - }); + }) + { + BaseAddress = new Uri("http://localhost:5000/"), + Timeout = TimeSpan.FromSeconds(10), + }; } public WebApplicationBuilder Builder { get; } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/TestServerTransport.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/TestServerTransport.cs index 0c9cb4377..a221b8a38 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/TestServerTransport.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/TestServerTransport.cs @@ -8,19 +8,19 @@ namespace ModelContextProtocol.Tests.Utils; public class TestServerTransport : ITransport { - private readonly Channel _messageChannel; + private readonly Channel _messageChannel; public bool IsConnected { get; set; } - public ChannelReader MessageReader => _messageChannel; + public ChannelReader MessageReader => _messageChannel; - public List SentMessages { get; } = []; + public List SentMessages { get; } = []; - public Action? OnMessageSent { get; set; } + public Action? OnMessageSent { get; set; } public TestServerTransport() { - _messageChannel = Channel.CreateUnbounded(new UnboundedChannelOptions + _messageChannel = Channel.CreateUnbounded(new UnboundedChannelOptions { SingleReader = true, SingleWriter = true, @@ -35,7 +35,7 @@ public ValueTask DisposeAsync() return ValueTask.CompletedTask; } - public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { SentMessages.Add(message); if (message is JsonRpcRequest request) @@ -76,7 +76,7 @@ await WriteMessageAsync(new JsonRpcResponse }, cancellationToken); } - private async Task WriteMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + private async Task WriteMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { await _messageChannel.Writer.WriteAsync(message, cancellationToken); } diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs index 454222e6b..4b3fe3a97 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs @@ -106,11 +106,11 @@ public async Task CreateAsync_WithCapabilitiesOptions(Type transportType) private class NopTransport : ITransport, IClientTransport { - private readonly Channel _channel = Channel.CreateUnbounded(); + private readonly Channel _channel = Channel.CreateUnbounded(); public bool IsConnected => true; - public ChannelReader MessageReader => _channel.Reader; + public ChannelReader MessageReader => _channel.Reader; public Task ConnectAsync(CancellationToken cancellationToken = default) => Task.FromResult(this); @@ -118,7 +118,7 @@ private class NopTransport : ITransport, IClientTransport public string Name => "Test Nop Transport"; - public virtual Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + public virtual Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { switch (message) { @@ -148,7 +148,7 @@ private sealed class FailureTransport : NopTransport { public const string ExpectedMessage = "Something failed"; - public override Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + public override Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { throw new InvalidOperationException(ExpectedMessage); } diff --git a/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs b/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs index 1d42c3d8a..0d18667e9 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs @@ -33,7 +33,7 @@ public async Task RegistrationsAreRemovedWhenDisposed() } } - Assert.Equal(Iterations, counter); + Assert.Equal(Iterations, counter); } [Fact] diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index d7ccb6e15..958fe124d 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -619,7 +619,7 @@ public Task SendRequestAsync(JsonRpcRequest request, Cancellati public McpServerOptions ServerOptions => throw new NotImplementedException(); public IServiceProvider? Services => throw new NotImplementedException(); public LoggingLevel? LoggingLevel => throw new NotImplementedException(); - public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) => + public Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) => throw new NotImplementedException(); public Task RunAsync(CancellationToken cancellationToken = default) => throw new NotImplementedException(); diff --git a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs index 6cfdd06cb..1f74a9565 100644 --- a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs @@ -216,6 +216,7 @@ public async Task DisposeAsync_Should_Dispose_Resources() await session.DisposeAsync(); - Assert.False(session.IsConnected); + var transportBase = Assert.IsAssignableFrom(session); + Assert.False(transportBase.IsConnected); } } \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs b/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs index 570708d54..aa01a894d 100644 --- a/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs +++ b/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs @@ -9,19 +9,19 @@ namespace ModelContextProtocol.Tests.Utils; public class TestServerTransport : ITransport { - private readonly Channel _messageChannel; + private readonly Channel _messageChannel; public bool IsConnected { get; set; } - public ChannelReader MessageReader => _messageChannel; + public ChannelReader MessageReader => _messageChannel; - public List SentMessages { get; } = []; + public List SentMessages { get; } = []; - public Action? OnMessageSent { get; set; } + public Action? OnMessageSent { get; set; } public TestServerTransport() { - _messageChannel = Channel.CreateUnbounded(new UnboundedChannelOptions + _messageChannel = Channel.CreateUnbounded(new UnboundedChannelOptions { SingleReader = true, SingleWriter = true, @@ -36,7 +36,7 @@ public ValueTask DisposeAsync() return default; } - public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { SentMessages.Add(message); if (message is JsonRpcRequest request) @@ -77,7 +77,7 @@ await WriteMessageAsync(new JsonRpcResponse }, cancellationToken); } - private async Task WriteMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + private async Task WriteMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { await _messageChannel.Writer.WriteAsync(message, cancellationToken); }