diff --git a/.gitignore b/.gitignore
index 3fce64294..5b5c2bcba 100644
--- a/.gitignore
+++ b/.gitignore
@@ -74,6 +74,14 @@ nCrunchTemp_*
*.orig
+*.ncrunchsolution
+
+*.lutconfig
+
+.NCrunch_ModelContextProtocol/
+
+*.ncrunchproject
+
# Auto-generated documentation
docs/_site
docs/api
\ No newline at end of file
diff --git a/Directory.Packages.props b/Directory.Packages.props
index 554361cbe..35bb92992 100644
--- a/Directory.Packages.props
+++ b/Directory.Packages.props
@@ -15,6 +15,7 @@
+
diff --git a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Transports.cs b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Transports.cs
index 1f357d32a..06953962e 100644
--- a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Transports.cs
+++ b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Transports.cs
@@ -1,8 +1,9 @@
-using ModelContextProtocol.Configuration;
+using Microsoft.Extensions.DependencyInjection;
+
+using ModelContextProtocol.Configuration;
using ModelContextProtocol.Hosting;
using ModelContextProtocol.Protocol.Transport;
using ModelContextProtocol.Utils;
-using Microsoft.Extensions.DependencyInjection;
namespace ModelContextProtocol;
@@ -11,6 +12,30 @@ namespace ModelContextProtocol;
///
public static partial class McpServerBuilderExtensions
{
+ ///
+ /// Adds a server transport that uses in memory communication.
+ ///
+ /// The builder instance.
+ public static IMcpServerBuilder WithInMemoryServerTransport(this IMcpServerBuilder builder)
+ {
+ Throw.IfNull(builder);
+ builder.Services.AddSingleton();
+ builder.Services.AddSingleton(s =>
+ {
+ var transport = s.GetRequiredService();
+ return transport.ClientTransport;
+ });
+
+ builder.Services.AddSingleton(s =>
+ {
+ var transport = s.GetRequiredService();
+ return transport.ServerTransport;
+ });
+
+ builder.Services.AddHostedService();
+ return builder;
+ }
+
///
/// Adds a server transport that uses stdin/stdout for communication.
///
diff --git a/src/ModelContextProtocol/ModelContextProtocol.csproj b/src/ModelContextProtocol/ModelContextProtocol.csproj
index dcee6278f..8caddfe92 100644
--- a/src/ModelContextProtocol/ModelContextProtocol.csproj
+++ b/src/ModelContextProtocol/ModelContextProtocol.csproj
@@ -14,8 +14,9 @@
-
+
+
diff --git a/src/ModelContextProtocol/Protocol/Transport/InMemoryClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/InMemoryClientTransport.cs
new file mode 100644
index 000000000..4a476afdf
--- /dev/null
+++ b/src/ModelContextProtocol/Protocol/Transport/InMemoryClientTransport.cs
@@ -0,0 +1,228 @@
+using Microsoft.Extensions.Logging;
+using Microsoft.Extensions.Logging.Abstractions;
+
+using ModelContextProtocol.Logging;
+using ModelContextProtocol.Protocol.Messages;
+
+using System.Threading.Channels;
+
+namespace ModelContextProtocol.Protocol.Transport;
+
+///
+/// Provides an in-memory implementation of the MCP client transport.
+///
+public sealed class InMemoryClientTransport : TransportBase, IClientTransport
+{
+ private string EndpointName => $"Client (in memory) for ({_serverName})";
+ private readonly ILogger _logger;
+ private readonly string _serverName;
+ private readonly ChannelWriter _outgoingChannel;
+ private readonly ChannelReader _incomingChannel;
+ private CancellationTokenSource? _cancellationTokenSource;
+ private Task? _readTask;
+ private readonly SemaphoreSlim _connectLock = new SemaphoreSlim(1, 1);
+ private volatile bool _disposed;
+
+ ///
+ /// Gets or sets the server transport this client connects to.
+ ///
+ internal InMemoryServerTransport? ServerTransport { get; set; }
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The name of the server.
+ /// Optional logger factory for logging transport operations.
+ /// Channel for sending messages to the server.
+ /// Channel for receiving messages from the server.
+ internal InMemoryClientTransport(
+ string serverName,
+ ILoggerFactory? loggerFactory,
+ ChannelWriter outgoingChannel,
+ ChannelReader incomingChannel)
+ : base(loggerFactory)
+ {
+ _logger = loggerFactory?.CreateLogger()
+ ?? NullLogger.Instance;
+ _serverName = serverName;
+ _outgoingChannel = outgoingChannel;
+ _incomingChannel = incomingChannel;
+ }
+
+
+
+ ///
+ public async Task ConnectAsync(CancellationToken cancellationToken = default)
+ {
+ await _connectLock.WaitAsync(cancellationToken).ConfigureAwait(false);
+ try
+ {
+ ThrowIfDisposed();
+
+ if (IsConnected)
+ {
+ _logger.TransportAlreadyConnected(EndpointName);
+ throw new McpTransportException("Transport is already connected");
+ }
+
+ _logger.TransportConnecting(EndpointName);
+
+ try
+ {
+ // Start the server if it exists and is not already connected
+ if (ServerTransport != null && !ServerTransport.IsConnected)
+ {
+ await ServerTransport.StartListeningAsync(cancellationToken).ConfigureAwait(false);
+ }
+
+ _cancellationTokenSource = new CancellationTokenSource();
+ _readTask = Task.Run(() => ReadMessagesAsync(_cancellationTokenSource.Token).ConfigureAwait(false), CancellationToken.None);
+
+ SetConnected(true);
+ }
+ catch (Exception ex)
+ {
+ _logger.TransportConnectFailed(EndpointName, ex);
+ await CleanupAsync(cancellationToken).ConfigureAwait(false);
+ throw new McpTransportException("Failed to connect transport", ex);
+ }
+ }
+ finally
+ {
+ _connectLock.Release();
+ }
+ }
+
+ ///
+ public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
+ {
+ ThrowIfDisposed();
+
+ if (!IsConnected)
+ {
+ _logger.TransportNotConnected(EndpointName);
+ throw new McpTransportException("Transport is not connected");
+ }
+
+ string id = "(no id)";
+ if (message is IJsonRpcMessageWithId messageWithId)
+ {
+ id = messageWithId.Id.ToString();
+ }
+
+ try
+ {
+ _logger.TransportSendingMessage(EndpointName, id);
+ await _outgoingChannel.WriteAsync(message, cancellationToken).ConfigureAwait(false);
+ _logger.TransportSentMessage(EndpointName, id);
+ }
+ catch (Exception ex)
+ {
+ _logger.TransportSendFailed(EndpointName, id, ex);
+ throw new McpTransportException("Failed to send message", ex);
+ }
+ }
+
+ ///
+ public override async ValueTask DisposeAsync()
+ {
+ await CleanupAsync(CancellationToken.None).ConfigureAwait(false);
+ GC.SuppressFinalize(this);
+ }
+
+ private async Task ReadMessagesAsync(CancellationToken cancellationToken)
+ {
+ try
+ {
+ _logger.TransportEnteringReadMessagesLoop(EndpointName);
+
+ await foreach (var message in _incomingChannel.ReadAllAsync(cancellationToken).ConfigureAwait(false))
+ {
+ var id = "(no id)";
+ if (message is IJsonRpcMessageWithId messageWithId)
+ {
+ id = messageWithId.Id.ToString();
+ }
+
+ _logger.TransportReceivedMessageParsed(EndpointName, id);
+
+ // Write to the base class's message channel that's exposed via MessageReader
+ await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false);
+
+ _logger.TransportMessageWritten(EndpointName, id);
+ }
+
+ _logger.TransportExitingReadMessagesLoop(EndpointName);
+ }
+ catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
+ {
+ _logger.TransportReadMessagesCancelled(EndpointName);
+ // Normal shutdown
+ }
+ catch (Exception ex)
+ {
+ _logger.TransportReadMessagesFailed(EndpointName, ex);
+ }
+ }
+
+ private async Task CleanupAsync(CancellationToken cancellationToken)
+ {
+ if (_disposed)
+ {
+ return;
+ }
+
+ _disposed = true;
+ _logger.TransportCleaningUp(EndpointName);
+
+ try
+ {
+ if (_cancellationTokenSource != null)
+ {
+ await _cancellationTokenSource.CancelAsync().ConfigureAwait(false);
+ _cancellationTokenSource.Dispose();
+ _cancellationTokenSource = null;
+ }
+
+ if (_readTask != null)
+ {
+ try
+ {
+ _logger.TransportWaitingForReadTask(EndpointName);
+ await _readTask.WaitAsync(TimeSpan.FromSeconds(1), cancellationToken).ConfigureAwait(false);
+ }
+ catch (TimeoutException)
+ {
+ _logger.TransportCleanupReadTaskTimeout(EndpointName);
+ }
+ catch (OperationCanceledException)
+ {
+ _logger.TransportCleanupReadTaskCancelled(EndpointName);
+ }
+ catch (Exception ex)
+ {
+ _logger.TransportCleanupReadTaskFailed(EndpointName, ex);
+ }
+ finally
+ {
+ _readTask = null;
+ }
+ }
+
+ _connectLock.Dispose();
+ }
+ finally
+ {
+ SetConnected(false);
+ _logger.TransportCleanedUp(EndpointName);
+ }
+ }
+
+ private void ThrowIfDisposed()
+ {
+ if (_disposed)
+ {
+ throw new ObjectDisposedException(nameof(InMemoryClientTransport));
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/ModelContextProtocol/Protocol/Transport/InMemoryServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/InMemoryServerTransport.cs
new file mode 100644
index 000000000..be5af42fa
--- /dev/null
+++ b/src/ModelContextProtocol/Protocol/Transport/InMemoryServerTransport.cs
@@ -0,0 +1,214 @@
+using Microsoft.Extensions.Logging;
+using Microsoft.Extensions.Logging.Abstractions;
+
+using ModelContextProtocol.Logging;
+using ModelContextProtocol.Protocol.Messages;
+
+using System.Threading.Channels;
+
+namespace ModelContextProtocol.Protocol.Transport;
+
+///
+/// Provides an in-memory implementation of the MCP server transport.
+///
+public sealed class InMemoryServerTransport : TransportBase, IServerTransport
+{
+ private string EndpointName => $"Server (in memory) for ({_serverName})";
+ private readonly ILogger _logger;
+ private readonly ChannelReader _incomingChannel;
+ private readonly ChannelWriter _outgoingChannel;
+ private CancellationTokenSource? _cancellationTokenSource;
+ private Task? _readTask;
+ private SemaphoreSlim _startLock = new SemaphoreSlim(1, 1);
+ private volatile bool _disposed;
+ private readonly string _serverName;
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The name of the server.
+ /// Optional logger factory for logging transport operations.
+ /// Channel for receiving messages from the client.
+ /// Channel for sending messages to the client.
+ internal InMemoryServerTransport(
+ string serverName,
+ ILoggerFactory? loggerFactory,
+ ChannelReader incomingChannel,
+ ChannelWriter outgoingChannel)
+ : base(loggerFactory)
+ {
+ _logger = loggerFactory?.CreateLogger()
+ ?? NullLogger.Instance;
+ _incomingChannel = incomingChannel;
+ _outgoingChannel = outgoingChannel;
+ _serverName = serverName;
+ }
+
+ ///
+ public async Task StartListeningAsync(CancellationToken cancellationToken = default)
+ {
+ await _startLock.WaitAsync(cancellationToken).ConfigureAwait(false);
+ try
+ {
+ ThrowIfDisposed();
+
+ if (IsConnected)
+ {
+ return;
+ }
+
+ _logger.TransportConnecting(EndpointName);
+
+ try
+ {
+ _cancellationTokenSource = new CancellationTokenSource();
+ _readTask = Task.Run(() => ReadMessagesAsync(_cancellationTokenSource.Token).ConfigureAwait(false), CancellationToken.None);
+
+ SetConnected(true);
+ }
+ catch (Exception ex)
+ {
+ _logger.TransportConnectFailed(EndpointName, ex);
+ await CleanupAsync(cancellationToken).ConfigureAwait(false);
+ throw new McpTransportException("Failed to connect transport", ex);
+ }
+ }
+ finally
+ {
+ _startLock.Release();
+ }
+ }
+
+ ///
+ public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
+ {
+ ThrowIfDisposed();
+
+ if (!IsConnected)
+ {
+ _logger.TransportNotConnected(EndpointName);
+ throw new McpTransportException("Transport is not connected");
+ }
+
+ string id = "(no id)";
+ if (message is IJsonRpcMessageWithId messageWithId)
+ {
+ id = messageWithId.Id.ToString();
+ }
+
+ try
+ {
+ _logger.TransportSendingMessage(EndpointName, id);
+ await _outgoingChannel.WriteAsync(message, cancellationToken).ConfigureAwait(false);
+ _logger.TransportSentMessage(EndpointName, id);
+ }
+ catch (Exception ex)
+ {
+ _logger.TransportSendFailed(EndpointName, id, ex);
+ throw new McpTransportException("Failed to send message", ex);
+ }
+ }
+
+ ///
+ public override async ValueTask DisposeAsync()
+ {
+ await CleanupAsync(CancellationToken.None).ConfigureAwait(false);
+ GC.SuppressFinalize(this);
+ }
+
+ private async Task ReadMessagesAsync(CancellationToken cancellationToken)
+ {
+ try
+ {
+ _logger.TransportEnteringReadMessagesLoop(EndpointName);
+
+ await foreach (var message in _incomingChannel.ReadAllAsync(cancellationToken).ConfigureAwait(false))
+ {
+ string id = "(no id)";
+ if (message is IJsonRpcMessageWithId messageWithId)
+ {
+ id = messageWithId.Id.ToString();
+ }
+
+ _logger.TransportReceivedMessageParsed(EndpointName, id);
+
+ // Write to the base class's message channel that's exposed via MessageReader
+ await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false);
+
+ _logger.TransportMessageWritten(EndpointName, id);
+ }
+
+ _logger.TransportExitingReadMessagesLoop(EndpointName);
+ }
+ catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
+ {
+ _logger.TransportReadMessagesCancelled(EndpointName);
+ // Normal shutdown
+ }
+ catch (Exception ex)
+ {
+ _logger.TransportReadMessagesFailed(EndpointName, ex);
+ }
+ }
+
+ private async Task CleanupAsync(CancellationToken cancellationToken)
+ {
+ if (_disposed)
+ {
+ return;
+ }
+
+ _disposed = true;
+ _logger.TransportCleaningUp(EndpointName);
+
+ try
+ {
+ if (_cancellationTokenSource != null)
+ {
+ await _cancellationTokenSource.CancelAsync().ConfigureAwait(false);
+ _cancellationTokenSource.Dispose();
+ _cancellationTokenSource = null;
+ }
+
+ if (_readTask != null)
+ {
+ try
+ {
+ _logger.TransportWaitingForReadTask(EndpointName);
+ await _readTask.WaitAsync(TimeSpan.FromSeconds(1), cancellationToken).ConfigureAwait(false);
+ }
+ catch (TimeoutException)
+ {
+ _logger.TransportCleanupReadTaskTimeout(EndpointName);
+ }
+ catch (OperationCanceledException)
+ {
+ _logger.TransportCleanupReadTaskCancelled(EndpointName);
+ }
+ catch (Exception ex)
+ {
+ _logger.TransportCleanupReadTaskFailed(EndpointName, ex);
+ }
+ finally
+ {
+ _readTask = null;
+ }
+ }
+
+ _startLock.Dispose();
+ }
+ finally
+ {
+ SetConnected(false);
+ _logger.TransportCleanedUp(EndpointName);
+ }
+ }
+
+ private void ThrowIfDisposed()
+ {
+ if (_disposed)
+ {
+ throw new ObjectDisposedException(nameof(InMemoryServerTransport));
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/ModelContextProtocol/Protocol/Transport/InMemoryTransport.cs b/src/ModelContextProtocol/Protocol/Transport/InMemoryTransport.cs
new file mode 100644
index 000000000..3119b8680
--- /dev/null
+++ b/src/ModelContextProtocol/Protocol/Transport/InMemoryTransport.cs
@@ -0,0 +1,121 @@
+using Microsoft.Extensions.Logging;
+using Microsoft.Extensions.Options;
+
+using ModelContextProtocol.Protocol.Messages;
+using ModelContextProtocol.Server;
+using ModelContextProtocol.Utils;
+
+using System.Threading.Channels;
+
+namespace ModelContextProtocol.Protocol.Transport;
+
+///
+/// Factory that creates linked in-memory client and server transports for testing purposes.
+///
+public sealed class InMemoryTransport
+{
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The server options.
+ /// Optional logger factory used for logging employed by the transport.
+ /// is or contains a null name.
+
+ public InMemoryTransport(McpServerOptions serverOptions, ILoggerFactory? loggerFactory = null)
+ : this(GetServerName(serverOptions), loggerFactory)
+ {
+ }
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The server options.
+ /// Optional logger factory used for logging employed by the transport.
+ /// is or contains a null name.
+
+ public InMemoryTransport(IOptions serverOptions, ILoggerFactory? loggerFactory = null)
+ : this(GetServerName(serverOptions.Value), loggerFactory)
+ {
+ }
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The name of the server.
+ /// Optional logger factory used for logging employed by the transport.
+ /// is .
+ ///
+ ///
+ /// By default, no logging is performed. If a is supplied, it must not log
+ /// to , as that will interfere with the transport's output.
+ ///
+ ///
+ public InMemoryTransport(string serverName, ILoggerFactory? loggerFactory = null)
+ {
+ var (clientTransport, serverTransport) = Create(serverName, loggerFactory);
+ ServerTransport = serverTransport;
+ ClientTransport = clientTransport;
+ }
+
+ ///
+ /// Gets the client transport.
+ ///
+ public IClientTransport ClientTransport { get; }
+
+ ///
+ /// Gets the server transport.
+ ///
+ public IServerTransport ServerTransport { get; }
+
+
+ private static (InMemoryClientTransport ClientTransport, InMemoryServerTransport ServerTransport) Create(
+ string serverName,
+ ILoggerFactory? loggerFactory = null)
+ {
+ // Configure client-to-server channel - this will be used for:
+ // 1. Client's outgoing channel
+ // 2. Server's MessageReader
+ var clientToServerChannel = Channel.CreateUnbounded(new UnboundedChannelOptions
+ {
+ SingleReader = false, // Both server and the server's MessageReader will read
+ SingleWriter = true, // Client writes
+ AllowSynchronousContinuations = true
+ });
+
+ // Configure server-to-client channel - this will be used for:
+ // 1. Server's outgoing channel
+ // 2. Client's MessageReader
+ var serverToClientChannel = Channel.CreateUnbounded(new UnboundedChannelOptions
+ {
+ SingleReader = false, // Both client and the client's MessageReader will read
+ SingleWriter = true, // Server writes
+ AllowSynchronousContinuations = true
+ });
+
+ // Create the client and server transports - they directly expose the channels through MessageReader
+ var serverTransport = new InMemoryServerTransport(
+ serverName,
+ loggerFactory,
+ clientToServerChannel.Reader, // incoming: reads messages from client
+ serverToClientChannel.Writer); // outgoing: writes messages to client
+
+ var clientTransport = new InMemoryClientTransport(
+ serverName,
+ loggerFactory,
+ clientToServerChannel.Writer, // outgoing: writes messages to server
+ serverToClientChannel.Reader); // incoming: reads messages from server
+
+ // Link the transports together
+ clientTransport.ServerTransport = serverTransport;
+
+ return (clientTransport, serverTransport);
+ }
+
+ private static string GetServerName(McpServerOptions serverOptions)
+ {
+ Throw.IfNull(serverOptions);
+ Throw.IfNull(serverOptions.ServerInfo);
+
+ return serverOptions.ServerInfo.Name;
+ }
+}
\ No newline at end of file
diff --git a/src/ModelContextProtocol/Protocol/Transport/TransportTypes.cs b/src/ModelContextProtocol/Protocol/Transport/TransportTypes.cs
index 02de0a6b5..8a8c9f569 100644
--- a/src/ModelContextProtocol/Protocol/Transport/TransportTypes.cs
+++ b/src/ModelContextProtocol/Protocol/Transport/TransportTypes.cs
@@ -14,4 +14,9 @@ public static class TransportTypes
/// The name of the ServerSideEvents transport.
///
public const string Sse = "sse";
+
+ ///
+ /// The name of the InMemory transport.
+ ///
+ public const string InMemory = "inmemory";
}
diff --git a/src/ModelContextProtocol/Server/McpServerExtensions.cs b/src/ModelContextProtocol/Server/McpServerExtensions.cs
index 7bff56642..8a676259c 100644
--- a/src/ModelContextProtocol/Server/McpServerExtensions.cs
+++ b/src/ModelContextProtocol/Server/McpServerExtensions.cs
@@ -1,7 +1,13 @@
-using ModelContextProtocol.Protocol.Messages;
+using Microsoft.Extensions.AI;
+using Microsoft.Extensions.DependencyInjection;
+
+using ModelContextProtocol.Client;
+using ModelContextProtocol.Configuration;
+using ModelContextProtocol.Protocol.Messages;
+using ModelContextProtocol.Protocol.Transport;
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Utils;
-using Microsoft.Extensions.AI;
+
using System.Runtime.CompilerServices;
using System.Text;
@@ -10,6 +16,27 @@ namespace ModelContextProtocol.Server;
///
public static class McpServerExtensions
{
+ ///
+ /// Gets an in-memory client for the server.
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static async Task GetInMemoryClientAsync(this IMcpServer server, CancellationToken cancellationToke = default)
+ {
+ var client = await McpClientFactory.CreateAsync(
+ new McpServerConfig
+ {
+ Id = server.ServerOptions.ServerInfo.Name,
+ Name = server.ServerOptions.ServerInfo.Name,
+ TransportType = TransportTypes.InMemory,
+ },
+ createTransportFunc: (_, _) => server.Services?.GetRequiredService() ?? throw new InvalidOperationException(),
+ cancellationToken: cancellationToke);
+ return client;
+ }
+
///
/// Requests to sample an LLM via the client.
///
@@ -42,7 +69,7 @@ public static Task RequestSamplingAsync(
/// is .
/// The client does not support sampling.
public static async Task RequestSamplingAsync(
- this IMcpServer server,
+ this IMcpServer server,
IEnumerable messages, ChatOptions? options = default, CancellationToken cancellationToken = default)
{
Throw.IfNull(server);
@@ -112,14 +139,14 @@ public static async Task RequestSamplingAsync(
}
var result = await server.RequestSamplingAsync(new()
- {
- Messages = samplingMessages,
- MaxTokens = options?.MaxOutputTokens,
- StopSequences = options?.StopSequences?.ToArray(),
- SystemPrompt = systemPrompt?.ToString(),
- Temperature = options?.Temperature,
- ModelPreferences = modelPreferences,
- }, cancellationToken).ConfigureAwait(false);
+ {
+ Messages = samplingMessages,
+ MaxTokens = options?.MaxOutputTokens,
+ StopSequences = options?.StopSequences?.ToArray(),
+ SystemPrompt = systemPrompt?.ToString(),
+ Temperature = options?.Temperature,
+ ModelPreferences = modelPreferences,
+ }, cancellationToken).ConfigureAwait(false);
return new(new ChatMessage(new(result.Role), [result.Content.ToAIContent()]))
{
diff --git a/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs
new file mode 100644
index 000000000..73fa31014
--- /dev/null
+++ b/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs
@@ -0,0 +1,158 @@
+using Microsoft.Extensions.DependencyInjection;
+
+using ModelContextProtocol.Client;
+using ModelContextProtocol.Protocol.Messages;
+using ModelContextProtocol.Protocol.Transport;
+using ModelContextProtocol.Server;
+using ModelContextProtocol.Tests.Configuration;
+using ModelContextProtocol.Tests.Utils;
+
+using System.Text.Json;
+
+namespace ModelContextProtocol.Tests.Transport;
+
+public class InMemoryTransportTests(ITestOutputHelper testOutputHelper) : LoggedTest(testOutputHelper)
+{
+
+ [Fact]
+ public async Task SendMessageAsync_Throws_Exception_If_Not_Connected()
+ {
+ var transport = new InMemoryTransport("test", LoggerFactory);
+ var serverTransport = transport.ServerTransport;
+ var clientTransport = transport.ClientTransport;
+
+ var message = new JsonRpcRequest { Method = "test" };
+
+ await Assert.ThrowsAsync(() => serverTransport.SendMessageAsync(message, TestContext.Current.CancellationToken));
+ await Assert.ThrowsAsync(() => clientTransport.SendMessageAsync(message, TestContext.Current.CancellationToken));
+ }
+
+ [Fact]
+ public async Task DisposeAsync_Should_Dispose_Resources()
+ {
+ var transport = new InMemoryTransport("test", LoggerFactory);
+ var serverTransport = transport.ServerTransport;
+ var clientTransport = transport.ClientTransport;
+
+ await serverTransport.DisposeAsync();
+ await clientTransport.DisposeAsync();
+
+ Assert.False(serverTransport.IsConnected);
+ Assert.False(clientTransport.IsConnected);
+ }
+
+ [Fact]
+ public async Task TransportPair_Should_Create_Valid_Transports()
+ {
+ var transport = new InMemoryTransport("test", LoggerFactory);
+ var serverTransport = transport.ServerTransport;
+ var clientTransport = transport.ClientTransport;
+
+ Assert.NotNull(clientTransport);
+ Assert.NotNull(serverTransport);
+ Assert.False(clientTransport.IsConnected);
+ Assert.False(serverTransport.IsConnected);
+
+ await clientTransport.DisposeAsync();
+ await serverTransport.DisposeAsync();
+ }
+
+
+ [Fact]
+ public async Task Message_Should_Flow_From_Client_To_Server()
+ {
+ var transport = new InMemoryTransport("test", LoggerFactory);
+ var serverTransport = transport.ServerTransport;
+ var clientTransport = transport.ClientTransport;
+
+ await clientTransport.ConnectAsync(TestContext.Current.CancellationToken);
+
+ var message = new JsonRpcRequest
+ {
+ Method = "test",
+ Id = RequestId.FromNumber(123),
+ Params = new Dictionary { ["text"] = "Hello, World!" }
+ };
+
+
+ await clientTransport.SendMessageAsync(message, TestContext.Current.CancellationToken);
+ await Task.Delay(2, TestContext.Current.CancellationToken);
+
+
+ Assert.True(serverTransport.MessageReader.TryRead(out var receivedMessage));
+ Assert.NotNull(receivedMessage);
+ Assert.IsType(receivedMessage);
+
+ var request = (JsonRpcRequest)receivedMessage;
+ Assert.Equal(123, request.Id.AsNumber);
+ Assert.Equal("test", request.Method);
+
+ var requestParams = (Dictionary)request.Params!;
+ Assert.Equal("Hello, World!", requestParams["text"]);
+
+ // Cleanup
+ await clientTransport.DisposeAsync();
+ await serverTransport.DisposeAsync();
+ }
+
+ [Fact]
+ public async Task Message_Should_Flow_From_Server_To_Client()
+ {
+ var transport = new InMemoryTransport("test", LoggerFactory);
+ var serverTransport = transport.ServerTransport;
+ var clientTransport = transport.ClientTransport;
+
+ await clientTransport.ConnectAsync(TestContext.Current.CancellationToken);
+
+ var message = new JsonRpcResponse
+ {
+ Id = RequestId.FromNumber(456),
+ Result = new Dictionary { ["text"] = "Response from server" }
+ };
+
+
+ await serverTransport.SendMessageAsync(message, TestContext.Current.CancellationToken);
+ await Task.Delay(2, TestContext.Current.CancellationToken);
+
+ Assert.True(clientTransport.MessageReader.TryRead(out var receivedMessage));
+ Assert.NotNull(receivedMessage);
+ Assert.IsType(receivedMessage);
+
+ var response = (JsonRpcResponse)receivedMessage;
+ Assert.Equal(456, response.Id.AsNumber);
+
+ var responseResult = (Dictionary)response.Result!;
+ Assert.Equal("Response from server", responseResult["text"]);
+
+ // Cleanup
+ await clientTransport.DisposeAsync();
+ await serverTransport.DisposeAsync();
+ }
+
+ [Fact]
+ public async Task Can_List_Registered_Tools()
+ {
+ ServiceCollection sc = new();
+ var builder = sc.AddMcpServer().WithTools().WithInMemoryServerTransport();
+ var server = sc.BuildServiceProvider().GetRequiredService();
+ await server.StartAsync(TestContext.Current.CancellationToken);
+
+ IMcpClient client = await server.GetInMemoryClientAsync(TestContext.Current.CancellationToken);
+
+
+ var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
+ Assert.Equal(11, tools.Count);
+
+ McpClientTool echoTool = tools.First(t => t.Name == "Echo");
+ Assert.Equal("Echo", echoTool.Name);
+ Assert.Equal("Echoes the input back to the client.", echoTool.Description);
+ Assert.Equal("object", echoTool.JsonSchema.GetProperty("type").GetString());
+ Assert.Equal(JsonValueKind.Object, echoTool.JsonSchema.GetProperty("properties").GetProperty("message").ValueKind);
+ Assert.Equal("the echoes message", echoTool.JsonSchema.GetProperty("properties").GetProperty("message").GetProperty("description").GetString());
+ Assert.Equal(1, echoTool.JsonSchema.GetProperty("required").GetArrayLength());
+
+ McpClientTool doubleEchoTool = tools.First(t => t.Name == "double_echo");
+ Assert.Equal("double_echo", doubleEchoTool.Name);
+ Assert.Equal("Echoes the input back to the client.", doubleEchoTool.Description);
+ }
+}
diff --git a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs
index 94ae62728..4c4941ed7 100644
--- a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs
+++ b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs
@@ -4,6 +4,7 @@
using ModelContextProtocol.Server;
using ModelContextProtocol.Tests.Utils;
using ModelContextProtocol.Utils.Json;
+
using System.IO.Pipelines;
using System.Text;
using System.Text.Json;
@@ -70,12 +71,12 @@ public async Task SendMessageAsync_Should_Send_Message()
new Pipe().Reader.AsStream(),
output,
LoggerFactory);
-
+
await transport.StartListeningAsync(TestContext.Current.CancellationToken);
-
+
// Ensure transport is fully initialized
await Task.Delay(100, TestContext.Current.CancellationToken);
-
+
// Verify transport is connected
Assert.True(transport.IsConnected, "Transport should be connected after StartListeningAsync");
@@ -124,12 +125,12 @@ public async Task ReadMessagesAsync_Should_Read_Messages()
input,
Stream.Null,
LoggerFactory);
-
+
await transport.StartListeningAsync(TestContext.Current.CancellationToken);
-
+
// Ensure transport is fully initialized
await Task.Delay(100, TestContext.Current.CancellationToken);
-
+
// Verify transport is connected
Assert.True(transport.IsConnected, "Transport should be connected after StartListeningAsync");
@@ -163,24 +164,24 @@ public async Task SendMessageAsync_Should_Preserve_Unicode_Characters()
using var output = new MemoryStream();
await using var transport = new StdioServerTransport(
- _serverOptions.ServerInfo.Name,
+ _serverOptions.ServerInfo.Name,
new Pipe().Reader.AsStream(),
- output,
+ output,
LoggerFactory);
-
+
await transport.StartListeningAsync(TestContext.Current.CancellationToken);
-
+
// Ensure transport is fully initialized
await Task.Delay(100, TestContext.Current.CancellationToken);
-
+
// Verify transport is connected
Assert.True(transport.IsConnected, "Transport should be connected after StartListeningAsync");
// Test 1: Chinese characters (BMP Unicode)
var chineseText = "上下文伺服器"; // "Context Server" in Chinese
- var chineseMessage = new JsonRpcRequest
- {
- Method = "test",
+ var chineseMessage = new JsonRpcRequest
+ {
+ Method = "test",
Id = RequestId.FromNumber(44),
Params = new Dictionary
{
@@ -191,18 +192,18 @@ public async Task SendMessageAsync_Should_Preserve_Unicode_Characters()
// Clear output and send message
output.SetLength(0);
await transport.SendMessageAsync(chineseMessage, TestContext.Current.CancellationToken);
-
+
// Verify Chinese characters preserved but encoded
var chineseResult = Encoding.UTF8.GetString(output.ToArray()).Trim();
var expectedChinese = JsonSerializer.Serialize(chineseMessage, McpJsonUtilities.DefaultOptions);
Assert.Equal(expectedChinese, chineseResult);
Assert.Contains(JsonSerializer.Serialize(chineseText), chineseResult);
-
+
// Test 2: Emoji (non-BMP Unicode using surrogate pairs)
var emojiText = "🔍 🚀 👍"; // Magnifying glass, rocket, thumbs up
- var emojiMessage = new JsonRpcRequest
- {
- Method = "test",
+ var emojiMessage = new JsonRpcRequest
+ {
+ Method = "test",
Id = RequestId.FromNumber(45),
Params = new Dictionary
{
@@ -213,23 +214,23 @@ public async Task SendMessageAsync_Should_Preserve_Unicode_Characters()
// Clear output and send message
output.SetLength(0);
await transport.SendMessageAsync(emojiMessage, TestContext.Current.CancellationToken);
-
+
// Verify emoji preserved - might be as either direct characters or escape sequences
var emojiResult = Encoding.UTF8.GetString(output.ToArray()).Trim();
var expectedEmoji = JsonSerializer.Serialize(emojiMessage, McpJsonUtilities.DefaultOptions);
Assert.Equal(expectedEmoji, emojiResult);
-
+
// Verify surrogate pairs in different possible formats
// Magnifying glass emoji: 🔍 (U+1F50D)
- bool magnifyingGlassFound =
- emojiResult.Contains("🔍") ||
+ bool magnifyingGlassFound =
+ emojiResult.Contains("🔍") ||
emojiResult.Contains("\\ud83d\\udd0d", StringComparison.OrdinalIgnoreCase);
-
+
// Rocket emoji: 🚀 (U+1F680)
- bool rocketFound =
- emojiResult.Contains("🚀") ||
+ bool rocketFound =
+ emojiResult.Contains("🚀") ||
emojiResult.Contains("\\ud83d\\ude80", StringComparison.OrdinalIgnoreCase);
-
+
Assert.True(magnifyingGlassFound, "Magnifying glass emoji not found in result");
Assert.True(rocketFound, "Rocket emoji not found in result");
}