Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -618,8 +618,9 @@ private async Task PerformDynamicClientRegistrationAsync(
Scope = GetScopeParameter(protectedResourceMetadata),
};

var requestJson = JsonSerializer.Serialize(registrationRequest, McpJsonUtilities.JsonContext.Default.DynamicClientRegistrationRequest);
using var requestContent = new StringContent(requestJson, Encoding.UTF8, "application/json");
var requestBytes = JsonSerializer.SerializeToUtf8Bytes(registrationRequest, McpJsonUtilities.JsonContext.Default.DynamicClientRegistrationRequest);
using var requestContent = new ByteArrayContent(requestBytes);
requestContent.Headers.ContentType = McpHttpClient.s_applicationJsonContentType;

using var request = new HttpRequestMessage(HttpMethod.Post, authServerMetadata.RegistrationEndpoint)
{
Expand Down
13 changes: 7 additions & 6 deletions src/ModelContextProtocol.Core/Client/McpHttpClient.cs
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
using ModelContextProtocol.Protocol;
using System.Diagnostics;
using System.Net.Http.Headers;

#if NET
using System.Net.Http.Json;
#else
using System.Text;
using System.Text.Json;
#endif

namespace ModelContextProtocol.Client;

internal class McpHttpClient(HttpClient httpClient)
{
internal static readonly MediaTypeHeaderValue s_applicationJsonContentType = new("application/json") { CharSet = "utf-8" };

internal virtual async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, JsonRpcMessage? message, CancellationToken cancellationToken)
{
Debug.Assert(request.Content is null, "The request body should only be supplied as a JsonRpcMessage");
Expand All @@ -32,11 +34,10 @@ internal virtual async Task<HttpResponseMessage> SendAsync(HttpRequestMessage re
#if NET
return JsonContent.Create(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage);
#else
return new StringContent(
JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage),
Encoding.UTF8,
"application/json"
);
var bytes = JsonSerializer.SerializeToUtf8Bytes(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage);
var content = new ByteArrayContent(bytes);
content.Headers.ContentType = s_applicationJsonContentType;
return content;
#endif
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ namespace ModelContextProtocol.Client;
/// <summary>Provides the client side of a stream-based session transport.</summary>
internal class StreamClientSessionTransport : TransportBase
{
private static readonly byte[] s_newlineBytes = "\n"u8.ToArray();

internal static UTF8Encoding NoBomUtf8Encoding { get; } = new(encoderShouldEmitUTF8Identifier: false);

private readonly TextReader _serverOutput;
private readonly TextWriter _serverInput;
private readonly Stream _serverInputStream;
private readonly SemaphoreSlim _sendLock = new(1, 1);
private CancellationTokenSource? _shutdownCts = new();
private Task? _readTask;
Expand All @@ -20,12 +22,13 @@ internal class StreamClientSessionTransport : TransportBase
/// Initializes a new instance of the <see cref="StreamClientSessionTransport"/> class.
/// </summary>
/// <param name="serverInput">
/// The text writer connected to the server's input stream.
/// Messages written to this writer will be sent to the server.
/// The server's input stream. Messages written to this stream will be sent to the server.
/// </param>
/// <param name="serverOutput">
/// The text reader connected to the server's output stream.
/// Messages read from this reader will be received from the server.
/// The server's output stream. Messages read from this stream will be received from the server.
/// </param>
/// <param name="encoding">
/// The encoding used for reading and writing messages from the input and output streams. Defaults to UTF-8 without BOM if null.
/// </param>
/// <param name="endpointName">
/// A name that identifies this transport endpoint in logs.
Expand All @@ -37,12 +40,18 @@ internal class StreamClientSessionTransport : TransportBase
/// This constructor starts a background task to read messages from the server output stream.
/// The transport will be marked as connected once initialized.
/// </remarks>
public StreamClientSessionTransport(
TextWriter serverInput, TextReader serverOutput, string endpointName, ILoggerFactory? loggerFactory)
public StreamClientSessionTransport(Stream serverInput, Stream serverOutput, Encoding? encoding, string endpointName, ILoggerFactory? loggerFactory)
: base(endpointName, loggerFactory)
{
_serverOutput = serverOutput;
_serverInput = serverInput;
Throw.IfNull(serverInput);
Throw.IfNull(serverOutput);

_serverInputStream = serverInput;
#if NET
_serverOutput = new StreamReader(serverOutput, encoding ?? NoBomUtf8Encoding);
#else
_serverOutput = new CancellableStreamReader(serverOutput, encoding ?? NoBomUtf8Encoding);
#endif

SetConnected();

Expand All @@ -57,43 +66,6 @@ public StreamClientSessionTransport(
readTask.Start();
}

/// <summary>
/// Initializes a new instance of the <see cref="StreamClientSessionTransport"/> class.
/// </summary>
/// <param name="serverInput">
/// The server's input stream. Messages written to this stream will be sent to the server.
/// </param>
/// <param name="serverOutput">
/// The server's output stream. Messages read from this stream will be received from the server.
/// </param>
/// <param name="encoding">
/// The encoding used for reading and writing messages from the input and output streams. Defaults to UTF-8 without BOM if null.
/// </param>
/// <param name="endpointName">
/// A name that identifies this transport endpoint in logs.
/// </param>
/// <param name="loggerFactory">
/// Optional factory for creating loggers. If null, a NullLogger is used.
/// </param>
/// <remarks>
/// This constructor starts a background task to read messages from the server output stream.
/// The transport will be marked as connected once initialized.
/// </remarks>
public StreamClientSessionTransport(Stream serverInput, Stream serverOutput, Encoding? encoding, string endpointName, ILoggerFactory? loggerFactory)
: this(
new StreamWriter(serverInput, encoding ?? NoBomUtf8Encoding),
#if NET
new StreamReader(serverOutput, encoding ?? NoBomUtf8Encoding),
#else
new CancellableStreamReader(serverOutput, encoding ?? NoBomUtf8Encoding),
#endif
endpointName,
loggerFactory)
{
Throw.IfNull(serverInput);
Throw.IfNull(serverOutput);
}

/// <inheritdoc/>
public override async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default)
{
Expand All @@ -103,16 +75,15 @@ public override async Task SendMessageAsync(JsonRpcMessage message, Cancellation
id = messageWithId.Id.ToString();
}

var json = JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage);

LogTransportSendingMessageSensitive(Name, json);
LogTransportSendingMessageSensitive(message);

using var _ = await _sendLock.LockAsync(cancellationToken).ConfigureAwait(false);
try
{
// Write the message followed by a newline using our UTF-8 writer
await _serverInput.WriteLineAsync(json).ConfigureAwait(false);
await _serverInput.FlushAsync(cancellationToken).ConfigureAwait(false);
var json = JsonSerializer.SerializeToUtf8Bytes(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage);
await _serverInputStream.WriteAsync(json, cancellationToken).ConfigureAwait(false);
await _serverInputStream.WriteAsync(s_newlineBytes, cancellationToken).ConfigureAwait(false);
await _serverInputStream.FlushAsync(cancellationToken).ConfigureAwait(false);
}
catch (Exception ex)
{
Expand Down
6 changes: 3 additions & 3 deletions src/ModelContextProtocol.Core/Server/StreamServerTransport.cs
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ public override async Task SendMessageAsync(JsonRpcMessage message, Cancellation

try
{
var json = JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage);
LogTransportSendingMessageSensitive(Name, json);
await _outputStream.WriteAsync(Encoding.UTF8.GetBytes(json), cancellationToken).ConfigureAwait(false);
var json = JsonSerializer.SerializeToUtf8Bytes(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage);
LogTransportSendingMessageSensitive(message);
await _outputStream.WriteAsync(json, cancellationToken).ConfigureAwait(false);
await _outputStream.WriteAsync(s_newlineBytes, cancellationToken).ConfigureAwait(false);
await _outputStream.FlushAsync(cancellationToken).ConfigureAwait(false);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
using ModelContextProtocol.Client;
using ModelContextProtocol.Protocol;
using ModelContextProtocol.Tests.Utils;
using System.IO.Pipelines;
using System.Runtime.InteropServices;
using System.Text;
using System.Text.Json;

namespace ModelContextProtocol.Tests.Transport;

Expand Down Expand Up @@ -136,4 +138,54 @@ public async Task EscapesCliArgumentsCorrectly(string? cliArgumentValue)
var content = Assert.IsType<TextContentBlock>(Assert.Single(result.Content));
Assert.Equal(cliArgumentValue ?? "", content.Text);
}

[Fact]
public async Task SendMessageAsync_Should_Use_LF_Not_CRLF()
{
using var serverInput = new MemoryStream();
Pipe serverOutputPipe = new();

var transport = new StreamClientTransport(serverInput, serverOutputPipe.Reader.AsStream(), LoggerFactory);
await using var sessionTransport = await transport.ConnectAsync(TestContext.Current.CancellationToken);

var message = new JsonRpcRequest { Method = "test", Id = new RequestId(44) };

await sessionTransport.SendMessageAsync(message, TestContext.Current.CancellationToken);

byte[] bytes = serverInput.ToArray();

// The output should end with exactly \n (0x0A), not \r\n (0x0D 0x0A).
Assert.True(bytes.Length > 1, "Output should contain message data");
Assert.Equal((byte)'\n', bytes[^1]);
Assert.NotEqual((byte)'\r', bytes[^2]);

// Also verify the JSON content is valid
var json = Encoding.UTF8.GetString(bytes).TrimEnd('\n');
var expected = JsonSerializer.Serialize(message, McpJsonUtilities.DefaultOptions);
Assert.Equal(expected, json);
}

[Fact]
public async Task ReadMessagesAsync_Should_Accept_CRLF_Delimited_Messages()
{
Pipe serverInputPipe = new();
Pipe serverOutputPipe = new();

var transport = new StreamClientTransport(serverInputPipe.Writer.AsStream(), serverOutputPipe.Reader.AsStream(), LoggerFactory);
await using var sessionTransport = await transport.ConnectAsync(TestContext.Current.CancellationToken);

var message = new JsonRpcRequest { Method = "test", Id = new RequestId(44) };
var json = JsonSerializer.Serialize(message, McpJsonUtilities.DefaultOptions);

// Write a \r\n-delimited message to the server's output (which the client reads)
await serverOutputPipe.Writer.WriteAsync(Encoding.UTF8.GetBytes($"{json}\r\n"), TestContext.Current.CancellationToken);

var canRead = await sessionTransport.MessageReader.WaitToReadAsync(TestContext.Current.CancellationToken);

Assert.True(canRead, "Should be able to read a \\r\\n-delimited message");
Assert.True(sessionTransport.MessageReader.TryPeek(out var readMessage));
Assert.NotNull(readMessage);
Assert.IsType<JsonRpcRequest>(readMessage);
Assert.Equal("44", ((JsonRpcRequest)readMessage).Id.ToString());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,54 @@ public async Task SendMessageAsync_Should_Log_At_Trace_Level()
Assert.Contains(traceLogMessages, x => x.Message.Contains("\"method\":\"test\"") && x.Message.Contains("\"id\":44"));
}

[Fact]
public async Task SendMessageAsync_Should_Use_LF_Not_CRLF()
{
using var output = new MemoryStream();

await using var transport = new StreamServerTransport(
new Pipe().Reader.AsStream(),
output,
loggerFactory: LoggerFactory);

var message = new JsonRpcRequest { Method = "test", Id = new RequestId(44) };

await transport.SendMessageAsync(message, TestContext.Current.CancellationToken);

byte[] bytes = output.ToArray();

// The output should end with exactly \n (0x0A), not \r\n (0x0D 0x0A).
Assert.True(bytes.Length > 1, "Output should contain message data");
Assert.Equal((byte)'\n', bytes[^1]);
Assert.NotEqual((byte)'\r', bytes[^2]);
}

[Fact]
public async Task ReadMessagesAsync_Should_Accept_CRLF_Delimited_Messages()
{
var message = new JsonRpcRequest { Method = "test", Id = new RequestId(44) };
var json = JsonSerializer.Serialize(message, McpJsonUtilities.DefaultOptions);

Pipe pipe = new();
using var input = pipe.Reader.AsStream();

await using var transport = new StreamServerTransport(
input,
Stream.Null,
loggerFactory: LoggerFactory);

// Write the message with \r\n line ending
await pipe.Writer.WriteAsync(Encoding.UTF8.GetBytes($"{json}\r\n"), TestContext.Current.CancellationToken);

var canRead = await transport.MessageReader.WaitToReadAsync(TestContext.Current.CancellationToken);

Assert.True(canRead, "Should be able to read a \\r\\n-delimited message");
Assert.True(transport.MessageReader.TryPeek(out var readMessage));
Assert.NotNull(readMessage);
Assert.IsType<JsonRpcRequest>(readMessage);
Assert.Equal("44", ((JsonRpcRequest)readMessage).Id.ToString());
}

[Fact]
public async Task ReadMessagesAsync_Should_Log_Received_At_Trace_Level()
{
Expand Down
Loading