diff --git a/src/ModelContextProtocol.Core/McpErrorCode.cs b/src/ModelContextProtocol.Core/McpErrorCode.cs index f6cf4f51..096e830b 100644 --- a/src/ModelContextProtocol.Core/McpErrorCode.cs +++ b/src/ModelContextProtocol.Core/McpErrorCode.cs @@ -46,4 +46,12 @@ public enum McpErrorCode /// This error is used when the endpoint encounters an unexpected condition that prevents it from fulfilling the request. /// InternalError = -32603, + + /// +    /// Indicates that the request was cancelled by the client. +    /// +    /// +    /// This error is returned when the CancellationToken passed with the request is cancelled before processing completes. +    /// +    RequestCancelled = -32800, } diff --git a/src/ModelContextProtocol.Core/McpSessionHandler.cs b/src/ModelContextProtocol.Core/McpSessionHandler.cs index fcd7980d..a5782872 100644 --- a/src/ModelContextProtocol.Core/McpSessionHandler.cs +++ b/src/ModelContextProtocol.Core/McpSessionHandler.cs @@ -170,7 +170,7 @@ async Task ProcessMessageAsync() await HandleMessageAsync(message, combinedCts?.Token ?? cancellationToken).ConfigureAwait(false); } catch (Exception ex) - { + { // Only send responses for request errors that aren't user-initiated cancellation. bool isUserCancellation = ex is OperationCanceledException && @@ -301,8 +301,35 @@ private async Task HandleMessageAsync(JsonRpcMessage message, CancellationToken } } + /// + /// Handles inbound JSON-RPC notifications. Special-cases $/cancelRequest + /// to cancel the exact in-flight request, and also supports the SDK's custom + /// for backwards compatibility. + /// private async Task HandleNotification(JsonRpcNotification notification, CancellationToken cancellationToken) { + // Handle JSON-RPC native cancellation: $/cancelRequest + if (notification.Method == NotificationMethods.JsonRpcCancelRequest) + { + try + { + if (TryGetJsonRpcIdFromCancelParams(notification.Params, out var reqId) && + _handlingRequests.TryGetValue(reqId, out var cts)) + { + // Request-specific CTS → cancel the in-flight handler + await cts.CancelAsync().ConfigureAwait(false); + LogRequestCanceled(EndpointName, reqId, reason: "jsonrpc/$/cancelRequest"); + } + } + catch + { + // Per spec, invalid cancel messages should be ignored. + } + + // We do not forward $/cancelRequest to user handlers. + return; + } + // Special-case cancellation to cancel a pending operation. (We'll still subsequently invoke a user-specified handler if one exists.) if (notification.Method == NotificationMethods.CancelledNotification) { @@ -567,6 +594,44 @@ private Task SendToRelatedTransportAsync(JsonRpcMessage message, CancellationTok } } + /// + /// Parses the id field from a $/cancelRequest notification's params. + /// Returns only when the id is a valid JSON-RPC request id + /// (string or number). + /// + private static bool TryGetJsonRpcIdFromCancelParams(JsonNode? notificationParams, out RequestId id) + { + id = default; + + if (notificationParams is not JsonObject obj) + return false; + + if (!obj.TryGetPropertyValue("id", out var idNode) || idNode is null) + return false; + + if (idNode.GetValueKind() == System.Text.Json.JsonValueKind.String) + { + id = new RequestId(idNode.GetValue()); + return true; + } + + if (idNode.GetValueKind() == System.Text.Json.JsonValueKind.Number) + { + try + { + var n = idNode.GetValue(); + id = new RequestId(n); + return true; + } + catch + { + return false; + } + } + + return false; + } + private static string CreateActivityName(string method) => method; private static string GetMethodName(JsonRpcMessage message) => diff --git a/src/ModelContextProtocol.Core/Protocol/NotificationMethods.cs b/src/ModelContextProtocol.Core/Protocol/NotificationMethods.cs index 30b7d68a..20a5641e 100644 --- a/src/ModelContextProtocol.Core/Protocol/NotificationMethods.cs +++ b/src/ModelContextProtocol.Core/Protocol/NotificationMethods.cs @@ -131,4 +131,13 @@ public static class NotificationMethods /// /// public const string CancelledNotification = "notifications/cancelled"; + + /// + /// JSON-RPC core cancellation method name ($/cancelRequest). + /// + /// + /// Carries a single id field (string or number) identifying the in-flight + /// request that should be cancelled. + /// + public const string JsonRpcCancelRequest = "$/cancelRequest"; } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Server/IMcpToolWithTimeout.cs b/src/ModelContextProtocol.Core/Server/IMcpToolWithTimeout.cs new file mode 100644 index 00000000..7f0c4f34 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/IMcpToolWithTimeout.cs @@ -0,0 +1,17 @@ +namespace ModelContextProtocol.Server; + +/// +/// Optional contract for tools that expose a per-tool execution timeout. +/// +/// +/// When specified, this value overrides the server-level +/// for this tool only. +/// +public interface IMcpToolWithTimeout +{ + /// + /// Gets the per-tool timeout. When , the server's + /// default applies (if any). + /// + TimeSpan? Timeout { get; } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Server/McpServerImpl.cs b/src/ModelContextProtocol.Core/Server/McpServerImpl.cs index 41408c22..deb24508 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerImpl.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerImpl.cs @@ -2,7 +2,9 @@ using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using ModelContextProtocol.Protocol; +using System.Collections.Concurrent; using System.Runtime.CompilerServices; +using System.Text.Json.Nodes; using System.Text.Json.Serialization.Metadata; namespace ModelContextProtocol.Server; @@ -508,6 +510,12 @@ await originalListPromptsHandler(request, cancellationToken).ConfigureAwait(fals McpJsonUtilities.JsonContext.Default.GetPromptResult); } + /// + /// Wires up tools capability: listing, invocation, DI-provided collections, + /// and the filter pipeline. Invocation enforces per-tool timeouts (when + /// is set) or falls back to + /// when present. + /// private void ConfigureTools(McpServerOptions options) { var listToolsHandler = options.Handlers.ListToolsHandler; @@ -578,8 +586,20 @@ await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) request.MatchedPrimitive = tool; } + TimeSpan? effectiveTimeout = null; + try { + // Determine effective timeout: per-tool overrides server default + effectiveTimeout = (request.MatchedPrimitive as IMcpToolWithTimeout)?.Timeout + ?? ServerOptions.DefaultToolTimeout; + + if (effectiveTimeout is { } ts) + { + return await RunWithTimeoutAsync(ts, request, cancellationToken, handler); + } + + // If no timeout is configured, use the original request cancellation token. return await handler(request, cancellationToken); } catch (Exception e) when (e is not OperationCanceledException and not McpProtocolException) @@ -613,6 +633,59 @@ await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) McpJsonUtilities.JsonContext.Default.CallToolResult); } + /// + /// Executes with a hard server-side timeout. If the timeout elapses, + /// returns a with IsError=true and machine-readable metadata + /// (Meta.IsTimeout=true, Meta.TimeoutMs). Client-initiated cancellations are not + /// handled here; they are rethrown to be processed by the JSON-RPC layer. + /// + /// Must be greater than . + /// The request context. + /// Outer cancellation (client/network) token. + /// The underlying handler to invoke. + private static async Task RunWithTimeoutAsync( + TimeSpan timeout, + RequestContext request, + CancellationToken requestCancellationToken, + McpRequestHandler next) + { + if (timeout <= TimeSpan.Zero) + { + throw new ArgumentOutOfRangeException(nameof(timeout), "Timeout must be greater than zero."); + } + + using var timeoutCts = new CancellationTokenSource(timeout); + using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(requestCancellationToken, timeoutCts.Token); + + try + { + return await next(request, linkedCts.Token).ConfigureAwait(false); + } + // Definitive server-side timeout: the timeout token fired, while the outer (client) token did not. + catch (OperationCanceledException) when (timeoutCts.IsCancellationRequested && !requestCancellationToken.IsCancellationRequested) + { + var ms = (int)Math.Round(timeout.TotalMilliseconds, MidpointRounding.AwayFromZero); + + return new CallToolResult + { + IsError = true, + Meta = new System.Text.Json.Nodes.JsonObject + { + ["IsTimeout"] = true, + ["TimeoutMs"] = ms, + }, + Content = + [ + new TextContentBlock + { + Text = $"Tool '{request.Params?.Name ?? ""}' timed out after {ms}ms." + } + ], + }; + } + } + + private void ConfigureLogging(McpServerOptions options) { // We don't require that the handler be provided, as we always store the provided log level to the server. diff --git a/src/ModelContextProtocol.Core/Server/McpServerOptions.cs b/src/ModelContextProtocol.Core/Server/McpServerOptions.cs index 7b915b94..2af4e691 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerOptions.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerOptions.cs @@ -93,12 +93,12 @@ public sealed class McpServerOptions /// /// Gets or sets the container of handlers used by the server for processing protocol messages. /// - public McpServerHandlers Handlers - { + public McpServerHandlers Handlers + { get => field ??= new(); set - { - Throw.IfNull(value); + { + Throw.IfNull(value); field = value; } } @@ -166,4 +166,15 @@ public McpServerHandlers Handlers /// /// public int MaxSamplingOutputTokens { get; set; } = 1000; + + /// + /// Gets or sets the default timeout applied to tool invocations. + /// + /// + /// When set, the server enforces this timeout for all tools that do not define + /// their own timeout. Tools implementing can + /// override this value on a per-tool basis. When , no + /// server-enforced timeout is applied. + /// + public TimeSpan? DefaultToolTimeout { get; set; } } diff --git a/src/ModelContextProtocol.Core/Server/McpServerToolAttribute.cs b/src/ModelContextProtocol.Core/Server/McpServerToolAttribute.cs index 9e71e0ea..0cc93742 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerToolAttribute.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerToolAttribute.cs @@ -177,10 +177,10 @@ public McpServerToolAttribute() /// The default is . /// /// - public bool Destructive + public bool Destructive { - get => _destructive ?? DestructiveDefault; - set => _destructive = value; + get => _destructive ?? DestructiveDefault; + set => _destructive = value; } /// @@ -195,10 +195,10 @@ public bool Destructive /// The default is . /// /// - public bool Idempotent + public bool Idempotent { get => _idempotent ?? IdempotentDefault; - set => _idempotent = value; + set => _idempotent = value; } /// @@ -215,8 +215,8 @@ public bool Idempotent /// public bool OpenWorld { - get => _openWorld ?? OpenWorldDefault; - set => _openWorld = value; + get => _openWorld ?? OpenWorldDefault; + set => _openWorld = value; } /// @@ -235,10 +235,10 @@ public bool OpenWorld /// The default is . /// /// - public bool ReadOnly + public bool ReadOnly { - get => _readOnly ?? ReadOnlyDefault; - set => _readOnly = value; + get => _readOnly ?? ReadOnlyDefault; + set => _readOnly = value; } /// @@ -269,4 +269,10 @@ public bool ReadOnly /// /// public string? IconSource { get; set; } + + /// + /// Optional timeout for this tool in seconds. + /// If null, the global default (if any) applies. + /// + public int? TimeoutSeconds { get; set; } } diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerToolTimeoutTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerToolTimeoutTests.cs new file mode 100644 index 00000000..58c57ff4 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Server/McpServerToolTimeoutTests.cs @@ -0,0 +1,354 @@ +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Utils; +using Xunit; + +namespace ModelContextProtocol.Tests.Server; + +// NOTE: Assumes McpServerOptions, McpServer, TestServerTransport, RequestMethods, +// CallToolRequestParams, CallToolResult, JsonRpcMessage, JsonRpcResponse, +// JsonRpcError, McpErrorCode, McpServerTool, Tool, RequestContext, +// TextContentBlock, McpJsonUtilities are available from project references. + +/// +/// A simple test tool that simulates slow work. Used to validate timeout enforcement paths. +/// +public class SlowTool : McpServerTool, IMcpToolWithTimeout +{ + private readonly TimeSpan _workDuration; + private readonly TimeSpan? _toolTimeout; + + public SlowTool(TimeSpan workDuration, TimeSpan? toolTimeout) + { + _workDuration = workDuration; + _toolTimeout = toolTimeout; + } + + public string Name => ProtocolTool.Name; + + /// + public override Tool ProtocolTool => new() + { + Name = "SlowTool", + Description = "A tool that works very slowly.", + // No input parameters; schema must be a non-null empty object. + InputSchema = JsonDocument.Parse("""{"type": "object", "properties": {}}""").RootElement + }; + + /// + public override IReadOnlyList Metadata => Array.Empty(); + + /// + public TimeSpan? Timeout => _toolTimeout; + + /// + /// Simulates long-running work and cooperates with cancellation. + /// + public override async ValueTask InvokeAsync( + RequestContext requestContext, + CancellationToken cancellationToken = default) + { + // If the server injects a timeout-linked token, this will throw on timeout. + await Task.Delay(_workDuration, cancellationToken); + + return new() + { + IsError = false, // <- explicitly success + Content = + [ + new TextContentBlock + { + Text = $"Done after {_workDuration.TotalMilliseconds}ms." + } + ] + }; + } +} + +/// +/// Tests server-side tool timeout enforcement and client-initiated cancellation +/// against a live in-memory transport. +/// +public class McpServerToolTimeoutTests : LoggedTest +{ + public McpServerToolTimeoutTests(ITestOutputHelper testOutputHelper) : base(testOutputHelper) { } + + private static McpServerOptions CreateOptions(TimeSpan? defaultTimeout = null) + => new() + { + ProtocolVersion = "2024", + InitializationTimeout = TimeSpan.FromSeconds(30), + DefaultToolTimeout = defaultTimeout + }; + + private static async Task InitializeServerAsync( + TestServerTransport transport, + string? protocolVersion, + CancellationToken ct) + { + var initReqId = new RequestId(Guid.NewGuid().ToString("N")); + var initTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + void OnInit(JsonRpcMessage m) + { + if (m is JsonRpcResponse r && r.Id.ToString() == initReqId.ToString()) + initTcs.TrySetResult(r); + if (m is JsonRpcError e && e.Id.ToString() == initReqId.ToString()) + initTcs.TrySetException(new Xunit.Sdk.XunitException( + $"initialize returned error. Code={e.Error.Code}, Message='{e.Error.Message}'")); + } + + transport.OnMessageSent += OnInit; + try + { + var initParams = JsonSerializer.SerializeToNode(new + { + protocolVersion, + clientInfo = new { name = "ModelContextProtocol.Tests", version = "0.0.0" }, + capabilities = new { } + }, McpJsonUtilities.DefaultOptions); + + await transport.SendMessageAsync(new JsonRpcRequest + { + Method = RequestMethods.Initialize, + Id = initReqId, + Params = initParams + }, CancellationToken.None); + + _ = await initTcs.Task.WaitAsync(ct); + } + finally + { + transport.OnMessageSent -= OnInit; + } + + await transport.SendMessageAsync(new JsonRpcNotification + { + Method = NotificationMethods.InitializedNotification, + Params = null + }, CancellationToken.None); + } + + + + private async Task ExecuteCallToolRequest( + McpServerOptions options, + string toolName, + CancellationToken externalCancellationToken = default) + { + // Early guard: ensure the tool exists in options.ToolCollection (clear failure if not). + if (options?.ToolCollection is null || !options.ToolCollection.Any(t => t.ProtocolTool.Name == toolName)) + throw new Xunit.Sdk.XunitException($"Tool '{toolName}' is not registered in options.ToolCollection."); + + await using var transport = new TestServerTransport(); + await using var server = McpServer.Create(transport, options, LoggerFactory); + + using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource( + externalCancellationToken, TestContext.Current.CancellationToken); + + var runTask = server.RunAsync(linkedCts.Token); + + // MCP handshake + await InitializeServerAsync(transport, options.ProtocolVersion, linkedCts.Token); + + var reqId = new RequestId(Guid.NewGuid().ToString("N")); + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + void OnReply(JsonRpcMessage m) + { + // Success response + if (m is JsonRpcResponse ok && ok.Id.ToString() == reqId.ToString()) + tcs.TrySetResult(ok); + + // Protocol-level error (e.g., "tool not found", validation failures, etc.) + if (m is JsonRpcError err && err.Id.ToString() == reqId.ToString()) + tcs.TrySetException(new Xunit.Sdk.XunitException( + $"Server returned JsonRpcError for tools/call. Code={err.Error.Code}, Message='{err.Error.Message}'")); + } + + transport.OnMessageSent += OnReply; + + try + { + await transport.SendMessageAsync(new JsonRpcRequest + { + Method = RequestMethods.ToolsCall, + Id = reqId, + Params = JsonSerializer.SerializeToNode( + new CallToolRequestParams { Name = toolName }, + McpJsonUtilities.DefaultOptions) + }, externalCancellationToken); + + // This completes for either success (JsonRpcResponse) or error (JsonRpcError). + var obj = await tcs.Task.WaitAsync(externalCancellationToken); + + var response = (JsonRpcResponse)obj; + + // Deserialize a successful response into CallToolResult + return JsonSerializer.Deserialize( + response.Result, McpJsonUtilities.DefaultOptions)!; + } + finally + { + transport.OnMessageSent -= OnReply; + + // Deterministic shutdown + linkedCts.Cancel(); + await transport.DisposeAsync(); + await Task.WhenAny(runTask, Task.Delay(TimeSpan.FromSeconds(5), CancellationToken.None)); + await server.DisposeAsync(); + } + } + + [Fact] + public async Task CallTool_ShouldSucceed_WhenFinishesWithinToolTimeout() + { + // Arrange: 50ms work, 200ms tool timeout → should succeed. + var tool = new SlowTool(TimeSpan.FromMilliseconds(50), TimeSpan.FromMilliseconds(200)); + var options = CreateOptions(); + options.ToolCollection ??= []; + options.ToolCollection.Add(tool); + + // Act + var result = await ExecuteCallToolRequest(options, tool.Name, TestContext.Current.CancellationToken); + + // Assert + Assert.False(result.IsError, "Tool call should succeed when it finishes within the timeout."); + var contentText = result.Content.OfType().Single().Text; + Assert.Contains("Done after 50ms", contentText); + } + + [Fact] + public async Task CallTool_ShouldReturnError_WhenToolTimeoutIsExceeded() + { + // Arrange: 300ms work, 200ms tool timeout → must time out. + var tool = new SlowTool(TimeSpan.FromMilliseconds(300), TimeSpan.FromMilliseconds(200)); + var options = CreateOptions(); + options.ToolCollection ??= []; + options.ToolCollection.Add(tool); + + // Act + var result = await ExecuteCallToolRequest(options, tool.Name, TestContext.Current.CancellationToken); + + // Assert (functional) + Assert.True(result.IsError, "Tool call should fail with IsError=true due to timeout."); + + // Assert (structural): Meta.IsTimeout must be true + Assert.NotNull(result.Meta); + Assert.True( + result.Meta.TryGetPropertyValue("IsTimeout", out var isTimeoutNode), + "Meta must contain 'IsTimeout' property."); + Assert.NotNull(isTimeoutNode); + Assert.True(isTimeoutNode.GetValue(), "'IsTimeout' must be true."); + } + + [Fact] + public async Task CallTool_ShouldReturnError_WhenServerDefaultTimeoutIsExceeded() + { + // Arrange: no per-tool timeout; server default is 100ms; work is 300ms → must time out. + var tool = new SlowTool(TimeSpan.FromMilliseconds(300), toolTimeout: null); + var options = CreateOptions(defaultTimeout: TimeSpan.FromMilliseconds(100)); + options.ToolCollection ??= []; + options.ToolCollection.Add(tool); + + // Act + var result = await ExecuteCallToolRequest(options, tool.Name, TestContext.Current.CancellationToken); + + // Assert (functional) + Assert.True(result.IsError, "Tool call should fail due to the server's default timeout."); + + // Assert (structural): Meta.IsTimeout must be true + Assert.NotNull(result.Meta); + Assert.True( + result.Meta.TryGetPropertyValue("IsTimeout", out var isTimeoutNode), + "Meta must contain 'IsTimeout' property."); + Assert.NotNull(isTimeoutNode); + Assert.True(isTimeoutNode.GetValue(), "'IsTimeout' must be true."); + } + + [Fact] + public async Task CallTool_ShouldNotRespond_WhenClientCancelsViaJsonRpc() + { + // Arrange: no server/tool timeout; user will cancel via $/cancelRequest. + var tool = new SlowTool(TimeSpan.FromSeconds(10), toolTimeout: null); + var options = CreateOptions(defaultTimeout: null); + options.ToolCollection ??= []; + options.ToolCollection.Add(tool); + + await using var transport = new TestServerTransport(); + await using var server = McpServer.Create(transport, options, LoggerFactory); + + using var serverCts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); + var runTask = server.RunAsync(serverCts.Token); + + // handshake + await InitializeServerAsync(transport, options.ProtocolVersion, serverCts.Token); + + var requestId = new RequestId(Guid.NewGuid().ToString("N")); + + var anyReply = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + void OnAnyReply(JsonRpcMessage m) + { + if ((m is JsonRpcResponse r && r.Id.ToString() == requestId.ToString()) || + (m is JsonRpcError e && e.Id.ToString() == requestId.ToString())) + { + anyReply.TrySetResult(m); + } + } + + transport.OnMessageSent += OnAnyReply; + + try + { + // 1) send call + await transport.SendMessageAsync( + new JsonRpcRequest + { + Method = RequestMethods.ToolsCall, + Id = requestId, + Params = JsonSerializer.SerializeToNode( + new CallToolRequestParams { Name = tool.Name }, + McpJsonUtilities.DefaultOptions) + }, + CancellationToken.None); + + await Task.Yield(); + await Task.Delay(200, serverCts.Token); + + // 2) send $/cancelRequest + var cancelParams = JsonSerializer.SerializeToNode(new { id = requestId.ToString() }); + await transport.SendMessageAsync( + new JsonRpcNotification + { + Method = NotificationMethods.JsonRpcCancelRequest, + Params = cancelParams + }, + CancellationToken.None); + + // 3) ensure that NO response is emitted for this cancellation + try + { + var _ = await anyReply.Task.WaitAsync(TimeSpan.FromSeconds(2), CancellationToken.None); + throw new Xunit.Sdk.XunitException("Server responded to user-initiated cancellation. Expected: no response."); + } + catch (TimeoutException) + { + // expected → silent cancel path + } + } + finally + { + transport.OnMessageSent -= OnAnyReply; + serverCts.Cancel(); + await transport.DisposeAsync(); + await Task.WhenAny(runTask, Task.Delay(TimeSpan.FromSeconds(5), CancellationToken.None)); + await server.DisposeAsync(); + } + } + +}