diff --git a/.gitignore b/.gitignore index 3fce64294..5b5c2bcba 100644 --- a/.gitignore +++ b/.gitignore @@ -74,6 +74,14 @@ nCrunchTemp_* *.orig +*.ncrunchsolution + +*.lutconfig + +.NCrunch_ModelContextProtocol/ + +*.ncrunchproject + # Auto-generated documentation docs/_site docs/api \ No newline at end of file diff --git a/Directory.Packages.props b/Directory.Packages.props index 554361cbe..35bb92992 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -15,6 +15,7 @@ + diff --git a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Transports.cs b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Transports.cs index 1f357d32a..06953962e 100644 --- a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Transports.cs +++ b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Transports.cs @@ -1,8 +1,9 @@ -using ModelContextProtocol.Configuration; +using Microsoft.Extensions.DependencyInjection; + +using ModelContextProtocol.Configuration; using ModelContextProtocol.Hosting; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Utils; -using Microsoft.Extensions.DependencyInjection; namespace ModelContextProtocol; @@ -11,6 +12,30 @@ namespace ModelContextProtocol; /// public static partial class McpServerBuilderExtensions { + /// + /// Adds a server transport that uses in memory communication. + /// + /// The builder instance. + public static IMcpServerBuilder WithInMemoryServerTransport(this IMcpServerBuilder builder) + { + Throw.IfNull(builder); + builder.Services.AddSingleton(); + builder.Services.AddSingleton(s => + { + var transport = s.GetRequiredService(); + return transport.ClientTransport; + }); + + builder.Services.AddSingleton(s => + { + var transport = s.GetRequiredService(); + return transport.ServerTransport; + }); + + builder.Services.AddHostedService(); + return builder; + } + /// /// Adds a server transport that uses stdin/stdout for communication. /// diff --git a/src/ModelContextProtocol/ModelContextProtocol.csproj b/src/ModelContextProtocol/ModelContextProtocol.csproj index dcee6278f..8caddfe92 100644 --- a/src/ModelContextProtocol/ModelContextProtocol.csproj +++ b/src/ModelContextProtocol/ModelContextProtocol.csproj @@ -14,8 +14,9 @@ - + + diff --git a/src/ModelContextProtocol/Protocol/Transport/InMemoryClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/InMemoryClientTransport.cs new file mode 100644 index 000000000..4a476afdf --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Transport/InMemoryClientTransport.cs @@ -0,0 +1,228 @@ +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; + +using ModelContextProtocol.Logging; +using ModelContextProtocol.Protocol.Messages; + +using System.Threading.Channels; + +namespace ModelContextProtocol.Protocol.Transport; + +/// +/// Provides an in-memory implementation of the MCP client transport. +/// +public sealed class InMemoryClientTransport : TransportBase, IClientTransport +{ + private string EndpointName => $"Client (in memory) for ({_serverName})"; + private readonly ILogger _logger; + private readonly string _serverName; + private readonly ChannelWriter _outgoingChannel; + private readonly ChannelReader _incomingChannel; + private CancellationTokenSource? _cancellationTokenSource; + private Task? _readTask; + private readonly SemaphoreSlim _connectLock = new SemaphoreSlim(1, 1); + private volatile bool _disposed; + + /// + /// Gets or sets the server transport this client connects to. + /// + internal InMemoryServerTransport? ServerTransport { get; set; } + + /// + /// Initializes a new instance of the class. + /// + /// The name of the server. + /// Optional logger factory for logging transport operations. + /// Channel for sending messages to the server. + /// Channel for receiving messages from the server. + internal InMemoryClientTransport( + string serverName, + ILoggerFactory? loggerFactory, + ChannelWriter outgoingChannel, + ChannelReader incomingChannel) + : base(loggerFactory) + { + _logger = loggerFactory?.CreateLogger() + ?? NullLogger.Instance; + _serverName = serverName; + _outgoingChannel = outgoingChannel; + _incomingChannel = incomingChannel; + } + + + + /// + public async Task ConnectAsync(CancellationToken cancellationToken = default) + { + await _connectLock.WaitAsync(cancellationToken).ConfigureAwait(false); + try + { + ThrowIfDisposed(); + + if (IsConnected) + { + _logger.TransportAlreadyConnected(EndpointName); + throw new McpTransportException("Transport is already connected"); + } + + _logger.TransportConnecting(EndpointName); + + try + { + // Start the server if it exists and is not already connected + if (ServerTransport != null && !ServerTransport.IsConnected) + { + await ServerTransport.StartListeningAsync(cancellationToken).ConfigureAwait(false); + } + + _cancellationTokenSource = new CancellationTokenSource(); + _readTask = Task.Run(() => ReadMessagesAsync(_cancellationTokenSource.Token).ConfigureAwait(false), CancellationToken.None); + + SetConnected(true); + } + catch (Exception ex) + { + _logger.TransportConnectFailed(EndpointName, ex); + await CleanupAsync(cancellationToken).ConfigureAwait(false); + throw new McpTransportException("Failed to connect transport", ex); + } + } + finally + { + _connectLock.Release(); + } + } + + /// + public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + { + ThrowIfDisposed(); + + if (!IsConnected) + { + _logger.TransportNotConnected(EndpointName); + throw new McpTransportException("Transport is not connected"); + } + + string id = "(no id)"; + if (message is IJsonRpcMessageWithId messageWithId) + { + id = messageWithId.Id.ToString(); + } + + try + { + _logger.TransportSendingMessage(EndpointName, id); + await _outgoingChannel.WriteAsync(message, cancellationToken).ConfigureAwait(false); + _logger.TransportSentMessage(EndpointName, id); + } + catch (Exception ex) + { + _logger.TransportSendFailed(EndpointName, id, ex); + throw new McpTransportException("Failed to send message", ex); + } + } + + /// + public override async ValueTask DisposeAsync() + { + await CleanupAsync(CancellationToken.None).ConfigureAwait(false); + GC.SuppressFinalize(this); + } + + private async Task ReadMessagesAsync(CancellationToken cancellationToken) + { + try + { + _logger.TransportEnteringReadMessagesLoop(EndpointName); + + await foreach (var message in _incomingChannel.ReadAllAsync(cancellationToken).ConfigureAwait(false)) + { + var id = "(no id)"; + if (message is IJsonRpcMessageWithId messageWithId) + { + id = messageWithId.Id.ToString(); + } + + _logger.TransportReceivedMessageParsed(EndpointName, id); + + // Write to the base class's message channel that's exposed via MessageReader + await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false); + + _logger.TransportMessageWritten(EndpointName, id); + } + + _logger.TransportExitingReadMessagesLoop(EndpointName); + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + _logger.TransportReadMessagesCancelled(EndpointName); + // Normal shutdown + } + catch (Exception ex) + { + _logger.TransportReadMessagesFailed(EndpointName, ex); + } + } + + private async Task CleanupAsync(CancellationToken cancellationToken) + { + if (_disposed) + { + return; + } + + _disposed = true; + _logger.TransportCleaningUp(EndpointName); + + try + { + if (_cancellationTokenSource != null) + { + await _cancellationTokenSource.CancelAsync().ConfigureAwait(false); + _cancellationTokenSource.Dispose(); + _cancellationTokenSource = null; + } + + if (_readTask != null) + { + try + { + _logger.TransportWaitingForReadTask(EndpointName); + await _readTask.WaitAsync(TimeSpan.FromSeconds(1), cancellationToken).ConfigureAwait(false); + } + catch (TimeoutException) + { + _logger.TransportCleanupReadTaskTimeout(EndpointName); + } + catch (OperationCanceledException) + { + _logger.TransportCleanupReadTaskCancelled(EndpointName); + } + catch (Exception ex) + { + _logger.TransportCleanupReadTaskFailed(EndpointName, ex); + } + finally + { + _readTask = null; + } + } + + _connectLock.Dispose(); + } + finally + { + SetConnected(false); + _logger.TransportCleanedUp(EndpointName); + } + } + + private void ThrowIfDisposed() + { + if (_disposed) + { + throw new ObjectDisposedException(nameof(InMemoryClientTransport)); + } + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Transport/InMemoryServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/InMemoryServerTransport.cs new file mode 100644 index 000000000..be5af42fa --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Transport/InMemoryServerTransport.cs @@ -0,0 +1,214 @@ +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; + +using ModelContextProtocol.Logging; +using ModelContextProtocol.Protocol.Messages; + +using System.Threading.Channels; + +namespace ModelContextProtocol.Protocol.Transport; + +/// +/// Provides an in-memory implementation of the MCP server transport. +/// +public sealed class InMemoryServerTransport : TransportBase, IServerTransport +{ + private string EndpointName => $"Server (in memory) for ({_serverName})"; + private readonly ILogger _logger; + private readonly ChannelReader _incomingChannel; + private readonly ChannelWriter _outgoingChannel; + private CancellationTokenSource? _cancellationTokenSource; + private Task? _readTask; + private SemaphoreSlim _startLock = new SemaphoreSlim(1, 1); + private volatile bool _disposed; + private readonly string _serverName; + + /// + /// Initializes a new instance of the class. + /// + /// The name of the server. + /// Optional logger factory for logging transport operations. + /// Channel for receiving messages from the client. + /// Channel for sending messages to the client. + internal InMemoryServerTransport( + string serverName, + ILoggerFactory? loggerFactory, + ChannelReader incomingChannel, + ChannelWriter outgoingChannel) + : base(loggerFactory) + { + _logger = loggerFactory?.CreateLogger() + ?? NullLogger.Instance; + _incomingChannel = incomingChannel; + _outgoingChannel = outgoingChannel; + _serverName = serverName; + } + + /// + public async Task StartListeningAsync(CancellationToken cancellationToken = default) + { + await _startLock.WaitAsync(cancellationToken).ConfigureAwait(false); + try + { + ThrowIfDisposed(); + + if (IsConnected) + { + return; + } + + _logger.TransportConnecting(EndpointName); + + try + { + _cancellationTokenSource = new CancellationTokenSource(); + _readTask = Task.Run(() => ReadMessagesAsync(_cancellationTokenSource.Token).ConfigureAwait(false), CancellationToken.None); + + SetConnected(true); + } + catch (Exception ex) + { + _logger.TransportConnectFailed(EndpointName, ex); + await CleanupAsync(cancellationToken).ConfigureAwait(false); + throw new McpTransportException("Failed to connect transport", ex); + } + } + finally + { + _startLock.Release(); + } + } + + /// + public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + { + ThrowIfDisposed(); + + if (!IsConnected) + { + _logger.TransportNotConnected(EndpointName); + throw new McpTransportException("Transport is not connected"); + } + + string id = "(no id)"; + if (message is IJsonRpcMessageWithId messageWithId) + { + id = messageWithId.Id.ToString(); + } + + try + { + _logger.TransportSendingMessage(EndpointName, id); + await _outgoingChannel.WriteAsync(message, cancellationToken).ConfigureAwait(false); + _logger.TransportSentMessage(EndpointName, id); + } + catch (Exception ex) + { + _logger.TransportSendFailed(EndpointName, id, ex); + throw new McpTransportException("Failed to send message", ex); + } + } + + /// + public override async ValueTask DisposeAsync() + { + await CleanupAsync(CancellationToken.None).ConfigureAwait(false); + GC.SuppressFinalize(this); + } + + private async Task ReadMessagesAsync(CancellationToken cancellationToken) + { + try + { + _logger.TransportEnteringReadMessagesLoop(EndpointName); + + await foreach (var message in _incomingChannel.ReadAllAsync(cancellationToken).ConfigureAwait(false)) + { + string id = "(no id)"; + if (message is IJsonRpcMessageWithId messageWithId) + { + id = messageWithId.Id.ToString(); + } + + _logger.TransportReceivedMessageParsed(EndpointName, id); + + // Write to the base class's message channel that's exposed via MessageReader + await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false); + + _logger.TransportMessageWritten(EndpointName, id); + } + + _logger.TransportExitingReadMessagesLoop(EndpointName); + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + _logger.TransportReadMessagesCancelled(EndpointName); + // Normal shutdown + } + catch (Exception ex) + { + _logger.TransportReadMessagesFailed(EndpointName, ex); + } + } + + private async Task CleanupAsync(CancellationToken cancellationToken) + { + if (_disposed) + { + return; + } + + _disposed = true; + _logger.TransportCleaningUp(EndpointName); + + try + { + if (_cancellationTokenSource != null) + { + await _cancellationTokenSource.CancelAsync().ConfigureAwait(false); + _cancellationTokenSource.Dispose(); + _cancellationTokenSource = null; + } + + if (_readTask != null) + { + try + { + _logger.TransportWaitingForReadTask(EndpointName); + await _readTask.WaitAsync(TimeSpan.FromSeconds(1), cancellationToken).ConfigureAwait(false); + } + catch (TimeoutException) + { + _logger.TransportCleanupReadTaskTimeout(EndpointName); + } + catch (OperationCanceledException) + { + _logger.TransportCleanupReadTaskCancelled(EndpointName); + } + catch (Exception ex) + { + _logger.TransportCleanupReadTaskFailed(EndpointName, ex); + } + finally + { + _readTask = null; + } + } + + _startLock.Dispose(); + } + finally + { + SetConnected(false); + _logger.TransportCleanedUp(EndpointName); + } + } + + private void ThrowIfDisposed() + { + if (_disposed) + { + throw new ObjectDisposedException(nameof(InMemoryServerTransport)); + } + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Transport/InMemoryTransport.cs b/src/ModelContextProtocol/Protocol/Transport/InMemoryTransport.cs new file mode 100644 index 000000000..3119b8680 --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Transport/InMemoryTransport.cs @@ -0,0 +1,121 @@ +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; + +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Server; +using ModelContextProtocol.Utils; + +using System.Threading.Channels; + +namespace ModelContextProtocol.Protocol.Transport; + +/// +/// Factory that creates linked in-memory client and server transports for testing purposes. +/// +public sealed class InMemoryTransport +{ + /// + /// Initializes a new instance of the class. + /// + /// The server options. + /// Optional logger factory used for logging employed by the transport. + /// is or contains a null name. + + public InMemoryTransport(McpServerOptions serverOptions, ILoggerFactory? loggerFactory = null) + : this(GetServerName(serverOptions), loggerFactory) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The server options. + /// Optional logger factory used for logging employed by the transport. + /// is or contains a null name. + + public InMemoryTransport(IOptions serverOptions, ILoggerFactory? loggerFactory = null) + : this(GetServerName(serverOptions.Value), loggerFactory) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The name of the server. + /// Optional logger factory used for logging employed by the transport. + /// is . + /// + /// + /// By default, no logging is performed. If a is supplied, it must not log + /// to , as that will interfere with the transport's output. + /// + /// + public InMemoryTransport(string serverName, ILoggerFactory? loggerFactory = null) + { + var (clientTransport, serverTransport) = Create(serverName, loggerFactory); + ServerTransport = serverTransport; + ClientTransport = clientTransport; + } + + /// + /// Gets the client transport. + /// + public IClientTransport ClientTransport { get; } + + /// + /// Gets the server transport. + /// + public IServerTransport ServerTransport { get; } + + + private static (InMemoryClientTransport ClientTransport, InMemoryServerTransport ServerTransport) Create( + string serverName, + ILoggerFactory? loggerFactory = null) + { + // Configure client-to-server channel - this will be used for: + // 1. Client's outgoing channel + // 2. Server's MessageReader + var clientToServerChannel = Channel.CreateUnbounded(new UnboundedChannelOptions + { + SingleReader = false, // Both server and the server's MessageReader will read + SingleWriter = true, // Client writes + AllowSynchronousContinuations = true + }); + + // Configure server-to-client channel - this will be used for: + // 1. Server's outgoing channel + // 2. Client's MessageReader + var serverToClientChannel = Channel.CreateUnbounded(new UnboundedChannelOptions + { + SingleReader = false, // Both client and the client's MessageReader will read + SingleWriter = true, // Server writes + AllowSynchronousContinuations = true + }); + + // Create the client and server transports - they directly expose the channels through MessageReader + var serverTransport = new InMemoryServerTransport( + serverName, + loggerFactory, + clientToServerChannel.Reader, // incoming: reads messages from client + serverToClientChannel.Writer); // outgoing: writes messages to client + + var clientTransport = new InMemoryClientTransport( + serverName, + loggerFactory, + clientToServerChannel.Writer, // outgoing: writes messages to server + serverToClientChannel.Reader); // incoming: reads messages from server + + // Link the transports together + clientTransport.ServerTransport = serverTransport; + + return (clientTransport, serverTransport); + } + + private static string GetServerName(McpServerOptions serverOptions) + { + Throw.IfNull(serverOptions); + Throw.IfNull(serverOptions.ServerInfo); + + return serverOptions.ServerInfo.Name; + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Transport/TransportTypes.cs b/src/ModelContextProtocol/Protocol/Transport/TransportTypes.cs index 02de0a6b5..8a8c9f569 100644 --- a/src/ModelContextProtocol/Protocol/Transport/TransportTypes.cs +++ b/src/ModelContextProtocol/Protocol/Transport/TransportTypes.cs @@ -14,4 +14,9 @@ public static class TransportTypes /// The name of the ServerSideEvents transport. /// public const string Sse = "sse"; + + /// + /// The name of the InMemory transport. + /// + public const string InMemory = "inmemory"; } diff --git a/src/ModelContextProtocol/Server/McpServerExtensions.cs b/src/ModelContextProtocol/Server/McpServerExtensions.cs index 7bff56642..8a676259c 100644 --- a/src/ModelContextProtocol/Server/McpServerExtensions.cs +++ b/src/ModelContextProtocol/Server/McpServerExtensions.cs @@ -1,7 +1,13 @@ -using ModelContextProtocol.Protocol.Messages; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; + +using ModelContextProtocol.Client; +using ModelContextProtocol.Configuration; +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Utils; -using Microsoft.Extensions.AI; + using System.Runtime.CompilerServices; using System.Text; @@ -10,6 +16,27 @@ namespace ModelContextProtocol.Server; /// public static class McpServerExtensions { + /// + /// Gets an in-memory client for the server. + /// + /// + /// + /// + /// + public static async Task GetInMemoryClientAsync(this IMcpServer server, CancellationToken cancellationToke = default) + { + var client = await McpClientFactory.CreateAsync( + new McpServerConfig + { + Id = server.ServerOptions.ServerInfo.Name, + Name = server.ServerOptions.ServerInfo.Name, + TransportType = TransportTypes.InMemory, + }, + createTransportFunc: (_, _) => server.Services?.GetRequiredService() ?? throw new InvalidOperationException(), + cancellationToken: cancellationToke); + return client; + } + /// /// Requests to sample an LLM via the client. /// @@ -42,7 +69,7 @@ public static Task RequestSamplingAsync( /// is . /// The client does not support sampling. public static async Task RequestSamplingAsync( - this IMcpServer server, + this IMcpServer server, IEnumerable messages, ChatOptions? options = default, CancellationToken cancellationToken = default) { Throw.IfNull(server); @@ -112,14 +139,14 @@ public static async Task RequestSamplingAsync( } var result = await server.RequestSamplingAsync(new() - { - Messages = samplingMessages, - MaxTokens = options?.MaxOutputTokens, - StopSequences = options?.StopSequences?.ToArray(), - SystemPrompt = systemPrompt?.ToString(), - Temperature = options?.Temperature, - ModelPreferences = modelPreferences, - }, cancellationToken).ConfigureAwait(false); + { + Messages = samplingMessages, + MaxTokens = options?.MaxOutputTokens, + StopSequences = options?.StopSequences?.ToArray(), + SystemPrompt = systemPrompt?.ToString(), + Temperature = options?.Temperature, + ModelPreferences = modelPreferences, + }, cancellationToken).ConfigureAwait(false); return new(new ChatMessage(new(result.Role), [result.Content.ToAIContent()])) { diff --git a/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs new file mode 100644 index 000000000..73fa31014 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs @@ -0,0 +1,158 @@ +using Microsoft.Extensions.DependencyInjection; + +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Configuration; +using ModelContextProtocol.Tests.Utils; + +using System.Text.Json; + +namespace ModelContextProtocol.Tests.Transport; + +public class InMemoryTransportTests(ITestOutputHelper testOutputHelper) : LoggedTest(testOutputHelper) +{ + + [Fact] + public async Task SendMessageAsync_Throws_Exception_If_Not_Connected() + { + var transport = new InMemoryTransport("test", LoggerFactory); + var serverTransport = transport.ServerTransport; + var clientTransport = transport.ClientTransport; + + var message = new JsonRpcRequest { Method = "test" }; + + await Assert.ThrowsAsync(() => serverTransport.SendMessageAsync(message, TestContext.Current.CancellationToken)); + await Assert.ThrowsAsync(() => clientTransport.SendMessageAsync(message, TestContext.Current.CancellationToken)); + } + + [Fact] + public async Task DisposeAsync_Should_Dispose_Resources() + { + var transport = new InMemoryTransport("test", LoggerFactory); + var serverTransport = transport.ServerTransport; + var clientTransport = transport.ClientTransport; + + await serverTransport.DisposeAsync(); + await clientTransport.DisposeAsync(); + + Assert.False(serverTransport.IsConnected); + Assert.False(clientTransport.IsConnected); + } + + [Fact] + public async Task TransportPair_Should_Create_Valid_Transports() + { + var transport = new InMemoryTransport("test", LoggerFactory); + var serverTransport = transport.ServerTransport; + var clientTransport = transport.ClientTransport; + + Assert.NotNull(clientTransport); + Assert.NotNull(serverTransport); + Assert.False(clientTransport.IsConnected); + Assert.False(serverTransport.IsConnected); + + await clientTransport.DisposeAsync(); + await serverTransport.DisposeAsync(); + } + + + [Fact] + public async Task Message_Should_Flow_From_Client_To_Server() + { + var transport = new InMemoryTransport("test", LoggerFactory); + var serverTransport = transport.ServerTransport; + var clientTransport = transport.ClientTransport; + + await clientTransport.ConnectAsync(TestContext.Current.CancellationToken); + + var message = new JsonRpcRequest + { + Method = "test", + Id = RequestId.FromNumber(123), + Params = new Dictionary { ["text"] = "Hello, World!" } + }; + + + await clientTransport.SendMessageAsync(message, TestContext.Current.CancellationToken); + await Task.Delay(2, TestContext.Current.CancellationToken); + + + Assert.True(serverTransport.MessageReader.TryRead(out var receivedMessage)); + Assert.NotNull(receivedMessage); + Assert.IsType(receivedMessage); + + var request = (JsonRpcRequest)receivedMessage; + Assert.Equal(123, request.Id.AsNumber); + Assert.Equal("test", request.Method); + + var requestParams = (Dictionary)request.Params!; + Assert.Equal("Hello, World!", requestParams["text"]); + + // Cleanup + await clientTransport.DisposeAsync(); + await serverTransport.DisposeAsync(); + } + + [Fact] + public async Task Message_Should_Flow_From_Server_To_Client() + { + var transport = new InMemoryTransport("test", LoggerFactory); + var serverTransport = transport.ServerTransport; + var clientTransport = transport.ClientTransport; + + await clientTransport.ConnectAsync(TestContext.Current.CancellationToken); + + var message = new JsonRpcResponse + { + Id = RequestId.FromNumber(456), + Result = new Dictionary { ["text"] = "Response from server" } + }; + + + await serverTransport.SendMessageAsync(message, TestContext.Current.CancellationToken); + await Task.Delay(2, TestContext.Current.CancellationToken); + + Assert.True(clientTransport.MessageReader.TryRead(out var receivedMessage)); + Assert.NotNull(receivedMessage); + Assert.IsType(receivedMessage); + + var response = (JsonRpcResponse)receivedMessage; + Assert.Equal(456, response.Id.AsNumber); + + var responseResult = (Dictionary)response.Result!; + Assert.Equal("Response from server", responseResult["text"]); + + // Cleanup + await clientTransport.DisposeAsync(); + await serverTransport.DisposeAsync(); + } + + [Fact] + public async Task Can_List_Registered_Tools() + { + ServiceCollection sc = new(); + var builder = sc.AddMcpServer().WithTools().WithInMemoryServerTransport(); + var server = sc.BuildServiceProvider().GetRequiredService(); + await server.StartAsync(TestContext.Current.CancellationToken); + + IMcpClient client = await server.GetInMemoryClientAsync(TestContext.Current.CancellationToken); + + + var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); + Assert.Equal(11, tools.Count); + + McpClientTool echoTool = tools.First(t => t.Name == "Echo"); + Assert.Equal("Echo", echoTool.Name); + Assert.Equal("Echoes the input back to the client.", echoTool.Description); + Assert.Equal("object", echoTool.JsonSchema.GetProperty("type").GetString()); + Assert.Equal(JsonValueKind.Object, echoTool.JsonSchema.GetProperty("properties").GetProperty("message").ValueKind); + Assert.Equal("the echoes message", echoTool.JsonSchema.GetProperty("properties").GetProperty("message").GetProperty("description").GetString()); + Assert.Equal(1, echoTool.JsonSchema.GetProperty("required").GetArrayLength()); + + McpClientTool doubleEchoTool = tools.First(t => t.Name == "double_echo"); + Assert.Equal("double_echo", doubleEchoTool.Name); + Assert.Equal("Echoes the input back to the client.", doubleEchoTool.Description); + } +} diff --git a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs index 94ae62728..4c4941ed7 100644 --- a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs @@ -4,6 +4,7 @@ using ModelContextProtocol.Server; using ModelContextProtocol.Tests.Utils; using ModelContextProtocol.Utils.Json; + using System.IO.Pipelines; using System.Text; using System.Text.Json; @@ -70,12 +71,12 @@ public async Task SendMessageAsync_Should_Send_Message() new Pipe().Reader.AsStream(), output, LoggerFactory); - + await transport.StartListeningAsync(TestContext.Current.CancellationToken); - + // Ensure transport is fully initialized await Task.Delay(100, TestContext.Current.CancellationToken); - + // Verify transport is connected Assert.True(transport.IsConnected, "Transport should be connected after StartListeningAsync"); @@ -124,12 +125,12 @@ public async Task ReadMessagesAsync_Should_Read_Messages() input, Stream.Null, LoggerFactory); - + await transport.StartListeningAsync(TestContext.Current.CancellationToken); - + // Ensure transport is fully initialized await Task.Delay(100, TestContext.Current.CancellationToken); - + // Verify transport is connected Assert.True(transport.IsConnected, "Transport should be connected after StartListeningAsync"); @@ -163,24 +164,24 @@ public async Task SendMessageAsync_Should_Preserve_Unicode_Characters() using var output = new MemoryStream(); await using var transport = new StdioServerTransport( - _serverOptions.ServerInfo.Name, + _serverOptions.ServerInfo.Name, new Pipe().Reader.AsStream(), - output, + output, LoggerFactory); - + await transport.StartListeningAsync(TestContext.Current.CancellationToken); - + // Ensure transport is fully initialized await Task.Delay(100, TestContext.Current.CancellationToken); - + // Verify transport is connected Assert.True(transport.IsConnected, "Transport should be connected after StartListeningAsync"); // Test 1: Chinese characters (BMP Unicode) var chineseText = "上下文伺服器"; // "Context Server" in Chinese - var chineseMessage = new JsonRpcRequest - { - Method = "test", + var chineseMessage = new JsonRpcRequest + { + Method = "test", Id = RequestId.FromNumber(44), Params = new Dictionary { @@ -191,18 +192,18 @@ public async Task SendMessageAsync_Should_Preserve_Unicode_Characters() // Clear output and send message output.SetLength(0); await transport.SendMessageAsync(chineseMessage, TestContext.Current.CancellationToken); - + // Verify Chinese characters preserved but encoded var chineseResult = Encoding.UTF8.GetString(output.ToArray()).Trim(); var expectedChinese = JsonSerializer.Serialize(chineseMessage, McpJsonUtilities.DefaultOptions); Assert.Equal(expectedChinese, chineseResult); Assert.Contains(JsonSerializer.Serialize(chineseText), chineseResult); - + // Test 2: Emoji (non-BMP Unicode using surrogate pairs) var emojiText = "🔍 🚀 👍"; // Magnifying glass, rocket, thumbs up - var emojiMessage = new JsonRpcRequest - { - Method = "test", + var emojiMessage = new JsonRpcRequest + { + Method = "test", Id = RequestId.FromNumber(45), Params = new Dictionary { @@ -213,23 +214,23 @@ public async Task SendMessageAsync_Should_Preserve_Unicode_Characters() // Clear output and send message output.SetLength(0); await transport.SendMessageAsync(emojiMessage, TestContext.Current.CancellationToken); - + // Verify emoji preserved - might be as either direct characters or escape sequences var emojiResult = Encoding.UTF8.GetString(output.ToArray()).Trim(); var expectedEmoji = JsonSerializer.Serialize(emojiMessage, McpJsonUtilities.DefaultOptions); Assert.Equal(expectedEmoji, emojiResult); - + // Verify surrogate pairs in different possible formats // Magnifying glass emoji: 🔍 (U+1F50D) - bool magnifyingGlassFound = - emojiResult.Contains("🔍") || + bool magnifyingGlassFound = + emojiResult.Contains("🔍") || emojiResult.Contains("\\ud83d\\udd0d", StringComparison.OrdinalIgnoreCase); - + // Rocket emoji: 🚀 (U+1F680) - bool rocketFound = - emojiResult.Contains("🚀") || + bool rocketFound = + emojiResult.Contains("🚀") || emojiResult.Contains("\\ud83d\\ude80", StringComparison.OrdinalIgnoreCase); - + Assert.True(magnifyingGlassFound, "Magnifying glass emoji not found in result"); Assert.True(rocketFound, "Rocket emoji not found in result"); }