From 0f060ce3efae76151f5c3a081d79ea53c71fc7ad Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Tue, 18 Nov 2025 09:30:12 -0500 Subject: [PATCH] Add tools support to sampling --- samples/ChatWithTools/Program.cs | 1 + .../EverythingServer/Tools/SampleLlmTool.cs | 4 +- .../Tools/SampleLlmTool.cs | 4 +- .../AIContentExtensions.cs | 212 ++++++++++++++- .../Client/McpClient.Methods.cs | 112 -------- .../Client/McpClientExtensions.cs | 122 --------- .../Client/McpClientHandlers.cs | 2 +- src/ModelContextProtocol.Core/Diagnostics.cs | 2 +- .../Protocol/ContentBlock.cs | 248 +++++++++++++----- .../Protocol/ContextInclusion.cs | 15 ++ .../Protocol/CreateMessageRequestParams.cs | 20 ++ .../Protocol/CreateMessageResult.cs | 26 +- .../Protocol/SamplingCapability.cs | 25 +- .../Protocol/SamplingContextCapability.cs | 6 + .../Protocol/SamplingMessage.cs | 25 +- .../Protocol/SamplingToolsCapability.cs | 6 + .../Protocol/SingleItemOrListConverter.cs | 67 +++++ .../Protocol/ToolChoice.cs | 32 +++ .../Server/AIFunctionMcpServerTool.cs | 4 +- .../Server/McpServer.Methods.cs | 98 ++++--- .../Server/McpServerTool.cs | 4 +- .../Server/McpServerToolAttribute.cs | 4 +- tests/Common/Utils/TestServerTransport.cs | 4 +- .../HttpServerIntegrationTests.cs | 5 +- .../MapMcpTests.cs | 12 +- .../Program.cs | 6 +- .../Program.cs | 8 +- .../AIContentExtensionsTests.cs | 123 +++++++++ .../Client/McpClientCreationTests.cs | 4 +- .../Client/McpClientExtensionsTests.cs | 9 +- .../Client/McpClientTests.cs | 139 +++++++++- .../Client/McpClientToolTests.cs | 9 +- .../ClientIntegrationTests.cs | 4 +- .../DockerEverythingServerTests.cs | 4 +- .../Protocol/ContentBlockTests.cs | 102 ++++++- .../CreateMessageRequestParamsTests.cs | 174 ++++++++++++ .../Protocol/CreateMessageResultTests.cs | 247 +++++++++++++++++ .../Protocol/SamplingMessageTests.cs | 111 ++++++++ .../Protocol/ToolChoiceTests.cs | 30 +++ .../Server/McpServerExtensionsTests.cs | 10 +- .../Server/McpServerTests.cs | 8 +- 41 files changed, 1623 insertions(+), 425 deletions(-) create mode 100644 src/ModelContextProtocol.Core/Protocol/SamplingContextCapability.cs create mode 100644 src/ModelContextProtocol.Core/Protocol/SamplingToolsCapability.cs create mode 100644 src/ModelContextProtocol.Core/Protocol/SingleItemOrListConverter.cs create mode 100644 src/ModelContextProtocol.Core/Protocol/ToolChoice.cs create mode 100644 tests/ModelContextProtocol.Tests/Protocol/CreateMessageRequestParamsTests.cs create mode 100644 tests/ModelContextProtocol.Tests/Protocol/CreateMessageResultTests.cs create mode 100644 tests/ModelContextProtocol.Tests/Protocol/SamplingMessageTests.cs create mode 100644 tests/ModelContextProtocol.Tests/Protocol/ToolChoiceTests.cs diff --git a/samples/ChatWithTools/Program.cs b/samples/ChatWithTools/Program.cs index c6fca0493..c5870cdc3 100644 --- a/samples/ChatWithTools/Program.cs +++ b/samples/ChatWithTools/Program.cs @@ -1,5 +1,6 @@ using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; +using ModelContextProtocol; using ModelContextProtocol.Client; using OpenAI; using OpenTelemetry; diff --git a/samples/EverythingServer/Tools/SampleLlmTool.cs b/samples/EverythingServer/Tools/SampleLlmTool.cs index 6bbe6e51d..48c5184b3 100644 --- a/samples/EverythingServer/Tools/SampleLlmTool.cs +++ b/samples/EverythingServer/Tools/SampleLlmTool.cs @@ -17,7 +17,7 @@ public static async Task SampleLLM( var samplingParams = CreateRequestSamplingParams(prompt ?? string.Empty, "sampleLLM", maxTokens); var sampleResult = await server.SampleAsync(samplingParams, cancellationToken); - return $"LLM sampling result: {(sampleResult.Content as TextContentBlock)?.Text}"; + return $"LLM sampling result: {sampleResult.Content.OfType().FirstOrDefault()?.Text}"; } private static CreateMessageRequestParams CreateRequestSamplingParams(string context, string uri, int maxTokens = 100) @@ -27,7 +27,7 @@ private static CreateMessageRequestParams CreateRequestSamplingParams(string con Messages = [new SamplingMessage { Role = Role.User, - Content = new TextContentBlock { Text = $"Resource {uri} context: {context}" }, + Content = [new TextContentBlock { Text = $"Resource {uri} context: {context}" }], }], SystemPrompt = "You are a helpful test server.", MaxTokens = maxTokens, diff --git a/samples/TestServerWithHosting/Tools/SampleLlmTool.cs b/samples/TestServerWithHosting/Tools/SampleLlmTool.cs index 2c96b8c35..7d4c61784 100644 --- a/samples/TestServerWithHosting/Tools/SampleLlmTool.cs +++ b/samples/TestServerWithHosting/Tools/SampleLlmTool.cs @@ -20,7 +20,7 @@ public static async Task SampleLLM( var samplingParams = CreateRequestSamplingParams(prompt ?? string.Empty, "sampleLLM", maxTokens); var sampleResult = await thisServer.SampleAsync(samplingParams, cancellationToken); - return $"LLM sampling result: {(sampleResult.Content as TextContentBlock)?.Text}"; + return $"LLM sampling result: {sampleResult.Content.OfType().FirstOrDefault()?.Text}"; } private static CreateMessageRequestParams CreateRequestSamplingParams(string context, string uri, int maxTokens = 100) @@ -30,7 +30,7 @@ private static CreateMessageRequestParams CreateRequestSamplingParams(string con Messages = [new SamplingMessage { Role = Role.User, - Content = new TextContentBlock { Text = $"Resource {uri} context: {context}" }, + Content = [new TextContentBlock { Text = $"Resource {uri} context: {context}" }], }], SystemPrompt = "You are a helpful test server.", MaxTokens = maxTokens, diff --git a/src/ModelContextProtocol.Core/AIContentExtensions.cs b/src/ModelContextProtocol.Core/AIContentExtensions.cs index 8686b7b6a..374f00555 100644 --- a/src/ModelContextProtocol.Core/AIContentExtensions.cs +++ b/src/ModelContextProtocol.Core/AIContentExtensions.cs @@ -1,9 +1,11 @@ using Microsoft.Extensions.AI; +using ModelContextProtocol.Client; using ModelContextProtocol.Protocol; #if !NET using System.Runtime.InteropServices; #endif using System.Text.Json; +using System.Text.Json.Nodes; namespace ModelContextProtocol; @@ -16,6 +18,140 @@ namespace ModelContextProtocol; /// public static class AIContentExtensions { + /// + /// Creates a sampling handler for use with that will + /// satisfy sampling requests using the specified . + /// + /// The with which to satisfy sampling requests. + /// The created handler delegate that can be assigned to . + /// + /// + /// This method creates a function that converts MCP message requests into chat client calls, enabling + /// an MCP client to generate text or other content using an actual AI model via the provided chat client. + /// + /// + /// The handler can process text messages, image messages, resource messages, and tool use/results as defined in the + /// Model Context Protocol. + /// + /// + /// is . + public static Func, CancellationToken, ValueTask> CreateSamplingHandler( + this IChatClient chatClient) + { + Throw.IfNull(chatClient); + + return async (requestParams, progress, cancellationToken) => + { + Throw.IfNull(requestParams); + + var (messages, options) = ToChatClientArguments(requestParams); + var progressToken = requestParams.ProgressToken; + + List updates = []; + await foreach (var update in chatClient.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) + { + updates.Add(update); + + if (progressToken is not null) + { + progress.Report(new() { Progress = updates.Count }); + } + } + + ChatResponse? chatResponse = updates.ToChatResponse(); + ChatMessage? lastMessage = chatResponse.Messages.LastOrDefault(); + + IList? contents = lastMessage?.Contents.Select(c => c.ToContentBlock()).ToList(); + if (contents is not { Count: > 0 }) + { + (contents ??= []).Add(new TextContentBlock() { Text = "" }); + } + + return new() + { + Model = chatResponse.ModelId ?? "", + StopReason = + chatResponse.FinishReason == ChatFinishReason.Stop ? CreateMessageResult.StopReasonEndTurn : + chatResponse.FinishReason == ChatFinishReason.Length ? CreateMessageResult.StopReasonMaxTokens : + chatResponse.FinishReason == ChatFinishReason.ToolCalls ? CreateMessageResult.StopReasonToolUse : + chatResponse.FinishReason.ToString(), + Meta = chatResponse.AdditionalProperties?.ToJsonObject(), + Role = lastMessage?.Role == ChatRole.User ? Role.User : Role.Assistant, + Content = contents, + }; + + static (IList Messages, ChatOptions? Options) ToChatClientArguments(CreateMessageRequestParams requestParams) + { + ChatOptions? options = null; + + if (requestParams.MaxTokens is int maxTokens) + { + (options ??= new()).MaxOutputTokens = maxTokens; + } + + if (requestParams.Temperature is float temperature) + { + (options ??= new()).Temperature = temperature; + } + + if (requestParams.StopSequences is { } stopSequences) + { + (options ??= new()).StopSequences = stopSequences.ToArray(); + } + + if (requestParams.SystemPrompt is { } systemPrompt) + { + (options ??= new()).Instructions = systemPrompt; + } + + if (requestParams.Tools is { } tools) + { + foreach (var tool in tools) + { + ((options ??= new()).Tools ??= []).Add(new ToolAIFunctionDeclaration(tool)); + } + + if (options.Tools is { Count: > 0 } && requestParams.ToolChoice is { } toolChoice) + { + options.ToolMode = toolChoice.Mode switch + { + ToolChoice.ModeAuto => ChatToolMode.Auto, + ToolChoice.ModeRequired => ChatToolMode.RequireAny, + ToolChoice.ModeNone => ChatToolMode.None, + _ => null, + }; + } + } + + List messages = []; + foreach (var sm in requestParams.Messages) + { + if (sm.Content?.Select(b => b.ToAIContent()).OfType().ToList() is { Count: > 0 } aiContents) + { + messages.Add(new ChatMessage(sm.Role is Role.Assistant ? ChatRole.Assistant : ChatRole.User, aiContents)); + } + } + + return (messages, options); + } + }; + } + + /// Converts the specified dictionary to a . + internal static JsonObject? ToJsonObject(this IReadOnlyDictionary properties) => + JsonSerializer.SerializeToNode(properties, McpJsonUtilities.JsonContext.Default.IReadOnlyDictionaryStringObject) as JsonObject; + + internal static AdditionalPropertiesDictionary ToAdditionalProperties(this JsonObject obj) + { + AdditionalPropertiesDictionary d = []; + foreach (var kvp in obj) + { + d.Add(kvp.Key, kvp.Value); + } + + return d; + } + /// /// Converts a to a object. /// @@ -99,7 +235,7 @@ public static IList ToPromptMessages(this ChatMessage chatMessage { if (content is TextContent or DataContent) { - messages.Add(new PromptMessage { Role = r, Content = content.ToContent() }); + messages.Add(new PromptMessage { Role = r, Content = content.ToContentBlock() }); } } @@ -122,13 +258,31 @@ public static IList ToPromptMessages(this ChatMessage chatMessage AIContent? ac = content switch { TextContentBlock textContent => new TextContent(textContent.Text), + ImageContentBlock imageContent => new DataContent(Convert.FromBase64String(imageContent.Data), imageContent.MimeType), + AudioContentBlock audioContent => new DataContent(Convert.FromBase64String(audioContent.Data), audioContent.MimeType), + EmbeddedResourceBlock resourceContent => resourceContent.Resource.ToAIContent(), + + ToolUseContentBlock toolUse => FunctionCallContent.CreateFromParsedArguments(toolUse.Input, toolUse.Id, toolUse.Name, + static json => JsonSerializer.Deserialize(json, McpJsonUtilities.JsonContext.Default.IDictionaryStringObject)), + + ToolResultContentBlock toolResult => new FunctionResultContent( + toolResult.ToolUseId, + toolResult.Content.Count == 1 ? toolResult.Content[0].ToAIContent() : toolResult.Content.Select(c => c.ToAIContent()).OfType().ToList()) + { + Exception = toolResult.IsError is true ? new() : null, + }, + _ => null, }; - ac?.RawRepresentation = content; + if (ac is not null) + { + ac.RawRepresentation = content; + ac.AdditionalProperties = content.Meta?.ToAdditionalProperties(); + } return ac; } @@ -200,8 +354,12 @@ public static IList ToAIContents(this IEnumerable c return [.. contents.Select(ToAIContent)]; } - internal static ContentBlock ToContent(this AIContent content) => - content switch + /// Creates a new from the content of an . + /// The to convert. + /// The created . + public static ContentBlock ToContentBlock(this AIContent content) + { + ContentBlock contentBlock = content switch { TextContent textContent => new TextContentBlock { @@ -230,9 +388,55 @@ internal static ContentBlock ToContent(this AIContent content) => } }, + FunctionCallContent callContent => new ToolUseContentBlock() + { + Id = callContent.CallId, + Name = callContent.Name, + Input = JsonSerializer.SerializeToElement(callContent.Arguments, McpJsonUtilities.DefaultOptions.GetTypeInfo>()!), + }, + + FunctionResultContent resultContent => new ToolResultContentBlock() + { + ToolUseId = resultContent.CallId, + IsError = resultContent.Exception is not null, + Content = + resultContent.Result is AIContent c ? [c.ToContentBlock()] : + resultContent.Result is IEnumerable ec ? [.. ec.Select(c => c.ToContentBlock())] : + [new TextContentBlock { Text = JsonSerializer.Serialize(content, McpJsonUtilities.DefaultOptions.GetTypeInfo()) }], + StructuredContent = resultContent.Result is JsonElement je ? je : null, + }, + _ => new TextContentBlock { Text = JsonSerializer.Serialize(content, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(object))), } }; + + contentBlock.Meta = content.AdditionalProperties?.ToJsonObject(); + + return contentBlock; + } + + private sealed class ToolAIFunctionDeclaration(Tool tool) : AIFunctionDeclaration + { + public override string Name => tool.Name; + + public override string Description => tool.Description ?? ""; + + public override IReadOnlyDictionary AdditionalProperties => + field ??= tool.Meta is { } meta ? meta.ToDictionary(p => p.Key, p => (object?)p.Value) : []; + + public override JsonElement JsonSchema => tool.InputSchema; + + public override JsonElement? ReturnJsonSchema => tool.OutputSchema; + + public override object? GetService(Type serviceType, object? serviceKey = null) + { + Throw.IfNull(serviceType); + + return + serviceKey is null && serviceType.IsInstanceOfType(tool) ? tool : + base.GetService(serviceType, serviceKey); + } + } } diff --git a/src/ModelContextProtocol.Core/Client/McpClient.Methods.cs b/src/ModelContextProtocol.Core/Client/McpClient.Methods.cs index 5550e786e..6397d8e78 100644 --- a/src/ModelContextProtocol.Core/Client/McpClient.Methods.cs +++ b/src/ModelContextProtocol.Core/Client/McpClient.Methods.cs @@ -555,118 +555,6 @@ async ValueTask SendRequestWithProgressAsync( } } - /// - /// Converts the contents of a into a pair of - /// and instances to use - /// as inputs into a operation. - /// - /// - /// The created pair of messages and options. - /// is . - internal static (IList Messages, ChatOptions? Options) ToChatClientArguments( - CreateMessageRequestParams requestParams) - { - Throw.IfNull(requestParams); - - ChatOptions? options = null; - - if (requestParams.MaxTokens is int maxTokens) - { - (options ??= new()).MaxOutputTokens = maxTokens; - } - - if (requestParams.Temperature is float temperature) - { - (options ??= new()).Temperature = temperature; - } - - if (requestParams.StopSequences is { } stopSequences) - { - (options ??= new()).StopSequences = stopSequences.ToArray(); - } - - List messages = - (from sm in requestParams.Messages - let aiContent = sm.Content.ToAIContent() - where aiContent is not null - select new ChatMessage(sm.Role == Role.Assistant ? ChatRole.Assistant : ChatRole.User, [aiContent])) - .ToList(); - - return (messages, options); - } - - /// Converts the contents of a into a . - /// The whose contents should be extracted. - /// The created . - /// is . - internal static CreateMessageResult ToCreateMessageResult(ChatResponse chatResponse) - { - Throw.IfNull(chatResponse); - - // The ChatResponse can include multiple messages, of varying modalities, but CreateMessageResult supports - // only either a single blob of text or a single image. Heuristically, we'll use an image if there is one - // in any of the response messages, or we'll use all the text from them concatenated, otherwise. - - ChatMessage? lastMessage = chatResponse.Messages.LastOrDefault(); - - ContentBlock? content = null; - if (lastMessage is not null) - { - foreach (var lmc in lastMessage.Contents) - { - if (lmc is DataContent dc && (dc.HasTopLevelMediaType("image") || dc.HasTopLevelMediaType("audio"))) - { - content = dc.ToContent(); - } - } - } - - return new() - { - Content = content ?? new TextContentBlock { Text = lastMessage?.Text ?? string.Empty }, - Model = chatResponse.ModelId ?? "unknown", - Role = lastMessage?.Role == ChatRole.User ? Role.User : Role.Assistant, - StopReason = chatResponse.FinishReason == ChatFinishReason.Length ? "maxTokens" : "endTurn", - }; - } - - /// - /// Creates a sampling handler for use with that will - /// satisfy sampling requests using the specified . - /// - /// The with which to satisfy sampling requests. - /// The created handler delegate that can be assigned to . - /// is . - public static Func, CancellationToken, ValueTask> CreateSamplingHandler( - IChatClient chatClient) - { - Throw.IfNull(chatClient); - - return async (requestParams, progress, cancellationToken) => - { - Throw.IfNull(requestParams); - - var (messages, options) = ToChatClientArguments(requestParams); - var progressToken = requestParams.ProgressToken; - - List updates = []; - await foreach (var update in chatClient.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) - { - updates.Add(update); - - if (progressToken is not null) - { - progress.Report(new() - { - Progress = updates.Count, - }); - } - } - - return ToCreateMessageResult(updates.ToChatResponse()); - }; - } - /// /// Sets the logging level for the server to control which log messages are sent to the client. /// diff --git a/src/ModelContextProtocol.Core/Client/McpClientExtensions.cs b/src/ModelContextProtocol.Core/Client/McpClientExtensions.cs index f0cd3c4f9..de2d0071b 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientExtensions.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientExtensions.cs @@ -19,53 +19,6 @@ namespace ModelContextProtocol.Client; /// public static class McpClientExtensions { - /// - /// Creates a sampling handler for use with that will - /// satisfy sampling requests using the specified . - /// - /// The with which to satisfy sampling requests. - /// The created handler delegate that can be assigned to . - /// - /// - /// This method creates a function that converts MCP message requests into chat client calls, enabling - /// an MCP client to generate text or other content using an actual AI model via the provided chat client. - /// - /// - /// The handler can process text messages, image messages, and resource messages as defined in the - /// Model Context Protocol. - /// - /// - /// is . - public static Func, CancellationToken, ValueTask> CreateSamplingHandler( - this IChatClient chatClient) - { - Throw.IfNull(chatClient); - - return async (requestParams, progress, cancellationToken) => - { - Throw.IfNull(requestParams); - - var (messages, options) = requestParams.ToChatClientArguments(); - var progressToken = requestParams.ProgressToken; - - List updates = []; - await foreach (var update in chatClient.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) - { - updates.Add(update); - - if (progressToken is not null) - { - progress.Report(new() - { - Progress = updates.Count, - }); - } - } - - return updates.ToChatResponse().ToCreateMessageResult(); - }; - } - /// /// Sends a ping request to verify server connectivity. /// @@ -654,79 +607,4 @@ static void ThrowInvalidEndpointType(string memberName) $"'{nameof(McpClientExtensions)}.{memberName}' is obsolete and will be " + $"removed in the future."); } - - /// - /// Converts the contents of a into a pair of - /// and instances to use - /// as inputs into a operation. - /// - /// - /// The created pair of messages and options. - /// is . - internal static (IList Messages, ChatOptions? Options) ToChatClientArguments( - this CreateMessageRequestParams requestParams) - { - Throw.IfNull(requestParams); - - ChatOptions? options = null; - - if (requestParams.MaxTokens is int maxTokens) - { - (options ??= new()).MaxOutputTokens = maxTokens; - } - - if (requestParams.Temperature is float temperature) - { - (options ??= new()).Temperature = temperature; - } - - if (requestParams.StopSequences is { } stopSequences) - { - (options ??= new()).StopSequences = stopSequences.ToArray(); - } - - List messages = - (from sm in requestParams.Messages - let aiContent = sm.Content.ToAIContent() - where aiContent is not null - select new ChatMessage(sm.Role == Role.Assistant ? ChatRole.Assistant : ChatRole.User, [aiContent])) - .ToList(); - - return (messages, options); - } - - /// Converts the contents of a into a . - /// The whose contents should be extracted. - /// The created . - /// is . - internal static CreateMessageResult ToCreateMessageResult(this ChatResponse chatResponse) - { - Throw.IfNull(chatResponse); - - // The ChatResponse can include multiple messages, of varying modalities, but CreateMessageResult supports - // only either a single blob of text or a single image. Heuristically, we'll use an image if there is one - // in any of the response messages, or we'll use all the text from them concatenated, otherwise. - - ChatMessage? lastMessage = chatResponse.Messages.LastOrDefault(); - - ContentBlock? content = null; - if (lastMessage is not null) - { - foreach (var lmc in lastMessage.Contents) - { - if (lmc is DataContent dc && (dc.HasTopLevelMediaType("image") || dc.HasTopLevelMediaType("audio"))) - { - content = dc.ToContent(); - } - } - } - - return new() - { - Content = content ?? new TextContentBlock { Text = lastMessage?.Text ?? string.Empty }, - Model = chatResponse.ModelId ?? "unknown", - Role = lastMessage?.Role == ChatRole.User ? Role.User : Role.Assistant, - StopReason = chatResponse.FinishReason == ChatFinishReason.Length ? "maxTokens" : "endTurn", - }; - } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Client/McpClientHandlers.cs b/src/ModelContextProtocol.Core/Client/McpClientHandlers.cs index fecb83299..f6abf0de0 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientHandlers.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientHandlers.cs @@ -80,7 +80,7 @@ public class McpClientHandlers /// generated content. /// /// - /// You can create a handler using the extension + /// You can create a handler using the extension /// method with any implementation of . /// /// diff --git a/src/ModelContextProtocol.Core/Diagnostics.cs b/src/ModelContextProtocol.Core/Diagnostics.cs index bed648868..083422d9c 100644 --- a/src/ModelContextProtocol.Core/Diagnostics.cs +++ b/src/ModelContextProtocol.Core/Diagnostics.cs @@ -13,7 +13,7 @@ internal static class Diagnostics internal static Meter Meter { get; } = new("Experimental.ModelContextProtocol"); internal static Histogram CreateDurationHistogram(string name, string description, bool longBuckets) => - Meter.CreateHistogram(name, "s", description, advice: longBuckets ? LongSecondsBucketBoundaries : ShortSecondsBucketBoundaries); + Meter.CreateHistogram(name, "s", description, advice: longBuckets ? LongSecondsBucketBoundaries : ShortSecondsBucketBoundaries); /// /// Follows boundaries from http.server.request.duration/http.client.request.duration diff --git a/src/ModelContextProtocol.Core/Protocol/ContentBlock.cs b/src/ModelContextProtocol.Core/Protocol/ContentBlock.cs index ccc5e9623..633097ff1 100644 --- a/src/ModelContextProtocol.Core/Protocol/ContentBlock.cs +++ b/src/ModelContextProtocol.Core/Protocol/ContentBlock.cs @@ -1,4 +1,3 @@ -using Microsoft.Extensions.AI; using System.ComponentModel; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; @@ -37,7 +36,7 @@ private protected ContentBlock() /// When overridden in a derived class, gets the type of content. /// /// - /// This determines the structure of the content object. Valid values include "image", "audio", "text", "resource", and "resource_link". + /// This determines the structure of the content object. Valid values include "image", "audio", "text", "resource", "resource_link", "tool_use", and "tool_result". /// [JsonPropertyName("type")] public abstract string Type { get; } @@ -52,6 +51,15 @@ private protected ContentBlock() [JsonPropertyName("annotations")] public Annotations? Annotations { get; set; } + /// + /// Gets or sets metadata reserved by MCP for protocol-level metadata. + /// + /// + /// Implementations must not make assumptions about its contents. + /// + [JsonPropertyName("_meta")] + public JsonObject? Meta { get; set; } + /// /// Provides a for . /// @@ -82,6 +90,12 @@ public class Converter : JsonConverter ResourceContents? resource = null; Annotations? annotations = null; JsonObject? meta = null; + string? id = null; + JsonElement? input = null; + string? toolUseId = null; + List? content = null; + JsonElement? structuredContent = null; + bool? isError = null; while (reader.Read() && reader.TokenType != JsonTokenType.EndObject) { @@ -140,42 +154,71 @@ public class Converter : JsonConverter meta = JsonSerializer.Deserialize(ref reader, McpJsonUtilities.JsonContext.Default.JsonObject); break; + case "id": + id = reader.GetString(); + break; + + case "input": + input = JsonSerializer.Deserialize(ref reader, McpJsonUtilities.JsonContext.Default.JsonElement); + break; + + case "toolUseId": + toolUseId = reader.GetString(); + break; + + case "content": + if (reader.TokenType == JsonTokenType.StartArray) + { + content = []; + while (reader.Read() && reader.TokenType != JsonTokenType.EndArray) + { + content.Add(Read(ref reader, typeof(ContentBlock), options) ?? + throw new JsonException("Unexpected null item in content array.")); + } + } + else + { + content = [Read(ref reader, typeof(ContentBlock), options) ?? + throw new JsonException("Unexpected null content item.")]; + } + break; + + case "structuredContent": + structuredContent = JsonSerializer.Deserialize(ref reader, McpJsonUtilities.JsonContext.Default.JsonElement); + break; + + case "isError": + isError = reader.GetBoolean(); + break; + default: reader.Skip(); break; } } - return type switch + ContentBlock block = type switch { "text" => new TextContentBlock { Text = text ?? throw new JsonException("Text contents must be provided for 'text' type."), - Annotations = annotations, - Meta = meta, }, "image" => new ImageContentBlock { Data = data ?? throw new JsonException("Image data must be provided for 'image' type."), MimeType = mimeType ?? throw new JsonException("MIME type must be provided for 'image' type."), - Annotations = annotations, - Meta = meta, }, "audio" => new AudioContentBlock { Data = data ?? throw new JsonException("Audio data must be provided for 'audio' type."), MimeType = mimeType ?? throw new JsonException("MIME type must be provided for 'audio' type."), - Annotations = annotations, - Meta = meta, }, "resource" => new EmbeddedResourceBlock { Resource = resource ?? throw new JsonException("Resource contents must be provided for 'resource' type."), - Annotations = annotations, - Meta = meta, }, "resource_link" => new ResourceLinkBlock @@ -185,11 +228,30 @@ public class Converter : JsonConverter Description = description, MimeType = mimeType, Size = size, - Annotations = annotations, + }, + + "tool_use" => new ToolUseContentBlock + { + Id = id ?? throw new JsonException("ID must be provided for 'tool_use' type."), + Name = name ?? throw new JsonException("Name must be provided for 'tool_use' type."), + Input = input ?? throw new JsonException("Input must be provided for 'tool_use' type."), + }, + + "tool_result" => new ToolResultContentBlock + { + ToolUseId = toolUseId ?? throw new JsonException("ToolUseId must be provided for 'tool_result' type."), + Content = content ?? throw new JsonException("Content must be provided for 'tool_result' type."), + StructuredContent = structuredContent, + IsError = isError, }, _ => throw new JsonException($"Unknown content type: '{type}'"), }; + + block.Annotations = annotations; + block.Meta = meta; + + return block; } /// @@ -209,41 +271,21 @@ public override void Write(Utf8JsonWriter writer, ContentBlock value, JsonSerial { case TextContentBlock textContent: writer.WriteString("text", textContent.Text); - if (textContent.Meta is not null) - { - writer.WritePropertyName("_meta"); - JsonSerializer.Serialize(writer, textContent.Meta, McpJsonUtilities.JsonContext.Default.JsonObject); - } break; case ImageContentBlock imageContent: writer.WriteString("data", imageContent.Data); writer.WriteString("mimeType", imageContent.MimeType); - if (imageContent.Meta is not null) - { - writer.WritePropertyName("_meta"); - JsonSerializer.Serialize(writer, imageContent.Meta, McpJsonUtilities.JsonContext.Default.JsonObject); - } break; case AudioContentBlock audioContent: writer.WriteString("data", audioContent.Data); writer.WriteString("mimeType", audioContent.MimeType); - if (audioContent.Meta is not null) - { - writer.WritePropertyName("_meta"); - JsonSerializer.Serialize(writer, audioContent.Meta, McpJsonUtilities.JsonContext.Default.JsonObject); - } break; case EmbeddedResourceBlock embeddedResource: writer.WritePropertyName("resource"); JsonSerializer.Serialize(writer, embeddedResource.Resource, McpJsonUtilities.JsonContext.Default.ResourceContents); - if (embeddedResource.Meta is not null) - { - writer.WritePropertyName("_meta"); - JsonSerializer.Serialize(writer, embeddedResource.Meta, McpJsonUtilities.JsonContext.Default.JsonObject); - } break; case ResourceLinkBlock resourceLink: @@ -262,6 +304,33 @@ public override void Write(Utf8JsonWriter writer, ContentBlock value, JsonSerial writer.WriteNumber("size", resourceLink.Size.Value); } break; + + case ToolUseContentBlock toolUse: + writer.WriteString("id", toolUse.Id); + writer.WriteString("name", toolUse.Name); + writer.WritePropertyName("input"); + JsonSerializer.Serialize(writer, toolUse.Input, McpJsonUtilities.JsonContext.Default.JsonElement); + break; + + case ToolResultContentBlock toolResult: + writer.WriteString("toolUseId", toolResult.ToolUseId); + writer.WritePropertyName("content"); + writer.WriteStartArray(); + foreach (var item in toolResult.Content) + { + Write(writer, item, options); + } + writer.WriteEndArray(); + if (toolResult.StructuredContent.HasValue) + { + writer.WritePropertyName("structuredContent"); + JsonSerializer.Serialize(writer, toolResult.StructuredContent.Value, McpJsonUtilities.JsonContext.Default.JsonElement); + } + if (toolResult.IsError.HasValue) + { + writer.WriteBoolean("isError", toolResult.IsError.Value); + } + break; } if (value.Annotations is { } annotations) @@ -270,6 +339,12 @@ public override void Write(Utf8JsonWriter writer, ContentBlock value, JsonSerial JsonSerializer.Serialize(writer, annotations, McpJsonUtilities.JsonContext.Default.Annotations); } + if (value.Meta is not null) + { + writer.WritePropertyName("_meta"); + JsonSerializer.Serialize(writer, value.Meta, McpJsonUtilities.JsonContext.Default.JsonObject); + } + writer.WriteEndObject(); } } @@ -286,15 +361,6 @@ public sealed class TextContentBlock : ContentBlock /// [JsonPropertyName("text")] public required string Text { get; set; } - - /// - /// Gets or sets metadata reserved by MCP for protocol-level metadata. - /// - /// - /// Implementations must not make assumptions about its contents. - /// - [JsonPropertyName("_meta")] - public JsonObject? Meta { get; set; } } /// Represents an image provided to or from an LLM. @@ -319,15 +385,6 @@ public sealed class ImageContentBlock : ContentBlock /// [JsonPropertyName("mimeType")] public required string MimeType { get; set; } - - /// - /// Gets or sets metadata reserved by MCP for protocol-level metadata. - /// - /// - /// Implementations must not make assumptions about its contents. - /// - [JsonPropertyName("_meta")] - public JsonObject? Meta { get; set; } } /// Represents audio provided to or from an LLM. @@ -352,15 +409,6 @@ public sealed class AudioContentBlock : ContentBlock /// [JsonPropertyName("mimeType")] public required string MimeType { get; set; } - - /// - /// Gets or sets metadata reserved by MCP for protocol-level metadata. - /// - /// - /// Implementations must not make assumptions about its contents. - /// - [JsonPropertyName("_meta")] - public JsonObject? Meta { get; set; } } /// Represents the contents of a resource, embedded into a prompt or tool call result. @@ -384,15 +432,6 @@ public sealed class EmbeddedResourceBlock : ContentBlock /// [JsonPropertyName("resource")] public required ResourceContents Resource { get; set; } - - /// - /// Gets or sets metadata reserved by MCP for protocol-level metadata. - /// - /// - /// Implementations must not make assumptions about its contents. - /// - [JsonPropertyName("_meta")] - public JsonObject? Meta { get; set; } } /// Represents a resource that the server is capable of reading, included in a prompt or tool call result. @@ -461,3 +500,76 @@ public sealed class ResourceLinkBlock : ContentBlock [JsonPropertyName("size")] public long? Size { get; set; } } + +/// Represents a request from the assistant to call a tool. +public sealed class ToolUseContentBlock : ContentBlock +{ + /// + public override string Type => "tool_use"; + + /// + /// Gets or sets a unique identifier for this tool use. + /// + /// + /// This ID is used to match tool results to their corresponding tool uses. + /// + [JsonPropertyName("id")] + public required string Id { get; set; } + + /// + /// Gets or sets the name of the tool to call. + /// + [JsonPropertyName("name")] + public required string Name { get; set; } + + /// + /// Gets or sets the arguments to pass to the tool, conforming to the tool's input schema. + /// + [JsonPropertyName("input")] + public required JsonElement Input { get; set; } +} + +/// Represents the result of a tool use, provided by the user back to the assistant. +public sealed class ToolResultContentBlock : ContentBlock +{ + /// + public override string Type => "tool_result"; + + /// + /// Gets or sets the ID of the tool use this result corresponds to. + /// + /// + /// This must match the ID from a previous . + /// + [JsonPropertyName("toolUseId")] + public required string ToolUseId { get; set; } + + /// + /// Gets or sets the unstructured result content of the tool use. + /// + /// + /// This has the same format as CallToolResult.Content and can include text, images, + /// audio, resource links, and embedded resources. + /// + [JsonPropertyName("content")] + public required List Content { get; set; } + + /// + /// Gets or sets an optional structured result object. + /// + /// + /// If the tool defined an outputSchema, this should conform to that schema. + /// + [JsonPropertyName("structuredContent")] + public JsonElement? StructuredContent { get; set; } + + /// + /// Gets or sets whether the tool use resulted in an error. + /// + /// + /// If true, the content typically describes the error that occurred. + /// Default: false + /// + [JsonPropertyName("isError")] + public bool? IsError { get; set; } +} diff --git a/src/ModelContextProtocol.Core/Protocol/ContextInclusion.cs b/src/ModelContextProtocol.Core/Protocol/ContextInclusion.cs index 8894a3b3a..ec1bc2977 100644 --- a/src/ModelContextProtocol.Core/Protocol/ContextInclusion.cs +++ b/src/ModelContextProtocol.Core/Protocol/ContextInclusion.cs @@ -6,7 +6,14 @@ namespace ModelContextProtocol.Protocol; /// Specifies the context inclusion options for a request in the Model Context Protocol (MCP). /// /// +/// /// See the schema for details. +/// +/// +/// , and in particular and , are deprecated. +/// Servers should only use these values if the client declares with +/// set. These values may be removed in future spec releases. +/// /// [JsonConverter(typeof(JsonStringEnumConverter))] public enum ContextInclusion @@ -20,12 +27,20 @@ public enum ContextInclusion /// /// Indicates that context from the server that sent the request should be included. /// + /// + /// This value is soft-deprecated. Servers should only use this value if the client + /// declares ClientCapabilities.Sampling.Context. + /// [JsonStringEnumMemberName("thisServer")] ThisServer, /// /// Indicates that context from all servers that the client is connected to should be included. /// + /// + /// This value is soft-deprecated. Servers should only use this value if the client + /// declares ClientCapabilities.Sampling.Context. + /// [JsonStringEnumMemberName("allServers")] AllServers } diff --git a/src/ModelContextProtocol.Core/Protocol/CreateMessageRequestParams.cs b/src/ModelContextProtocol.Core/Protocol/CreateMessageRequestParams.cs index d3086c0be..c910053fb 100644 --- a/src/ModelContextProtocol.Core/Protocol/CreateMessageRequestParams.cs +++ b/src/ModelContextProtocol.Core/Protocol/CreateMessageRequestParams.cs @@ -16,7 +16,15 @@ public sealed class CreateMessageRequestParams : RequestParams /// Gets or sets an indication as to which server contexts should be included in the prompt. /// /// + /// /// The client may ignore this request. + /// + /// + /// , and in particular and + /// , are deprecated. Servers should only use these values if the client + /// declares with set. + /// These values may be removed in future spec releases. + /// /// [JsonPropertyName("includeContext")] public ContextInclusion? IncludeContext { get; set; } @@ -100,4 +108,16 @@ public sealed class CreateMessageRequestParams : RequestParams /// [JsonPropertyName("temperature")] public float? Temperature { get; set; } + + /// + /// Gets or sets tools that the model may use during generation. + /// + [JsonPropertyName("tools")] + public IList? Tools { get; set; } + + /// + /// Gets or sets controls for how the model uses tools. + /// + [JsonPropertyName("toolChoice")] + public ToolChoice? ToolChoice { get; set; } } diff --git a/src/ModelContextProtocol.Core/Protocol/CreateMessageResult.cs b/src/ModelContextProtocol.Core/Protocol/CreateMessageResult.cs index 7fada6399..d6891c747 100644 --- a/src/ModelContextProtocol.Core/Protocol/CreateMessageResult.cs +++ b/src/ModelContextProtocol.Core/Protocol/CreateMessageResult.cs @@ -11,10 +11,14 @@ namespace ModelContextProtocol.Protocol; public sealed class CreateMessageResult : Result { /// - /// Gets or sets the content of the message. + /// Gets or sets the content of the assistant's response. /// + /// + /// In the corresponding JSON, this may be a single content block or an array of content blocks. + /// [JsonPropertyName("content")] - public required ContentBlock Content { get; set; } + [JsonConverter(typeof(SingleItemOrListConverter))] + public required IList Content { get; set; } /// /// Gets or sets the name of the model that generated the message. @@ -35,12 +39,14 @@ public sealed class CreateMessageResult : Result /// Gets or sets the reason why message generation (sampling) stopped, if known. /// /// - /// Common values include: + /// Standard values include: /// /// endTurnThe model naturally completed its response. /// maxTokensThe response was truncated due to reaching token limits. /// stopSequenceA specific stop sequence was encountered during generation. + /// toolUseThe model wants to use one or more tools. /// + /// This field is an open string to allow for provider-specific stop reasons. /// [JsonPropertyName("stopReason")] public string? StopReason { get; set; } @@ -49,5 +55,17 @@ public sealed class CreateMessageResult : Result /// Gets or sets the role of the user who generated the message. /// [JsonPropertyName("role")] - public required Role Role { get; set; } + public Role Role { get; set; } = Role.Assistant; + + /// The stop reason "endTurn". + internal const string StopReasonEndTurn = "endTurn"; + + /// The stop reason "maxTokens". + internal const string StopReasonMaxTokens = "maxTokens"; + + /// The stop reason "stopSequence". + internal const string StopReasonStopSequence = "stopSequence"; + + /// The stop reason "toolUse". + internal const string StopReasonToolUse = "toolUse"; } diff --git a/src/ModelContextProtocol.Core/Protocol/SamplingCapability.cs b/src/ModelContextProtocol.Core/Protocol/SamplingCapability.cs index 8ddc7ecf8..aba24a77c 100644 --- a/src/ModelContextProtocol.Core/Protocol/SamplingCapability.cs +++ b/src/ModelContextProtocol.Core/Protocol/SamplingCapability.cs @@ -16,11 +16,6 @@ namespace ModelContextProtocol.Protocol; /// When this capability is enabled, an MCP server can request the client to generate content /// using an AI model. The client must set a to process these requests. /// -/// -/// This class is intentionally empty as the Model Context Protocol specification does not -/// currently define additional properties for sampling capabilities. Future versions of the -/// specification may extend this capability with additional configuration options. -/// /// public sealed class SamplingCapability { @@ -38,7 +33,7 @@ public sealed class SamplingCapability /// generated content. /// /// - /// You can create a handler using the extension + /// You can create a handler using the extension /// method with any implementation of . /// /// @@ -46,4 +41,20 @@ public sealed class SamplingCapability [Obsolete($"Use {nameof(McpClientOptions.Handlers.SamplingHandler)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 [EditorBrowsable(EditorBrowsableState.Never)] public Func, CancellationToken, ValueTask>? SamplingHandler { get; set; } -} \ No newline at end of file + + /// + /// Gets or sets whether the client supports context inclusion via includeContext parameter. + /// + /// + /// If not declared, servers should only use includeContext: "none". + /// + [JsonPropertyName("context")] + public SamplingContextCapability? Context { get; set; } + + /// + /// Gets or sets whether the client supports tool use via tools and toolChoice parameters. + /// + [JsonPropertyName("tools")] + public SamplingToolsCapability? Tools { get; set; } +} + diff --git a/src/ModelContextProtocol.Core/Protocol/SamplingContextCapability.cs b/src/ModelContextProtocol.Core/Protocol/SamplingContextCapability.cs new file mode 100644 index 000000000..bae960f3a --- /dev/null +++ b/src/ModelContextProtocol.Core/Protocol/SamplingContextCapability.cs @@ -0,0 +1,6 @@ +namespace ModelContextProtocol.Protocol; + +/// +/// Represents the sampling context capability. +/// +public sealed class SamplingContextCapability; \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Protocol/SamplingMessage.cs b/src/ModelContextProtocol.Core/Protocol/SamplingMessage.cs index 60db179cc..093824e47 100644 --- a/src/ModelContextProtocol.Core/Protocol/SamplingMessage.cs +++ b/src/ModelContextProtocol.Core/Protocol/SamplingMessage.cs @@ -1,3 +1,4 @@ +using System.Text.Json.Nodes; using System.Text.Json.Serialization; namespace ModelContextProtocol.Protocol; @@ -8,7 +9,8 @@ namespace ModelContextProtocol.Protocol; /// /// /// A encapsulates content sent to or received from AI models in the Model Context Protocol. -/// Each message has a specific role ( or ) and contains content which can be text or images. +/// The message has a role ( or ) and content which can be text, images, +/// audio, tool uses, or tool results. /// /// /// objects are typically used in collections within @@ -16,8 +18,9 @@ namespace ModelContextProtocol.Protocol; /// within the Model Context Protocol. /// /// -/// While similar to , the is focused on direct LLM sampling -/// operations rather than the enhanced resource embedding capabilities provided by . +/// If content contains any , then all content items +/// must be . Tool results cannot be mixed with text, image, or +/// audio content in the same message. /// /// /// See the schema for details. @@ -29,11 +32,21 @@ public sealed class SamplingMessage /// Gets or sets the content of the message. /// [JsonPropertyName("content")] - public required ContentBlock Content { get; set; } + [JsonConverter(typeof(SingleItemOrListConverter))] + public required IList Content { get; set; } /// - /// Gets or sets the role of the message sender, indicating whether it's from a "user" or an "assistant". + /// Gets or sets the role of the message sender. /// [JsonPropertyName("role")] - public required Role Role { get; set; } + public Role Role { get; set; } = Role.User; + + /// + /// Gets or sets metadata reserved by MCP for protocol-level metadata. + /// + /// + /// Implementations must not make assumptions about its contents. + /// + [JsonPropertyName("_meta")] + public JsonObject? Meta { get; set; } } diff --git a/src/ModelContextProtocol.Core/Protocol/SamplingToolsCapability.cs b/src/ModelContextProtocol.Core/Protocol/SamplingToolsCapability.cs new file mode 100644 index 000000000..f93b79725 --- /dev/null +++ b/src/ModelContextProtocol.Core/Protocol/SamplingToolsCapability.cs @@ -0,0 +1,6 @@ +namespace ModelContextProtocol.Protocol; + +/// +/// Represents the sampling tools capability. +/// +public sealed class SamplingToolsCapability; \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Protocol/SingleItemOrListConverter.cs b/src/ModelContextProtocol.Core/Protocol/SingleItemOrListConverter.cs new file mode 100644 index 000000000..497062e7e --- /dev/null +++ b/src/ModelContextProtocol.Core/Protocol/SingleItemOrListConverter.cs @@ -0,0 +1,67 @@ +using System.ComponentModel; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Protocol; + +/// +/// JSON converter for that handles both array and single object representations. +/// +[EditorBrowsable(EditorBrowsableState.Never)] +public sealed class SingleItemOrListConverter : JsonConverter> + where T : class +{ + /// + public override IList? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + if (reader.TokenType == JsonTokenType.Null) + { + return null; + } + + if (reader.TokenType == JsonTokenType.StartArray) + { + List list = []; + while (reader.Read() && reader.TokenType != JsonTokenType.EndArray) + { + if (JsonSerializer.Deserialize(ref reader, options.GetTypeInfo(typeof(T))) is T item) + { + list.Add(item); + } + } + + return list; + } + + if (reader.TokenType == JsonTokenType.StartObject) + { + return JsonSerializer.Deserialize(ref reader, options.GetTypeInfo(typeof(T))) is T item ? [item] : []; + } + + throw new JsonException($"Unexpected token type: {reader.TokenType}. Expected StartArray or StartObject."); + } + + /// + public override void Write(Utf8JsonWriter writer, IList value, JsonSerializerOptions options) + { + switch (value) + { + case null: + writer.WriteNullValue(); + return; + + case { Count: 1 }: + JsonSerializer.Serialize(writer, value[0], options.GetTypeInfo(typeof(object))); + return; + + default: + writer.WriteStartArray(); + foreach (var item in value) + { + JsonSerializer.Serialize(writer, item, options.GetTypeInfo(typeof(object))); + } + writer.WriteEndArray(); + return; + } + } +} diff --git a/src/ModelContextProtocol.Core/Protocol/ToolChoice.cs b/src/ModelContextProtocol.Core/Protocol/ToolChoice.cs new file mode 100644 index 000000000..41b387494 --- /dev/null +++ b/src/ModelContextProtocol.Core/Protocol/ToolChoice.cs @@ -0,0 +1,32 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Protocol; + +/// +/// Controls tool selection behavior for sampling requests. +/// +public sealed class ToolChoice +{ + /// + /// Gets or sets the mode controlling which tools the model can call. + /// + /// + /// + /// "auto"Model decides whether to call tools (default) + /// "required"Model must call at least one tool + /// "none"Model must not call any tools + /// + /// + [JsonPropertyName("mode")] + public string? Mode { get; set; } + + /// The mode value "auto". + internal const string ModeAuto = "auto"; + + /// The mode value "required". + internal const string ModeRequired = "required"; + + /// The mode value "none". + internal const string ModeNone = "none"; +} + diff --git a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs index e3d1271d0..f714bb184 100644 --- a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs +++ b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs @@ -249,7 +249,7 @@ public override async ValueTask InvokeAsync( { AIContent aiContent => new() { - Content = [aiContent.ToContent()], + Content = [aiContent.ToContentBlock()], StructuredContent = structuredContent, IsError = aiContent is ErrorContent }, @@ -491,7 +491,7 @@ private static CallToolResult ConvertAIContentEnumerableToCallToolResult(IEnumer foreach (var item in contentItems) { - contentList.Add(item.ToContent()); + contentList.Add(item.ToContentBlock()); hasAny = true; if (allErrorContent && item is not ErrorContent) diff --git a/src/ModelContextProtocol.Core/Server/McpServer.Methods.cs b/src/ModelContextProtocol.Core/Server/McpServer.Methods.cs index 609da53c1..3b0eadf30 100644 --- a/src/ModelContextProtocol.Core/Server/McpServer.Methods.cs +++ b/src/ModelContextProtocol.Core/Server/McpServer.Methods.cs @@ -106,42 +106,26 @@ public async Task SampleAsync( continue; } - if (message.Role == ChatRole.User || message.Role == ChatRole.Assistant) - { - Role role = message.Role == ChatRole.User ? Role.User : Role.Assistant; + Role role = message.Role == ChatRole.Assistant ? Role.Assistant : Role.User; - foreach (var content in message.Contents) + // Group all content blocks from this message into a single SamplingMessage + List contentBlocks = []; + foreach (var content in message.Contents) + { + if (content.ToContentBlock() is { } contentBlock) { - switch (content) - { - case TextContent textContent: - samplingMessages.Add(new() - { - Role = role, - Content = new TextContentBlock { Text = textContent.Text }, - }); - break; - - case DataContent dataContent when dataContent.HasTopLevelMediaType("image") || dataContent.HasTopLevelMediaType("audio"): - samplingMessages.Add(new() - { - Role = role, - Content = dataContent.HasTopLevelMediaType("image") ? - new ImageContentBlock - { - MimeType = dataContent.MediaType, - Data = dataContent.Base64Data.ToString(), - } : - new AudioContentBlock - { - MimeType = dataContent.MediaType, - Data = dataContent.Base64Data.ToString(), - }, - }); - break; - } + contentBlocks.Add(contentBlock); } } + + if (contentBlocks.Count > 0) + { + samplingMessages.Add(new() + { + Role = role, + Content = contentBlocks, + }); + } } ModelPreferences? modelPreferences = null; @@ -150,25 +134,63 @@ public async Task SampleAsync( modelPreferences = new() { Hints = [new() { Name = modelId }] }; } + IList? tools = null; + if (options?.Tools is { Count: > 0 }) + { + foreach (var tool in options.Tools) + { + if (tool is AIFunctionDeclaration af) + { + (tools ??= []).Add(new() + { + Name = af.Name, + Description = af.Description, + InputSchema = af.JsonSchema, + Meta = af.AdditionalProperties.ToJsonObject(), + }); + } + } + } + + ToolChoice? toolChoice = options?.ToolMode switch + { + NoneChatToolMode => new() { Mode = ToolChoice.ModeNone }, + AutoChatToolMode => new() { Mode = ToolChoice.ModeAuto }, + RequiredChatToolMode => new() { Mode = ToolChoice.ModeRequired }, + _ => null, + }; + var result = await SampleAsync(new() { - Messages = samplingMessages, MaxTokens = options?.MaxOutputTokens ?? ServerOptions.MaxSamplingOutputTokens, + Messages = samplingMessages, + ModelPreferences = modelPreferences, StopSequences = options?.StopSequences?.ToArray(), SystemPrompt = systemPrompt?.ToString(), Temperature = options?.Temperature, - ModelPreferences = modelPreferences, + ToolChoice = toolChoice, + Tools = tools, + Meta = options?.AdditionalProperties?.ToJsonObject(), }, cancellationToken).ConfigureAwait(false); - AIContent? responseContent = result.Content.ToAIContent(); + List responseContents = []; + foreach (var block in result.Content) + { + if (block.ToAIContent() is { } content) + { + responseContents.Add(content); + } + } - return new(new ChatMessage(result.Role is Role.User ? ChatRole.User : ChatRole.Assistant, responseContent is not null ? [responseContent] : [])) + return new(new ChatMessage(result.Role is Role.User ? ChatRole.User : ChatRole.Assistant, responseContents)) { ModelId = result.Model, FinishReason = result.StopReason switch { - "maxTokens" => ChatFinishReason.Length, - "endTurn" or "stopSequence" or _ => ChatFinishReason.Stop, + CreateMessageResult.StopReasonMaxTokens => ChatFinishReason.Length, + CreateMessageResult.StopReasonToolUse => ChatFinishReason.ToolCalls, + CreateMessageResult.StopReasonEndTurn or CreateMessageResult.StopReasonStopSequence => ChatFinishReason.Stop, + _ => null, } }; } diff --git a/src/ModelContextProtocol.Core/Server/McpServerTool.cs b/src/ModelContextProtocol.Core/Server/McpServerTool.cs index 6948ea912..987424b0d 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerTool.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerTool.cs @@ -99,7 +99,7 @@ namespace ModelContextProtocol.Server; /// /// /// -/// Converted to a single object using . +/// Converted to a single object using . /// /// /// @@ -111,7 +111,7 @@ namespace ModelContextProtocol.Server; /// /// /// of -/// Each is converted to a object using . +/// Each is converted to a object using . /// /// /// of diff --git a/src/ModelContextProtocol.Core/Server/McpServerToolAttribute.cs b/src/ModelContextProtocol.Core/Server/McpServerToolAttribute.cs index 9e71e0eab..cf5455587 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerToolAttribute.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerToolAttribute.cs @@ -90,7 +90,7 @@ namespace ModelContextProtocol.Server; /// /// /// -/// Converted to a single object using . +/// Converted to a single object using . /// /// /// @@ -106,7 +106,7 @@ namespace ModelContextProtocol.Server; /// /// /// of -/// Each is converted to a object using . +/// Each is converted to a object using . /// /// /// of diff --git a/tests/Common/Utils/TestServerTransport.cs b/tests/Common/Utils/TestServerTransport.cs index f875fe504..51682ba60 100644 --- a/tests/Common/Utils/TestServerTransport.cs +++ b/tests/Common/Utils/TestServerTransport.cs @@ -1,4 +1,4 @@ -using ModelContextProtocol.Protocol; +using ModelContextProtocol.Protocol; using System.Text.Json; using System.Threading.Channels; @@ -74,7 +74,7 @@ private async Task SamplingAsync(JsonRpcRequest request, CancellationToken cance await WriteMessageAsync(new JsonRpcResponse { Id = request.Id, - Result = JsonSerializer.SerializeToNode(new CreateMessageResult { Content = new TextContentBlock { Text = "" }, Model = "model", Role = Role.User }, McpJsonUtilities.DefaultOptions), + Result = JsonSerializer.SerializeToNode(new CreateMessageResult { Content = [new TextContentBlock { Text = "" }], Model = "model"}, McpJsonUtilities.DefaultOptions), }, cancellationToken); } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs index 78acaeb5e..acf5d469e 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs @@ -1,4 +1,4 @@ -using ModelContextProtocol.Client; +using ModelContextProtocol.Client; using ModelContextProtocol.Protocol; using ModelContextProtocol.Tests.Utils; @@ -259,8 +259,7 @@ public async Task Sampling_Sse_TestServer() return new CreateMessageResult { Model = "test-model", - Role = Role.Assistant, - Content = new TextContentBlock { Text = "Test response" }, + Content = [new TextContentBlock { Text = "Test response" }], }; }; await using var client = await GetClientAsync(options); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs index 0c71e56e3..2899851ef 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs @@ -1,4 +1,4 @@ -using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; @@ -143,7 +143,7 @@ public async Task ClaimsPrincipal_CanBeInjected_IntoToolMethod() } [Fact] - public async Task Sampling_DoesNotCloseStream_Prematurely() + public async Task Sampling_DoesNotCloseStreamPrematurely() { Assert.SkipWhen(Stateless, "Sampling is not supported in stateless mode."); @@ -172,14 +172,14 @@ public async Task Sampling_DoesNotCloseStream_Prematurely() Assert.NotNull(parameters?.Messages); var message = Assert.Single(parameters.Messages); Assert.Equal(Role.User, message.Role); - Assert.Equal("Test prompt for sampling", Assert.IsType(message.Content).Text); + Assert.Equal("Test prompt for sampling", Assert.IsType(Assert.Single(message.Content)).Text); sampleCount++; return new CreateMessageResult { Model = "test-model", Role = Role.Assistant, - Content = new TextContentBlock { Text = "Sampling response from client" }, + Content = [new TextContentBlock { Text = "Sampling response from client" }], }; } } @@ -285,7 +285,7 @@ public static async Task SamplingToolAsync(McpServer server, string prom new SamplingMessage { Role = Role.User, - Content = new TextContentBlock { Text = prompt }, + Content = [new TextContentBlock { Text = prompt }], } ], MaxTokens = 1000 @@ -294,7 +294,7 @@ public static async Task SamplingToolAsync(McpServer server, string prom await server.SampleAsync(samplingRequest, cancellationToken); var samplingResult = await server.SampleAsync(samplingRequest, cancellationToken); - return $"Sampling completed successfully. Client responded: {Assert.IsType(samplingResult.Content).Text}"; + return $"Sampling completed successfully. Client responded: {Assert.IsType(Assert.Single(samplingResult.Content)).Text}"; } } } diff --git a/tests/ModelContextProtocol.TestServer/Program.cs b/tests/ModelContextProtocol.TestServer/Program.cs index 9a54ed71d..e4b797ee6 100644 --- a/tests/ModelContextProtocol.TestServer/Program.cs +++ b/tests/ModelContextProtocol.TestServer/Program.cs @@ -1,4 +1,4 @@ -using System.Collections.Concurrent; +using System.Collections.Concurrent; using System.Text; using System.Text.Json; using Microsoft.Extensions.Logging; @@ -197,7 +197,7 @@ private static void ConfigureTools(McpServerOptions options, string? cliArg) return new CallToolResult { - Content = [new TextContentBlock { Text = $"LLM sampling result: {(sampleResult.Content as TextContentBlock)?.Text}" }] + Content = [new TextContentBlock { Text = $"LLM sampling result: {sampleResult.Content.OfType().FirstOrDefault()?.Text}" }] }; } else if (request.Params?.Name == "echoCliArg") @@ -521,7 +521,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st Messages = [new SamplingMessage { Role = Role.User, - Content = new TextContentBlock { Text = $"Resource {uri} context: {context}" }, + Content = [new TextContentBlock { Text = $"Resource {uri} context: {context}" }], }], SystemPrompt = "You are a helpful test server.", MaxTokens = maxTokens, diff --git a/tests/ModelContextProtocol.TestSseServer/Program.cs b/tests/ModelContextProtocol.TestSseServer/Program.cs index 183a64e7e..53537c2b8 100644 --- a/tests/ModelContextProtocol.TestSseServer/Program.cs +++ b/tests/ModelContextProtocol.TestSseServer/Program.cs @@ -1,4 +1,4 @@ -using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Connections; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; using Serilog; @@ -46,7 +46,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st Messages = [new SamplingMessage { Role = Role.User, - Content = new TextContentBlock { Text = $"Resource {uri} context: {context}" }, + Content = [new TextContentBlock { Text = $"Resource {uri} context: {context}" }], }], SystemPrompt = "You are a helpful test server.", MaxTokens = maxTokens, @@ -191,7 +191,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st return new CallToolResult { - Content = [new TextContentBlock { Text = $"LLM sampling result: {(sampleResult.Content as TextContentBlock)?.Text}" }] + Content = [new TextContentBlock { Text = $"LLM sampling result: {sampleResult.Content.OfType().FirstOrDefault()?.Text}" }] }; } else @@ -339,7 +339,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st }); messages.Add(new PromptMessage { - Role = Role.Assistant, + Role = Role.User, Content = new TextContentBlock { Text = "I understand. You've provided a complex prompt with temperature and style arguments. How would you like me to proceed?" }, }); messages.Add(new PromptMessage diff --git a/tests/ModelContextProtocol.Tests/AIContentExtensionsTests.cs b/tests/ModelContextProtocol.Tests/AIContentExtensionsTests.cs index ec603c63f..3a57a07c6 100644 --- a/tests/ModelContextProtocol.Tests/AIContentExtensionsTests.cs +++ b/tests/ModelContextProtocol.Tests/AIContentExtensionsTests.cs @@ -25,4 +25,127 @@ public void CallToolResult_ToChatMessage_ProducesExpectedAIContent() JsonElement result = Assert.IsType(frc.Result); Assert.Contains("This is a test message.", result.ToString()); } + + [Fact] + public void ToAIContent_ConvertsToolUseContentBlock() + { + Dictionary inputDict = new() { ["city"] = "Paris", ["units"] = "metric" }; + ToolUseContentBlock toolUse = new() + { + Id = "call_abc123", + Name = "get_weather", + Input = JsonSerializer.SerializeToElement(inputDict, McpJsonUtilities.DefaultOptions) + }; + + AIContent? aiContent = toolUse.ToAIContent(); + + var functionCall = Assert.IsType(aiContent); + Assert.Equal("call_abc123", functionCall.CallId); + Assert.Equal("get_weather", functionCall.Name); + Assert.NotNull(functionCall.Arguments); + + var cityArg = Assert.IsType(functionCall.Arguments["city"]); + Assert.Equal("Paris", cityArg.GetString()); + var unitsArg = Assert.IsType(functionCall.Arguments["units"]); + Assert.Equal("metric", unitsArg.GetString()); + } + + [Fact] + public void ToAIContent_ConvertsToolResultContentBlock() + { + ToolResultContentBlock toolResult = new() + { + ToolUseId = "call_abc123", + Content = [new TextContentBlock { Text = "Weather: 18°C" }], + IsError = false + }; + + AIContent? aiContent = toolResult.ToAIContent(); + + var functionResult = Assert.IsType(aiContent); + Assert.Equal("call_abc123", functionResult.CallId); + Assert.Null(functionResult.Exception); + Assert.NotNull(functionResult.Result); + } + + [Fact] + public void ToAIContent_ConvertsToolResultContentBlockWithError() + { + ToolResultContentBlock toolResult = new() + { + ToolUseId = "call_abc123", + Content = [new TextContentBlock { Text = "Error: Invalid city" }], + IsError = true + }; + + AIContent? aiContent = toolResult.ToAIContent(); + + var functionResult = Assert.IsType(aiContent); + Assert.Equal("call_abc123", functionResult.CallId); + Assert.NotNull(functionResult.Exception); + } + + [Fact] + public void ToAIContent_ConvertsToolResultWithMultipleContent() + { + ToolResultContentBlock toolResult = new() + { + ToolUseId = "call_123", + Content = + [ + new TextContentBlock { Text = "Text result" }, + new ImageContentBlock { Data = Convert.ToBase64String([1, 2, 3]), MimeType = "image/png" } + ] + }; + + AIContent? aiContent = toolResult.ToAIContent(); + + var functionResult = Assert.IsType(aiContent); + Assert.Equal("call_123", functionResult.CallId); + + var resultList = Assert.IsAssignableFrom>(functionResult.Result); + Assert.Equal(2, resultList.Count); + Assert.IsType(resultList[0]); + Assert.IsType(resultList[1]); + } + + [Fact] + public void ToAIContent_ToolUseToFunctionCallRoundTrip() + { + Dictionary inputDict = new() { ["param1"] = "value1", ["param2"] = 42 }; + ToolUseContentBlock original = new() + { + Id = "call_123", + Name = "test_tool", + Input = JsonSerializer.SerializeToElement(inputDict, McpJsonUtilities.DefaultOptions) + }; + + var functionCall = Assert.IsType(original.ToAIContent()); + + Assert.Equal("call_123", functionCall.CallId); + Assert.Equal("test_tool", functionCall.Name); + Assert.NotNull(functionCall.Arguments); + + var param1 = Assert.IsType(functionCall.Arguments["param1"]); + Assert.Equal("value1", param1.GetString()); + var param2 = Assert.IsType(functionCall.Arguments["param2"]); + Assert.Equal(42, param2.GetInt32()); + } + + [Fact] + public void ToAIContent_ToolResultToFunctionResultRoundTrip() + { + ToolResultContentBlock original = new() + { + ToolUseId = "call_123", + Content = [new TextContentBlock { Text = "Result" }, new TextContentBlock { Text = "More data" }], + IsError = false + }; + + var functionResult = Assert.IsType(original.ToAIContent()); + + Assert.Equal("call_123", functionResult.CallId); + Assert.False(functionResult.Exception != null); + Assert.NotNull(functionResult.Result); + } } \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientCreationTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientCreationTests.cs index 0eb84262b..504b52e21 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientCreationTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientCreationTests.cs @@ -73,10 +73,10 @@ public async Task CreateAsync_WithCapabilitiesOptions(Type transportType) RootsHandler = async (t, r) => new ListRootsResult { Roots = [] }, SamplingHandler = async (c, p, t) => new CreateMessageResult { - Content = new TextContentBlock { Text = "result" }, + Content = [new TextContentBlock { Text = "result" }], Model = "test-model", Role = Role.User, - StopReason = "endTurn" + StopReason = "endTurn", } } }; diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs index 8fb7d2203..ac7b75f8c 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs @@ -220,7 +220,14 @@ public async Task GetPromptAsync_Forwards_To_McpClient_SendRequestAsync() { var mockClient = new Mock { CallBase = true }; - var resultPayload = new GetPromptResult { Messages = [new PromptMessage { Role = Role.User, Content = new TextContentBlock { Text = "hi" } }] }; + var resultPayload = new GetPromptResult + { + Messages = [new() + { + Role = Role.User, + Content = new TextContentBlock { Text = "hi" } + }] + }; mockClient .Setup(c => c.SendRequestAsync(It.IsAny(), It.IsAny())) diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs index 09be7385e..604c31b8b 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs @@ -5,6 +5,7 @@ using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; using Moq; +using System.Runtime.CompilerServices; using System.Text.Json; using System.Text.Json.Serialization.Metadata; using System.Threading.Channels; @@ -106,10 +107,10 @@ public async Task CreateSamplingHandler_ShouldHandleTextMessages(float? temperat { Messages = [ - new SamplingMessage + new SamplingMessage { Role = Role.User, - Content = new TextContentBlock { Text = "Hello" } + Content = [new TextContentBlock { Text = "Hello" }] } ], Temperature = temperature, @@ -134,14 +135,14 @@ public async Task CreateSamplingHandler_ShouldHandleTextMessages(float? temperat .Setup(client => client.GetStreamingResponseAsync(It.IsAny>(), It.IsAny(), cancellationToken)) .Returns(expectedResponse); - var handler = McpClientExtensions.CreateSamplingHandler(mockChatClient.Object); + var handler = mockChatClient.Object.CreateSamplingHandler(); // Act var result = await handler(requestParams, Mock.Of>(), cancellationToken); // Assert Assert.NotNull(result); - Assert.Equal("Hello, World!", (result.Content as TextContentBlock)?.Text); + Assert.Equal("Hello, World!", result.Content.OfType().FirstOrDefault()?.Text); Assert.Equal("test-model", result.Model); Assert.Equal(Role.Assistant, result.Role); Assert.Equal("endTurn", result.StopReason); @@ -156,14 +157,14 @@ public async Task CreateSamplingHandler_ShouldHandleImageMessages() { Messages = [ - new SamplingMessage + new SamplingMessage { Role = Role.User, - Content = new ImageContentBlock + Content = [new ImageContentBlock { MimeType = "image/png", Data = Convert.ToBase64String(new byte[] { 1, 2, 3 }) - } + }], } ], MaxTokens = 100 @@ -188,14 +189,14 @@ public async Task CreateSamplingHandler_ShouldHandleImageMessages() .Setup(client => client.GetStreamingResponseAsync(It.IsAny>(), It.IsAny(), cancellationToken)) .Returns(expectedResponse); - var handler = McpClientExtensions.CreateSamplingHandler(mockChatClient.Object); + var handler = mockChatClient.Object.CreateSamplingHandler(); // Act var result = await handler(requestParams, Mock.Of>(), cancellationToken); // Assert Assert.NotNull(result); - Assert.Equal(expectedData, (result.Content as ImageContentBlock)?.Data); + Assert.Equal(expectedData, result.Content.OfType().FirstOrDefault()?.Data); Assert.Equal("test-model", result.Model); Assert.Equal(Role.Assistant, result.Role); Assert.Equal("endTurn", result.StopReason); @@ -222,7 +223,7 @@ public async Task CreateSamplingHandler_ShouldHandleResourceMessages() new SamplingMessage { Role = Role.User, - Content = new EmbeddedResourceBlock { Resource = resource }, + Content = [new EmbeddedResourceBlock { Resource = resource }], } ], MaxTokens = 100 @@ -247,7 +248,7 @@ public async Task CreateSamplingHandler_ShouldHandleResourceMessages() .Setup(client => client.GetStreamingResponseAsync(It.IsAny>(), It.IsAny(), cancellationToken)) .Returns(expectedResponse); - var handler = McpClientExtensions.CreateSamplingHandler(mockChatClient.Object); + var handler = mockChatClient.Object.CreateSamplingHandler(); // Act var result = await handler(requestParams, Mock.Of>(), cancellationToken); @@ -542,4 +543,120 @@ public async Task ReturnsNegotiatedProtocolVersion(string? protocolVersion) await using McpClient client = await CreateMcpClientForServer(new() { ProtocolVersion = protocolVersion }); Assert.Equal(protocolVersion ?? "2025-06-18", client.NegotiatedProtocolVersion); } + + [Fact] + public async Task EndToEnd_SamplingWithTools_ServerUsesIChatClientWithFunctionInvocation_ClientHandlesSamplingWithIChatClient() + { + int getWeatherToolCallCount = 0; + int askClientToolCallCount = 0; + + Server.ServerOptions.ToolCollection?.Add(McpServerTool.Create( + async (McpServer server, string query, CancellationToken cancellationToken) => + { + askClientToolCallCount++; + + var weatherTool = AIFunctionFactory.Create( + (string location) => + { + getWeatherToolCallCount++; + return $"Weather in {location}: sunny, 22°C"; + }, + "get_weather", "Gets the weather for a location"); + + var response = await server + .AsSamplingChatClient() + .AsBuilder() + .UseFunctionInvocation() + .Build() + .GetResponseAsync(query, new ChatOptions { Tools = [weatherTool] }, cancellationToken); + + return response.Text ?? "No response"; + }, + new() { Name = "ask_client", Description = "Asks the client a question using sampling" })); + + int samplingCallCount = 0; + TestChatClient testChatClient = new((messages, options, ct) => + { + int currentCall = samplingCallCount++; + var lastMessage = messages.LastOrDefault(); + + // First call: Return a tool call request for get_weather + if (currentCall == 0) + { + return Task.FromResult(new([ + new ChatMessage(ChatRole.User, messages.First().Contents), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("call_weather_123", "get_weather", new Dictionary { ["location"] = "Paris" })]) + ]) + { + ModelId = "test-model", + FinishReason = ChatFinishReason.ToolCalls + }); + } + // Second call (after tool result): Return final text response + else + { + var toolResult = lastMessage?.Contents.OfType().FirstOrDefault(); + Assert.NotNull(toolResult); + Assert.Equal("call_weather_123", toolResult.CallId); + + string resultText = toolResult.Result?.ToString() ?? string.Empty; + Assert.Contains("Weather in Paris: sunny", resultText); + + return Task.FromResult(new([ + new ChatMessage(ChatRole.User, messages.First().Contents), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("call_weather_123", "get_weather", new Dictionary { ["location"] = "Paris" })]), + new ChatMessage(ChatRole.User, [toolResult]), + new ChatMessage(ChatRole.Assistant, [new TextContent($"Based on the weather data: {resultText}")]) + ]) + { + ModelId = "test-model", + FinishReason = ChatFinishReason.Stop + }); + } + }); + + await using McpClient client = await CreateMcpClientForServer(new() + { + Handlers = new() { SamplingHandler = testChatClient.CreateSamplingHandler() }, + }); + + var result = await client.CallToolAsync( + "ask_client", + new Dictionary { ["query"] = "What's the weather in Paris?" }, + cancellationToken: TestContext.Current.CancellationToken); + Assert.NotNull(result); + Assert.Null(result.IsError); + + var textContent = result.Content.OfType().FirstOrDefault(); + Assert.NotNull(textContent); + Assert.Contains("Weather in Paris: sunny, 22", textContent.Text); + Assert.Equal(1, getWeatherToolCallCount); + Assert.Equal(1, askClientToolCallCount); + Assert.Equal(2, samplingCallCount); + } + + /// Simple test IChatClient implementation for testing. + private sealed class TestChatClient(Func, ChatOptions?, CancellationToken, Task> getResponse) : IChatClient + { + public Task GetResponseAsync( + IEnumerable messages, + ChatOptions? options = null, + CancellationToken cancellationToken = default) => + getResponse(messages, options, cancellationToken); + + async IAsyncEnumerable IChatClient.GetStreamingResponseAsync( + IEnumerable messages, + ChatOptions? options, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + var response = await GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); + foreach (var update in response.ToChatResponseUpdates()) + { + yield return update; + } + } + + object? IChatClient.GetService(Type serviceType, object? serviceKey) => null; + void IDisposable.Dispose() { } + } } \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientToolTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientToolTests.cs index 45d8a467f..7f1cc3689 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientToolTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientToolTests.cs @@ -352,13 +352,10 @@ public async Task ResourceLinkTool_ReturnsJsonElement() Assert.IsType(result); var jsonElement = (JsonElement)result!; - Assert.True(jsonElement.TryGetProperty("content", out var contentArray)); - Assert.Equal(JsonValueKind.Array, contentArray.ValueKind); - Assert.Equal(1, contentArray.GetArrayLength()); + Assert.True(jsonElement.TryGetProperty("content", out var contentValue)); + Assert.Equal(JsonValueKind.Array, contentValue.ValueKind); - var firstContent = contentArray[0]; - Assert.True(firstContent.TryGetProperty("type", out var typeProperty)); - Assert.Equal("resource_link", typeProperty.GetString()); + Assert.Equal(1, contentValue.GetArrayLength()); } [Fact] diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs index 16fad124a..ff6f56e24 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs @@ -1,4 +1,4 @@ -using Microsoft.Extensions.AI; +using Microsoft.Extensions.AI; using ModelContextProtocol.Client; using ModelContextProtocol.Protocol; using ModelContextProtocol.Tests.Utils; @@ -383,7 +383,7 @@ public async Task Sampling_Stdio(string clientId) { Model = "test-model", Role = Role.Assistant, - Content = new TextContentBlock { Text = "Test response" }, + Content = [new TextContentBlock { Text = "Test response" }], }; } } diff --git a/tests/ModelContextProtocol.Tests/DockerEverythingServerTests.cs b/tests/ModelContextProtocol.Tests/DockerEverythingServerTests.cs index 2d5ef5f2d..31a8236f2 100644 --- a/tests/ModelContextProtocol.Tests/DockerEverythingServerTests.cs +++ b/tests/ModelContextProtocol.Tests/DockerEverythingServerTests.cs @@ -1,4 +1,4 @@ -using ModelContextProtocol.Client; +using ModelContextProtocol.Client; using ModelContextProtocol.Protocol; using ModelContextProtocol.Tests.Utils; @@ -81,7 +81,7 @@ public async Task Sampling_Sse_EverythingServer() { Model = "test-model", Role = Role.Assistant, - Content = new TextContentBlock { Text = "Test response" }, + Content = [new TextContentBlock { Text = "Test response" }], }; } } diff --git a/tests/ModelContextProtocol.Tests/Protocol/ContentBlockTests.cs b/tests/ModelContextProtocol.Tests/Protocol/ContentBlockTests.cs index 3d8d8ff18..0113b77f3 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/ContentBlockTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/ContentBlockTests.cs @@ -1,4 +1,3 @@ -using Microsoft.Extensions.AI; using ModelContextProtocol.Protocol; using System.Text.Json; @@ -125,4 +124,105 @@ public void Deserialize_IgnoresUnknownObjectProperties() var textBlock = Assert.IsType(contentBlock); Assert.Contains("Sample text", textBlock.Text); } + + [Fact] + public void ToolResultContentBlock_WithError_SerializationRoundtrips() + { + ToolResultContentBlock toolResult = new() + { + ToolUseId = "call_123", + Content = [new TextContentBlock { Text = "Error: City not found" }], + IsError = true + }; + + var json = JsonSerializer.Serialize(toolResult, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + var result = Assert.IsType(deserialized); + Assert.Equal("call_123", result.ToolUseId); + Assert.True(result.IsError); + Assert.Single(result.Content); + var textBlock = Assert.IsType(result.Content[0]); + Assert.Equal("Error: City not found", textBlock.Text); + } + + [Fact] + public void ToolResultContentBlock_WithStructuredContent_SerializationRoundtrips() + { + ToolResultContentBlock toolResult = new() + { + ToolUseId = "call_123", + Content = + [ + new TextContentBlock { Text = "Result data" } + ], + StructuredContent = JsonElement.Parse("""{"temperature":18,"condition":"cloudy"}"""), + IsError = false + }; + + var json = JsonSerializer.Serialize(toolResult, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + var result = Assert.IsType(deserialized); + Assert.Equal("call_123", result.ToolUseId); + Assert.Single(result.Content); + var textBlock = Assert.IsType(result.Content[0]); + Assert.Equal("Result data", textBlock.Text); + Assert.NotNull(result.StructuredContent); + Assert.Equal(18, result.StructuredContent.Value.GetProperty("temperature").GetInt32()); + Assert.Equal("cloudy", result.StructuredContent.Value.GetProperty("condition").GetString()); + Assert.False(result.IsError); + } + + [Fact] + public void ToolResultContentBlock_SerializationRoundTrip() + { + ToolResultContentBlock toolResult = new() + { + ToolUseId = "call_123", + Content = + [ + new TextContentBlock { Text = "Result data" }, + new ImageContentBlock { Data = "base64data", MimeType = "image/png" } + ], + StructuredContent = JsonElement.Parse("""{"temperature":18,"condition":"cloudy"}"""), + IsError = false + }; + + var json = JsonSerializer.Serialize(toolResult, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + var result = Assert.IsType(deserialized); + Assert.Equal("call_123", result.ToolUseId); + Assert.Equal(2, result.Content.Count); + var textBlock = Assert.IsType(result.Content[0]); + Assert.Equal("Result data", textBlock.Text); + var imageBlock = Assert.IsType(result.Content[1]); + Assert.Equal("base64data", imageBlock.Data); + Assert.Equal("image/png", imageBlock.MimeType); + Assert.NotNull(result.StructuredContent); + Assert.Equal(18, result.StructuredContent.Value.GetProperty("temperature").GetInt32()); + Assert.Equal("cloudy", result.StructuredContent.Value.GetProperty("condition").GetString()); + Assert.False(result.IsError); + } + + [Fact] + public void ToolUseContentBlock_SerializationRoundTrip() + { + ToolUseContentBlock toolUse = new() + { + Id = "call_abc123", + Name = "get_weather", + Input = JsonElement.Parse("""{"city":"Paris","units":"metric"}""") + }; + + var json = JsonSerializer.Serialize(toolUse, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + var result = Assert.IsType(deserialized); + Assert.Equal("call_abc123", result.Id); + Assert.Equal("get_weather", result.Name); + Assert.Equal("Paris", result.Input.GetProperty("city").GetString()); + Assert.Equal("metric", result.Input.GetProperty("units").GetString()); + } } \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Protocol/CreateMessageRequestParamsTests.cs b/tests/ModelContextProtocol.Tests/Protocol/CreateMessageRequestParamsTests.cs new file mode 100644 index 000000000..f57faf1d8 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Protocol/CreateMessageRequestParamsTests.cs @@ -0,0 +1,174 @@ +using ModelContextProtocol.Protocol; +using System.Text.Json; + +namespace ModelContextProtocol.Tests.Protocol; + +public class CreateMessageRequestParamsTests +{ + [Fact] + public void WithTools_SerializationRoundtrips() + { + CreateMessageRequestParams requestParams = new() + { + MaxTokens = 1000, + Messages = + [ + new SamplingMessage + { + Role = Role.User, + Content = [new TextContentBlock { Text = "What's the weather in Paris?" }] + } + ], + Tools = + [ + new Tool + { + Name = "get_weather", + Description = "Get weather for a city", + InputSchema = JsonElement.Parse(""" + { + "type": "object", + "properties": { + "city": { "type": "string" } + }, + "required": ["city"] + } + """) + } + ], + ToolChoice = new ToolChoice { Mode = "auto" } + }; + + var json = JsonSerializer.Serialize(requestParams, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + Assert.NotNull(deserialized); + Assert.Equal(1000, deserialized.MaxTokens); + Assert.NotNull(deserialized.Messages); + Assert.Single(deserialized.Messages); + Assert.Equal(Role.User, deserialized.Messages[0].Role); + Assert.Single(deserialized.Messages[0].Content); + var textContent = Assert.IsType(deserialized.Messages[0].Content[0]); + Assert.Equal("What's the weather in Paris?", textContent.Text); + Assert.NotNull(deserialized.Tools); + Assert.Single(deserialized.Tools); + Assert.Equal("get_weather", deserialized.Tools[0].Name); + Assert.Equal("Get weather for a city", deserialized.Tools[0].Description); + Assert.Equal("object", deserialized.Tools[0].InputSchema.GetProperty("type").GetString()); + Assert.True(deserialized.Tools[0].InputSchema.GetProperty("properties").TryGetProperty("city", out var cityProp)); + Assert.Equal("string", cityProp.GetProperty("type").GetString()); + Assert.Single(deserialized.Tools[0].InputSchema.GetProperty("required").EnumerateArray()); + Assert.Equal("city", deserialized.Tools[0].InputSchema.GetProperty("required")[0].GetString()); + Assert.NotNull(deserialized.ToolChoice); + Assert.Equal("auto", deserialized.ToolChoice.Mode); + } + + [Fact] + public void WithToolChoiceRequired_SerializationRoundtrips() + { + CreateMessageRequestParams requestParams = new() + { + MaxTokens = 1000, + Messages = + [ + new SamplingMessage + { + Role = Role.User, + Content = [new TextContentBlock { Text = "What's the weather?" }] + } + ], + Tools = + [ + new Tool + { + Name = "get_weather", + Description = "Get weather for a city", + InputSchema = JsonElement.Parse(""" + { + "type": "object", + "properties": { "city": { "type": "string" } }, + "required": ["city"] + } + """) + } + ], + ToolChoice = new ToolChoice { Mode = "required" } + }; + + var json = JsonSerializer.Serialize(requestParams, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + Assert.NotNull(deserialized); + Assert.Equal(1000, deserialized.MaxTokens); + Assert.NotNull(deserialized.Messages); + Assert.Single(deserialized.Messages); + Assert.Equal(Role.User, deserialized.Messages[0].Role); + Assert.Single(deserialized.Messages[0].Content); + var textContent = Assert.IsType(deserialized.Messages[0].Content[0]); + Assert.Equal("What's the weather?", textContent.Text); + Assert.NotNull(deserialized.Tools); + Assert.Single(deserialized.Tools); + Assert.Equal("get_weather", deserialized.Tools[0].Name); + Assert.Equal("Get weather for a city", deserialized.Tools[0].Description); + Assert.Equal("object", deserialized.Tools[0].InputSchema.GetProperty("type").GetString()); + Assert.NotNull(deserialized.ToolChoice); + Assert.Equal("required", deserialized.ToolChoice.Mode); + } + + [Fact] + public void WithToolChoiceNone_SerializationRoundtrips() + { + CreateMessageRequestParams requestParams = new() + { + MaxTokens = 1000, + Messages = + [ + new SamplingMessage + { + Role = Role.User, + Content = [new TextContentBlock { Text = "What's the weather in Paris?" }] + } + ], + Tools = + [ + new Tool + { + Name = "get_weather", + Description = "Get weather for a city", + InputSchema = JsonElement.Parse(""" + { + "type": "object", + "properties": { "city": { "type": "string" } }, + "required": ["city"] + } + """) + } + ], + ToolChoice = new ToolChoice { Mode = "none" } + }; + + var json = JsonSerializer.Serialize(requestParams, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + Assert.NotNull(deserialized); + Assert.Equal(1000, deserialized.MaxTokens); + Assert.NotNull(deserialized.Messages); + Assert.Single(deserialized.Messages); + Assert.Equal(Role.User, deserialized.Messages[0].Role); + Assert.Single(deserialized.Messages[0].Content); + var textContent = Assert.IsType(deserialized.Messages[0].Content[0]); + Assert.Equal("What's the weather in Paris?", textContent.Text); + Assert.NotNull(deserialized.Tools); + Assert.Single(deserialized.Tools); + Assert.Equal("get_weather", deserialized.Tools[0].Name); + Assert.Equal("Get weather for a city", deserialized.Tools[0].Description); + Assert.Equal("object", deserialized.Tools[0].InputSchema.GetProperty("type").GetString()); + Assert.NotNull(deserialized.ToolChoice); + Assert.Equal("none", deserialized.ToolChoice.Mode); + } +} + + + + + diff --git a/tests/ModelContextProtocol.Tests/Protocol/CreateMessageResultTests.cs b/tests/ModelContextProtocol.Tests/Protocol/CreateMessageResultTests.cs new file mode 100644 index 000000000..67ab5f4f9 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Protocol/CreateMessageResultTests.cs @@ -0,0 +1,247 @@ +using ModelContextProtocol.Protocol; +using System.Text.Json; +using System.Text.Json.Nodes; + +namespace ModelContextProtocol.Tests.Protocol; + +public class CreateMessageResultTests +{ + [Fact] + public void CreateMessageResult_WithSingleContent_Serializes() + { + CreateMessageResult result = new() + { + Role = Role.Assistant, + Model = "test-model", + Content = [new TextContentBlock { Text = "Hello" }], + StopReason = "endTurn" + }; + + var json = JsonSerializer.Serialize(result, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + Assert.NotNull(deserialized); + Assert.Single(deserialized.Content); + Assert.IsType(deserialized.Content[0]); + } + + [Fact] + public void CreateMessageResult_WithMultipleToolUses_Serializes() + { + CreateMessageResult result = new() + { + Role = Role.Assistant, + Model = "test-model", + Content = + [ + new ToolUseContentBlock + { + Id = "call_1", + Name = "tool1", + Input = JsonElement.Parse("""{}""") + }, + new ToolUseContentBlock + { + Id = "call_2", + Name = "tool2", + Input = JsonElement.Parse("""{}""") + } + ], + StopReason = "toolUse" + }; + + var json = JsonSerializer.Serialize(result, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + Assert.NotNull(deserialized); + Assert.Equal(2, deserialized.Content.Count); + Assert.All(deserialized.Content, c => Assert.IsType(c)); + Assert.Equal("call_1", ((ToolUseContentBlock)deserialized.Content[0]).Id); + Assert.Equal("call_2", ((ToolUseContentBlock)deserialized.Content[1]).Id); + } + + [Fact] + public void CreateMessageResult_WithMixedContent_Serializes() + { + CreateMessageResult result = new() + { + Role = Role.Assistant, + Model = "test-model", + Content = + [ + new TextContentBlock { Text = "Let me check that." }, + new ToolUseContentBlock + { + Id = "call_1", + Name = "tool1", + Input = JsonElement.Parse("""{}""") + } + ], + StopReason = "toolUse" + }; + + var json = JsonSerializer.Serialize(result, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + Assert.NotNull(deserialized); + Assert.Equal(2, deserialized.Content.Count); + Assert.IsType(deserialized.Content[0]); + Assert.IsType(deserialized.Content[1]); + } + + [Fact] + public void CreateMessageResult_EmptyContent_AllowedButUnusual() + { + CreateMessageResult result = new() + { + Role = Role.Assistant, + Model = "test-model", + Content = [], + StopReason = "endTurn" + }; + + var json = JsonSerializer.Serialize(result, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + Assert.NotNull(deserialized); + Assert.Empty(deserialized.Content); + } + + [Fact] + public void CreateMessageResult_WithImageContent_Serializes() + { + CreateMessageResult result = new() + { + Role = Role.Assistant, + Model = "test-model", + Content = + [ + new ImageContentBlock + { + Data = Convert.ToBase64String([1, 2, 3, 4, 5]), + MimeType = "image/png" + } + ], + StopReason = "endTurn" + }; + + var json = JsonSerializer.Serialize(result, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + Assert.NotNull(deserialized); + Assert.Single(deserialized.Content); + var imageBlock = Assert.IsType(deserialized.Content[0]); + Assert.Equal("image/png", imageBlock.MimeType); + } + + [Fact] + public void CreateMessageResult_RoundTripWithAllFields() + { + CreateMessageResult original = new() + { + Role = Role.Assistant, + Model = "claude-3-sonnet", + Content = + [ + new TextContentBlock { Text = "I'll help you with that." }, + new ToolUseContentBlock + { + Id = "call_xyz", + Name = "calculator", + Input = JsonElement.Parse("""{"operation":"add","a":5,"b":3}""") + } + ], + StopReason = "toolUse", + Meta = (JsonObject)JsonNode.Parse("""{"custom":"metadata"}""")! + }; + + var json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + Assert.NotNull(deserialized); + Assert.Equal(Role.Assistant, deserialized.Role); + Assert.Equal("claude-3-sonnet", deserialized.Model); + Assert.Equal(2, deserialized.Content.Count); + Assert.Equal("toolUse", deserialized.StopReason); + Assert.NotNull(deserialized.Meta); + } + + [Fact] + public void CreateMessageResult_WithToolUse_SerializationRoundtrips() + { + CreateMessageResult result = new() + { + Role = Role.Assistant, + Model = "test-model", + Content = + [ + new ToolUseContentBlock + { + Id = "call_123", + Name = "get_weather", + Input = JsonElement.Parse("""{"city":"Paris"}""") + } + ], + StopReason = "toolUse" + }; + + var json = JsonSerializer.Serialize(result, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + Assert.NotNull(deserialized); + Assert.Equal(Role.Assistant, deserialized.Role); + Assert.Equal("test-model", deserialized.Model); + Assert.Equal("toolUse", deserialized.StopReason); + Assert.Single(deserialized.Content); + + var toolUse = Assert.IsType(deserialized.Content[0]); + Assert.Equal("call_123", toolUse.Id); + Assert.Equal("get_weather", toolUse.Name); + Assert.Equal("Paris", toolUse.Input.GetProperty("city").GetString()); + } + + [Fact] + public void CreateMessageResult_WithParallelToolUses_SerializationRoundtrips() + { + CreateMessageResult result = new() + { + Role = Role.Assistant, + Model = "test-model", + Content = + [ + new ToolUseContentBlock + { + Id = "call_abc123", + Name = "get_weather", + Input = JsonElement.Parse("""{"city":"Paris"}""") + }, + new ToolUseContentBlock + { + Id = "call_def456", + Name = "get_weather", + Input = JsonElement.Parse("""{"city":"London"}""") + } + ], + StopReason = "toolUse" + }; + + var json = JsonSerializer.Serialize(result, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + Assert.NotNull(deserialized); + Assert.Equal(Role.Assistant, deserialized.Role); + Assert.Equal("test-model", deserialized.Model); + Assert.Equal("toolUse", deserialized.StopReason); + Assert.Equal(2, deserialized.Content.Count); + + var toolUse1 = Assert.IsType(deserialized.Content[0]); + Assert.Equal("call_abc123", toolUse1.Id); + Assert.Equal("get_weather", toolUse1.Name); + Assert.Equal("Paris", toolUse1.Input.GetProperty("city").GetString()); + + var toolUse2 = Assert.IsType(deserialized.Content[1]); + Assert.Equal("call_def456", toolUse2.Id); + Assert.Equal("get_weather", toolUse2.Name); + Assert.Equal("London", toolUse2.Input.GetProperty("city").GetString()); + } +} diff --git a/tests/ModelContextProtocol.Tests/Protocol/SamplingMessageTests.cs b/tests/ModelContextProtocol.Tests/Protocol/SamplingMessageTests.cs new file mode 100644 index 000000000..9765d1be3 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Protocol/SamplingMessageTests.cs @@ -0,0 +1,111 @@ +using ModelContextProtocol.Protocol; +using System.Text.Json; + +namespace ModelContextProtocol.Tests.Protocol; + +public class SamplingMessageTests +{ + [Fact] + public void WithToolResults_SerializationRoundtrips() + { + SamplingMessage message = new() + { + Role = Role.User, + Content = + [ + new ToolResultContentBlock + { + ToolUseId = "call_123", + Content = + [ + new TextContentBlock { Text = "Weather in Paris: 18°C, partly cloudy" } + ] + } + ] + }; + + var json = JsonSerializer.Serialize(message, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + Assert.NotNull(deserialized); + Assert.Equal(Role.User, deserialized.Role); + Assert.Single(deserialized.Content); + + var toolResult = Assert.IsType(deserialized.Content[0]); + Assert.Equal("call_123", toolResult.ToolUseId); + Assert.Single(toolResult.Content); + + var textBlock = Assert.IsType(toolResult.Content[0]); + Assert.Equal("Weather in Paris: 18°C, partly cloudy", textBlock.Text); + } + + [Fact] + public void WithMultipleToolResults_SerializationRoundtrips() + { + SamplingMessage message = new() + { + Role = Role.User, + Content = + [ + new ToolResultContentBlock + { + ToolUseId = "call_abc123", + Content = [new TextContentBlock { Text = "Weather in Paris: 18°C, partly cloudy" }] + }, + new ToolResultContentBlock + { + ToolUseId = "call_def456", + Content = [new TextContentBlock { Text = "Weather in London: 15°C, rainy" }] + } + ] + }; + + var json = JsonSerializer.Serialize(message, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + Assert.NotNull(deserialized); + Assert.Equal(Role.User, deserialized.Role); + Assert.Equal(2, deserialized.Content.Count); + + var toolResult1 = Assert.IsType(deserialized.Content[0]); + Assert.Equal("call_abc123", toolResult1.ToolUseId); + Assert.Single(toolResult1.Content); + var textBlock1 = Assert.IsType(toolResult1.Content[0]); + Assert.Equal("Weather in Paris: 18°C, partly cloudy", textBlock1.Text); + + var toolResult2 = Assert.IsType(deserialized.Content[1]); + Assert.Equal("call_def456", toolResult2.ToolUseId); + Assert.Single(toolResult2.Content); + var textBlock2 = Assert.IsType(toolResult2.Content[0]); + Assert.Equal("Weather in London: 15°C, rainy", textBlock2.Text); + } + + [Fact] + public void WithToolResultOnly_SerializationRoundtrips() + { + SamplingMessage message = new() + { + Role = Role.User, + Content = + [ + new ToolResultContentBlock + { + ToolUseId = "call_123", + Content = [new TextContentBlock { Text = "Result" }] + } + ] + }; + + var json = JsonSerializer.Serialize(message, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + Assert.NotNull(deserialized); + Assert.Equal(Role.User, deserialized.Role); + Assert.Single(deserialized.Content); + var toolResult = Assert.IsType(deserialized.Content[0]); + Assert.Equal("call_123", toolResult.ToolUseId); + Assert.Single(toolResult.Content); + var textBlock = Assert.IsType(toolResult.Content[0]); + Assert.Equal("Result", textBlock.Text); + } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Protocol/ToolChoiceTests.cs b/tests/ModelContextProtocol.Tests/Protocol/ToolChoiceTests.cs new file mode 100644 index 000000000..548579d50 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Protocol/ToolChoiceTests.cs @@ -0,0 +1,30 @@ +using ModelContextProtocol.Protocol; +using System.Text.Json; + +namespace ModelContextProtocol.Tests.Protocol; + +public class ToolChoiceTests +{ + [Fact] + public void DefaultModeIsNull() + { + Assert.Null(new ToolChoice().Mode); + } + + [Theory] + [InlineData(null)] + [InlineData("none")] + [InlineData("required")] + [InlineData("auto")] + [InlineData("something_custom")] + public void SerializesWithMode(string? mode) + { + ToolChoice toolChoice = new() { Mode = mode }; + + var json = JsonSerializer.Serialize(toolChoice, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + Assert.NotNull(deserialized); + Assert.Equal(mode, deserialized.Mode); + } +} diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerExtensionsTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerExtensionsTests.cs index bf90b218a..ed069c820 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerExtensionsTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerExtensionsTests.cs @@ -18,7 +18,7 @@ public async Task SampleAsync_Request_Throws_When_Not_McpServer() var ex = await Assert.ThrowsAsync(async () => await server.SampleAsync( new CreateMessageRequestParams { - Messages = [new SamplingMessage { Role = Role.User, Content = new TextContentBlock { Text = "hi" } }], + Messages = [new SamplingMessage { Role = Role.User, Content = [new TextContentBlock { Text = "hi" }] }], MaxTokens = 1000 }, TestContext.Current.CancellationToken)); @@ -80,7 +80,7 @@ public async Task SampleAsync_Request_Forwards_To_McpServer_SendRequestAsync() var resultPayload = new CreateMessageResult { - Content = new TextContentBlock { Text = "resp" }, + Content = [new TextContentBlock { Text = "resp" }], Model = "test-model", Role = Role.Assistant, StopReason = "endTurn", @@ -102,13 +102,13 @@ public async Task SampleAsync_Request_Forwards_To_McpServer_SendRequestAsync() var result = await server.SampleAsync(new CreateMessageRequestParams { - Messages = [new SamplingMessage { Role = Role.User, Content = new TextContentBlock { Text = "hi" } }], + Messages = [new SamplingMessage { Role = Role.User, Content = [new TextContentBlock { Text = "hi" }] }], MaxTokens = 1000 }, TestContext.Current.CancellationToken); Assert.Equal("test-model", result.Model); Assert.Equal(Role.Assistant, result.Role); - Assert.Equal("resp", Assert.IsType(result.Content).Text); + Assert.Equal("resp", Assert.IsType(result.Content[0]).Text); mockServer.Verify(s => s.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); } @@ -119,7 +119,7 @@ public async Task SampleAsync_Messages_Forwards_To_McpServer_SendRequestAsync() var resultPayload = new CreateMessageResult { - Content = new TextContentBlock { Text = "resp" }, + Content = [new TextContentBlock { Text = "resp" }], Model = "test-model", Role = Role.Assistant, StopReason = "endTurn", diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index 810bcef48..ab2537b6b 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -1,4 +1,4 @@ -using Microsoft.Extensions.AI; +using Microsoft.Extensions.AI; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; using ModelContextProtocol.Tests.Utils; @@ -802,12 +802,12 @@ public override Task SendRequestAsync(JsonRpcRequest request, C Assert.Equal($"You are a helpful assistant.{Environment.NewLine}More system stuff.", rp.SystemPrompt); Assert.Equal(2, rp.Messages.Count); - Assert.Equal("I am going to France.", Assert.IsType(rp.Messages[0].Content).Text); - Assert.Equal("What is the most famous tower in Paris?", Assert.IsType(rp.Messages[1].Content).Text); + Assert.Equal("I am going to France.", Assert.IsType(Assert.Single(rp.Messages[0].Content)).Text); + Assert.Equal("What is the most famous tower in Paris?", Assert.IsType(Assert.Single(rp.Messages[1].Content)).Text); CreateMessageResult result = new() { - Content = new TextContentBlock { Text = "The Eiffel Tower." }, + Content = [new TextContentBlock { Text = "The Eiffel Tower." }], Model = "amazingmodel", Role = Role.Assistant, StopReason = "endTurn",