Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions samples/ChatWithTools/Program.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Logging;
using ModelContextProtocol;
using ModelContextProtocol.Client;
using OpenAI;
using OpenTelemetry;
Expand Down
4 changes: 2 additions & 2 deletions samples/EverythingServer/Tools/SampleLlmTool.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public static async Task<string> SampleLLM(
var samplingParams = CreateRequestSamplingParams(prompt ?? string.Empty, "sampleLLM", maxTokens);
var sampleResult = await server.SampleAsync(samplingParams, cancellationToken);

return $"LLM sampling result: {(sampleResult.Content as TextContentBlock)?.Text}";
return $"LLM sampling result: {sampleResult.Content.OfType<TextContentBlock>().FirstOrDefault()?.Text}";
}

private static CreateMessageRequestParams CreateRequestSamplingParams(string context, string uri, int maxTokens = 100)
Expand All @@ -27,7 +27,7 @@ private static CreateMessageRequestParams CreateRequestSamplingParams(string con
Messages = [new SamplingMessage
{
Role = Role.User,
Content = new TextContentBlock { Text = $"Resource {uri} context: {context}" },
Content = [new TextContentBlock { Text = $"Resource {uri} context: {context}" }],
}],
SystemPrompt = "You are a helpful test server.",
MaxTokens = maxTokens,
Expand Down
4 changes: 2 additions & 2 deletions samples/TestServerWithHosting/Tools/SampleLlmTool.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public static async Task<string> SampleLLM(
var samplingParams = CreateRequestSamplingParams(prompt ?? string.Empty, "sampleLLM", maxTokens);
var sampleResult = await thisServer.SampleAsync(samplingParams, cancellationToken);

return $"LLM sampling result: {(sampleResult.Content as TextContentBlock)?.Text}";
return $"LLM sampling result: {sampleResult.Content.OfType<TextContentBlock>().FirstOrDefault()?.Text}";
}

private static CreateMessageRequestParams CreateRequestSamplingParams(string context, string uri, int maxTokens = 100)
Expand All @@ -30,7 +30,7 @@ private static CreateMessageRequestParams CreateRequestSamplingParams(string con
Messages = [new SamplingMessage
{
Role = Role.User,
Content = new TextContentBlock { Text = $"Resource {uri} context: {context}" },
Content = [new TextContentBlock { Text = $"Resource {uri} context: {context}" }],
}],
SystemPrompt = "You are a helpful test server.",
MaxTokens = maxTokens,
Expand Down
212 changes: 208 additions & 4 deletions src/ModelContextProtocol.Core/AIContentExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
using Microsoft.Extensions.AI;
using ModelContextProtocol.Client;
using ModelContextProtocol.Protocol;
#if !NET
using System.Runtime.InteropServices;
#endif
using System.Text.Json;
using System.Text.Json.Nodes;

namespace ModelContextProtocol;

Expand All @@ -16,6 +18,140 @@ namespace ModelContextProtocol;
/// </remarks>
public static class AIContentExtensions
{
/// <summary>
/// Creates a sampling handler for use with <see cref="McpClientHandlers.SamplingHandler"/> that will
/// satisfy sampling requests using the specified <see cref="IChatClient"/>.
/// </summary>
/// <param name="chatClient">The <see cref="IChatClient"/> with which to satisfy sampling requests.</param>
/// <returns>The created handler delegate that can be assigned to <see cref="McpClientHandlers.SamplingHandler"/>.</returns>
/// <remarks>
/// <para>
/// This method creates a function that converts MCP message requests into chat client calls, enabling
/// an MCP client to generate text or other content using an actual AI model via the provided chat client.
/// </para>
/// <para>
/// The handler can process text messages, image messages, resource messages, and tool use/results as defined in the
/// Model Context Protocol.
/// </para>
/// </remarks>
/// <exception cref="ArgumentNullException"><paramref name="chatClient"/> is <see langword="null"/>.</exception>
public static Func<CreateMessageRequestParams?, IProgress<ProgressNotificationValue>, CancellationToken, ValueTask<CreateMessageResult>> CreateSamplingHandler(
this IChatClient chatClient)
{
Throw.IfNull(chatClient);

return async (requestParams, progress, cancellationToken) =>
{
Throw.IfNull(requestParams);

var (messages, options) = ToChatClientArguments(requestParams);
var progressToken = requestParams.ProgressToken;

List<ChatResponseUpdate> updates = [];
await foreach (var update in chatClient.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false))
{
updates.Add(update);

if (progressToken is not null)
{
progress.Report(new() { Progress = updates.Count });
}
}

ChatResponse? chatResponse = updates.ToChatResponse();
ChatMessage? lastMessage = chatResponse.Messages.LastOrDefault();

IList<ContentBlock>? contents = lastMessage?.Contents.Select(c => c.ToContentBlock()).ToList();
if (contents is not { Count: > 0 })
{
(contents ??= []).Add(new TextContentBlock() { Text = "" });
}

return new()
{
Model = chatResponse.ModelId ?? "",
StopReason =
chatResponse.FinishReason == ChatFinishReason.Stop ? CreateMessageResult.StopReasonEndTurn :
chatResponse.FinishReason == ChatFinishReason.Length ? CreateMessageResult.StopReasonMaxTokens :
chatResponse.FinishReason == ChatFinishReason.ToolCalls ? CreateMessageResult.StopReasonToolUse :
chatResponse.FinishReason.ToString(),
Meta = chatResponse.AdditionalProperties?.ToJsonObject(),
Role = lastMessage?.Role == ChatRole.User ? Role.User : Role.Assistant,
Content = contents,
};

static (IList<ChatMessage> Messages, ChatOptions? Options) ToChatClientArguments(CreateMessageRequestParams requestParams)
{
ChatOptions? options = null;

if (requestParams.MaxTokens is int maxTokens)
{
(options ??= new()).MaxOutputTokens = maxTokens;
}

if (requestParams.Temperature is float temperature)
{
(options ??= new()).Temperature = temperature;
}

if (requestParams.StopSequences is { } stopSequences)
{
(options ??= new()).StopSequences = stopSequences.ToArray();
}

if (requestParams.SystemPrompt is { } systemPrompt)
{
(options ??= new()).Instructions = systemPrompt;
}

if (requestParams.Tools is { } tools)
{
foreach (var tool in tools)
{
((options ??= new()).Tools ??= []).Add(new ToolAIFunctionDeclaration(tool));
}

if (options.Tools is { Count: > 0 } && requestParams.ToolChoice is { } toolChoice)
{
options.ToolMode = toolChoice.Mode switch
{
ToolChoice.ModeAuto => ChatToolMode.Auto,
ToolChoice.ModeRequired => ChatToolMode.RequireAny,
ToolChoice.ModeNone => ChatToolMode.None,
_ => null,
};
}
}

List<ChatMessage> messages = [];
foreach (var sm in requestParams.Messages)
{
if (sm.Content?.Select(b => b.ToAIContent()).OfType<AIContent>().ToList() is { Count: > 0 } aiContents)
{
messages.Add(new ChatMessage(sm.Role is Role.Assistant ? ChatRole.Assistant : ChatRole.User, aiContents));
}
}

return (messages, options);
}
};
}

/// <summary>Converts the specified dictionary to a <see cref="JsonObject"/>.</summary>
internal static JsonObject? ToJsonObject(this IReadOnlyDictionary<string, object?> properties) =>
JsonSerializer.SerializeToNode(properties, McpJsonUtilities.JsonContext.Default.IReadOnlyDictionaryStringObject) as JsonObject;

internal static AdditionalPropertiesDictionary ToAdditionalProperties(this JsonObject obj)
{
AdditionalPropertiesDictionary d = [];
foreach (var kvp in obj)
{
d.Add(kvp.Key, kvp.Value);
}

return d;
}

/// <summary>
/// Converts a <see cref="PromptMessage"/> to a <see cref="ChatMessage"/> object.
/// </summary>
Expand Down Expand Up @@ -99,7 +235,7 @@ public static IList<PromptMessage> ToPromptMessages(this ChatMessage chatMessage
{
if (content is TextContent or DataContent)
{
messages.Add(new PromptMessage { Role = r, Content = content.ToContent() });
messages.Add(new PromptMessage { Role = r, Content = content.ToContentBlock() });
}
}

Expand All @@ -122,13 +258,31 @@ public static IList<PromptMessage> ToPromptMessages(this ChatMessage chatMessage
AIContent? ac = content switch
{
TextContentBlock textContent => new TextContent(textContent.Text),

ImageContentBlock imageContent => new DataContent(Convert.FromBase64String(imageContent.Data), imageContent.MimeType),

AudioContentBlock audioContent => new DataContent(Convert.FromBase64String(audioContent.Data), audioContent.MimeType),

EmbeddedResourceBlock resourceContent => resourceContent.Resource.ToAIContent(),

ToolUseContentBlock toolUse => FunctionCallContent.CreateFromParsedArguments(toolUse.Input, toolUse.Id, toolUse.Name,
static json => JsonSerializer.Deserialize(json, McpJsonUtilities.JsonContext.Default.IDictionaryStringObject)),

ToolResultContentBlock toolResult => new FunctionResultContent(
toolResult.ToolUseId,
toolResult.Content.Count == 1 ? toolResult.Content[0].ToAIContent() : toolResult.Content.Select(c => c.ToAIContent()).OfType<AIContent>().ToList())
{
Exception = toolResult.IsError is true ? new() : null,
},

_ => null,
};

ac?.RawRepresentation = content;
if (ac is not null)
{
ac.RawRepresentation = content;
ac.AdditionalProperties = content.Meta?.ToAdditionalProperties();
}

return ac;
}
Expand Down Expand Up @@ -200,8 +354,12 @@ public static IList<AIContent> ToAIContents(this IEnumerable<ResourceContents> c
return [.. contents.Select(ToAIContent)];
}

internal static ContentBlock ToContent(this AIContent content) =>
content switch
/// <summary>Creates a new <see cref="ContentBlock"/> from the content of an <see cref="AIContent"/>.</summary>
/// <param name="content">The <see cref="AIContent"/> to convert.</param>
/// <returns>The created <see cref="ContentBlock"/>.</returns>
public static ContentBlock ToContentBlock(this AIContent content)
{
ContentBlock contentBlock = content switch
{
TextContent textContent => new TextContentBlock
{
Expand Down Expand Up @@ -230,9 +388,55 @@ internal static ContentBlock ToContent(this AIContent content) =>
}
},

FunctionCallContent callContent => new ToolUseContentBlock()
{
Id = callContent.CallId,
Name = callContent.Name,
Input = JsonSerializer.SerializeToElement(callContent.Arguments, McpJsonUtilities.DefaultOptions.GetTypeInfo<IDictionary<string, object?>>()!),
},

FunctionResultContent resultContent => new ToolResultContentBlock()
{
ToolUseId = resultContent.CallId,
IsError = resultContent.Exception is not null,
Content =
resultContent.Result is AIContent c ? [c.ToContentBlock()] :
resultContent.Result is IEnumerable<AIContent> ec ? [.. ec.Select(c => c.ToContentBlock())] :
[new TextContentBlock { Text = JsonSerializer.Serialize(content, McpJsonUtilities.DefaultOptions.GetTypeInfo<object>()) }],
StructuredContent = resultContent.Result is JsonElement je ? je : null,
},

_ => new TextContentBlock
{
Text = JsonSerializer.Serialize(content, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(object))),
}
};

contentBlock.Meta = content.AdditionalProperties?.ToJsonObject();

return contentBlock;
}

private sealed class ToolAIFunctionDeclaration(Tool tool) : AIFunctionDeclaration
{
public override string Name => tool.Name;

public override string Description => tool.Description ?? "";

public override IReadOnlyDictionary<string, object?> AdditionalProperties =>
field ??= tool.Meta is { } meta ? meta.ToDictionary(p => p.Key, p => (object?)p.Value) : [];

public override JsonElement JsonSchema => tool.InputSchema;

public override JsonElement? ReturnJsonSchema => tool.OutputSchema;

public override object? GetService(Type serviceType, object? serviceKey = null)
{
Throw.IfNull(serviceType);

return
serviceKey is null && serviceType.IsInstanceOfType(tool) ? tool :
base.GetService(serviceType, serviceKey);
}
}
}
Loading