From cb539340997f43f427f7c742ee9e47375c9144df Mon Sep 17 00:00:00 2001 From: Daniel Cazzulino Date: Fri, 4 Jul 2025 01:46:39 -0300 Subject: [PATCH] Add tool-focused extension to inspect calls and results It's sometimes useful to be able to inspect (wether in tests or in production) the invocations that were performed by the model, including typed results. The introduced `ToolJsonOptions` provide a mechanism to automatically inject a `$type` for result types so they can be later inspected by the `FindCall` extension for `ChatResponse`. To simplify this scenario, we also provide a `ToolFactory` that automatically sets things up for this scenario, additionally making the tool name default to the naming convention for tools, rather than the .NET method name. --- readme.md | 48 +++++++++++ src/AI.Tests/ToolsTests.cs | 137 ++++++++++++++++++++++++++++++++ src/AI/ToolExtensions.cs | 105 ++++++++++++++++++++++++ src/AI/ToolFactory.cs | 19 +++++ src/AI/ToolJsonOptions.cs | 34 ++++++++ src/AI/TypeInjectingResolver.cs | 48 +++++++++++ 6 files changed, 391 insertions(+) create mode 100644 src/AI.Tests/ToolsTests.cs create mode 100644 src/AI/ToolExtensions.cs create mode 100644 src/AI/ToolFactory.cs create mode 100644 src/AI/ToolJsonOptions.cs create mode 100644 src/AI/TypeInjectingResolver.cs diff --git a/readme.md b/readme.md index c6d467d..adc943c 100644 --- a/readme.md +++ b/readme.md @@ -136,6 +136,54 @@ var openai = new OpenAIClient( OpenAIClientOptions.Observable(requests.Add, responses.Add)); ``` +## Tool Results + +Given the following tool: + +```csharp +MyResult RunTool(string name, string description, string content) { ... } +``` + +You can use the `ToolFactory` and `FindCall` extension method to +locate the function invocation, its outcome and the typed result for inspection: + +```csharp +AIFunction tool = ToolFactory.Create(RunTool); +var options = new ChatOptions +{ + ToolMode = ChatToolMode.RequireSpecific(tool.Name), // 👈 forces the tool to be used + Tools = [tool] +}; + +var response = await client.GetResponseAsync(chat, options); +var result = response.FindCalls(tool).FirstOrDefault(); + +if (result != null) +{ + // Successful tool call + Console.WriteLine($"Args: '{result.Call.Arguments.Count}'"); + MyResult typed = result.Result; +} +else +{ + Console.WriteLine("Tool call not found in response."); +} +``` + +If the typed result is not found, you can also inspect the raw outcomes by finding +untyped calls to the tool and checking their `Outcome.Exception` property: + +```csharp +var result = response.FindCalls(tool).FirstOrDefault(); +if (result.Outcome.Exception is not null) +{ + Console.WriteLine($"Tool call failed: {result.Outcome.Exception.Message}"); +} +else +{ + Console.WriteLine($"Tool call succeeded: {result.Outcome.Result}"); +} +``` ## Console Logging diff --git a/src/AI.Tests/ToolsTests.cs b/src/AI.Tests/ToolsTests.cs new file mode 100644 index 0000000..135e6f3 --- /dev/null +++ b/src/AI.Tests/ToolsTests.cs @@ -0,0 +1,137 @@ +using System.ComponentModel; +using Microsoft.Extensions.AI; +using static ConfigurationExtensions; + +namespace Devlooped.Extensions.AI; + +public class ToolsTests(ITestOutputHelper output) +{ + public record ToolResult(string Name, string Description, string Content); + + [SecretsFact("OPENAI_API_KEY")] + public async Task RunToolResult() + { + var chat = new Chat() + { + { "system", "You make up a tool run by making up a name, description and content based on whatever the user says." }, + { "user", "I want to create an order for a dozen eggs" }, + }; + + var client = new OpenAIChatClient(Configuration["OPENAI_API_KEY"]!, "gpt-4.1", + OpenAI.OpenAIClientOptions.WriteTo(output)) + .AsBuilder() + .UseFunctionInvocation() + .Build(); + + var tool = ToolFactory.Create(RunTool); + var options = new ChatOptions + { + ToolMode = ChatToolMode.RequireSpecific(tool.Name), + Tools = [tool] + }; + + var response = await client.GetResponseAsync(chat, options); + var result = response.FindCalls(tool).FirstOrDefault(); + + Assert.NotNull(result); + Assert.NotNull(result.Call); + Assert.Equal(tool.Name, result.Call.Name); + Assert.NotNull(result.Outcome); + Assert.Null(result.Outcome.Exception); + } + + [SecretsFact("OPENAI_API_KEY")] + public async Task RunToolTerminateResult() + { + var chat = new Chat() + { + { "system", "You make up a tool run by making up a name, description and content based on whatever the user says." }, + { "user", "I want to create an order for a dozen eggs" }, + }; + + var client = new OpenAIChatClient(Configuration["OPENAI_API_KEY"]!, "gpt-4.1", + OpenAI.OpenAIClientOptions.WriteTo(output)) + .AsBuilder() + .UseFunctionInvocation() + .Build(); + + var tool = ToolFactory.Create(RunToolTerminate); + var options = new ChatOptions + { + ToolMode = ChatToolMode.RequireSpecific(tool.Name), + Tools = [tool] + }; + + var response = await client.GetResponseAsync(chat, options); + var result = response.FindCalls(tool).FirstOrDefault(); + + Assert.NotNull(result); + Assert.NotNull(result.Call); + Assert.Equal(tool.Name, result.Call.Name); + Assert.NotNull(result.Outcome); + Assert.Null(result.Outcome.Exception); + } + + [SecretsFact("OPENAI_API_KEY")] + public async Task RunToolExceptionOutcome() + { + var chat = new Chat() + { + { "system", "You make up a tool run by making up a name, description and content based on whatever the user says." }, + { "user", "I want to create an order for a dozen eggs" }, + }; + + var client = new OpenAIChatClient(Configuration["OPENAI_API_KEY"]!, "gpt-4.1", + OpenAI.OpenAIClientOptions.WriteTo(output)) + .AsBuilder() + .UseFunctionInvocation() + .Build(); + + var tool = ToolFactory.Create(RunToolThrows); + var options = new ChatOptions + { + ToolMode = ChatToolMode.RequireSpecific(tool.Name), + Tools = [tool] + }; + + var response = await client.GetResponseAsync(chat, options); + var result = response.FindCalls(tool).FirstOrDefault(); + + Assert.NotNull(result); + Assert.NotNull(result.Call); + Assert.Equal(tool.Name, result.Call.Name); + Assert.NotNull(result.Outcome); + Assert.NotNull(result.Outcome.Exception); + } + + [Description("Runs a tool to provide a result based on user input.")] + ToolResult RunTool( + [Description("The name")] string name, + [Description("The description")] string description, + [Description("The content")] string content) + { + // Simulate running a tool and returning a result + return new ToolResult(name, description, content); + } + + [Description("Runs a tool to provide a result based on user input.")] + ToolResult RunToolTerminate( + [Description("The name")] string name, + [Description("The description")] string description, + [Description("The content")] string content) + { + FunctionInvokingChatClient.CurrentContext?.Terminate = true; + // Simulate running a tool and returning a result + return new ToolResult(name, description, content); + } + + [Description("Runs a tool to provide a result based on user input.")] + ToolResult RunToolThrows( + [Description("The name")] string name, + [Description("The description")] string description, + [Description("The content")] string content) + { + FunctionInvokingChatClient.CurrentContext?.Terminate = true; + throw new ArgumentException("BOOM"); + } +} diff --git a/src/AI/ToolExtensions.cs b/src/AI/ToolExtensions.cs new file mode 100644 index 0000000..8bd6ee5 --- /dev/null +++ b/src/AI/ToolExtensions.cs @@ -0,0 +1,105 @@ +using System.Text.Json; +using Microsoft.Extensions.AI; + +namespace Devlooped.Extensions.AI; + +/// +/// Represents a tool call made by the AI, including the function call content and the result of the function execution. +/// +public record ToolCall(FunctionCallContent Call, FunctionResultContent Outcome); + +/// +/// Represents a tool call made by the AI, including the function call content, the result of the function execution, +/// and the deserialized result of type . +/// +public record ToolCall(FunctionCallContent Call, FunctionResultContent Outcome, TResult Result); + +/// +/// Extensions for inspecting chat messages and responses for tool +/// usage and processing responses. +/// +public static class ToolExtensions +{ + /// + /// Looks for calls to a tool and their outcome. + /// + public static IEnumerable FindCalls(this ChatResponse response, AIFunction tool) + => FindCalls(response.Messages, tool.Name); + + /// + /// Looks for calls to a tool and their outcome. + /// + public static IEnumerable FindCalls(this IEnumerable messages, AIFunction tool) + => FindCalls(messages, tool.Name); + + /// + /// Looks for calls to a tool and their outcome. + /// + public static IEnumerable FindCalls(this IEnumerable messages, string tool) + { + var calls = messages + .Where(x => x.Role == ChatRole.Assistant) + .SelectMany(x => x.Contents) + .OfType() + .Where(x => x.Name == tool) + .ToDictionary(x => x.CallId); + + var results = messages + .Where(x => x.Role == ChatRole.Tool) + .SelectMany(x => x.Contents) + .OfType() + .Where(x => calls.TryGetValue(x.CallId, out var call) && call.Name == tool) + .Select(x => new ToolCall(calls[x.CallId], x)); + + return results; + } + + /// + /// Looks for a user prompt in the chat response messages. + /// + /// + /// In order for this to work, the must have been invoked using + /// the or with a configured + /// with so + /// that the tool result type can be properly inspected. + /// + public static IEnumerable> FindCalls(this ChatResponse response, AIFunction tool) + => FindCalls(response.Messages, tool.Name); + + /// + /// Looks for a user prompt in the chat response messages. + /// + /// + /// In order for this to work, the must have been invoked using + /// the or with a configured + /// with so + /// that the tool result type can be properly inspected. + /// + public static IEnumerable> FindCalls(this IEnumerable messages, AIFunction tool) + => FindCalls(messages, tool.Name); + + /// + /// Looks for a user prompt in the chat response messages. + /// + /// + /// In order for this to work, the must have been invoked using + /// the or with a configured + /// with so + /// that the tool result type can be properly inspected. + /// + public static IEnumerable> FindCalls(this IEnumerable messages, string tool) + { + var calls = FindCalls(messages, tool) + .Where(x => x.Outcome.Result is JsonElement element && + element.ValueKind == JsonValueKind.Object && + element.TryGetProperty("$type", out var type) && + type.GetString() == typeof(TResult).FullName) + .Select(x => new ToolCall( + Call: x.Call, + Outcome: x.Outcome, + Result: JsonSerializer.Deserialize((JsonElement)x.Outcome.Result!, ToolJsonOptions.Default) ?? + throw new InvalidOperationException($"Failed to deserialize result for tool '{tool}' to {typeof(TResult).FullName}."))); + + return calls; + } +} diff --git a/src/AI/ToolFactory.cs b/src/AI/ToolFactory.cs new file mode 100644 index 0000000..994a411 --- /dev/null +++ b/src/AI/ToolFactory.cs @@ -0,0 +1,19 @@ +using Microsoft.Extensions.AI; + +namespace Devlooped.Extensions.AI; + +/// +/// Creates tools for function calling that can leverage the +/// extension methods for locating invocations and their results. +/// +public static class ToolFactory +{ + /// + /// Invokes + /// using the method name following the naming convention and serialization options from . + /// + public static AIFunction Create(Delegate method) + => AIFunctionFactory.Create(method, + ToolJsonOptions.Default.PropertyNamingPolicy!.ConvertName(method.Method.Name), + serializerOptions: ToolJsonOptions.Default); +} diff --git a/src/AI/ToolJsonOptions.cs b/src/AI/ToolJsonOptions.cs new file mode 100644 index 0000000..a4625e3 --- /dev/null +++ b/src/AI/ToolJsonOptions.cs @@ -0,0 +1,34 @@ +using System.Diagnostics; +using System.Text.Json; +using System.Text.Json.Serialization; +using System.Text.Json.Serialization.Metadata; + +namespace Devlooped.Extensions.AI; + +/// +/// Provides a optimized for use with +/// function calling and tools. +/// +public static class ToolJsonOptions +{ + static ToolJsonOptions() => Default.MakeReadOnly(); + + /// + /// Default for function calling and tools. + /// + public static JsonSerializerOptions Default { get; } = new(JsonSerializerDefaults.Web) + { + Converters = + { + new AdditionalPropertiesDictionaryConverter(), + new JsonStringEnumConverter(), + }, + DefaultIgnoreCondition = + JsonIgnoreCondition.WhenWritingDefault | + JsonIgnoreCondition.WhenWritingNull, + Encoder = System.Text.Encodings.Web.JavaScriptEncoder.UnsafeRelaxedJsonEscaping, + PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower, + WriteIndented = Debugger.IsAttached, + TypeInfoResolver = new TypeInjectingResolver(new DefaultJsonTypeInfoResolver()) + }; +} diff --git a/src/AI/TypeInjectingResolver.cs b/src/AI/TypeInjectingResolver.cs new file mode 100644 index 0000000..570d129 --- /dev/null +++ b/src/AI/TypeInjectingResolver.cs @@ -0,0 +1,48 @@ +using System.ComponentModel; +using System.Text.Json; +using System.Text.Json.Serialization.Metadata; + +namespace Devlooped.Extensions.AI; + +/// +/// Extensions for to enable type injection for object types. +/// +[EditorBrowsable(EditorBrowsableState.Never)] +public static class TypeInjectingResolverExtensions +{ + /// + /// Creates a new that injects a $type property into object types. + /// + public static JsonSerializerOptions WithTypeInjection(this JsonSerializerOptions options) + { + if (options.IsReadOnly) + options = new(options); + + options.TypeInfoResolver = new TypeInjectingResolver( + JsonTypeInfoResolver.Combine([.. options.TypeInfoResolverChain])); + + return options; + } +} + +/// +/// A custom that injects a $type property into object types +/// so they can be automatically distinguished during deserialization or inspection. +/// +public class TypeInjectingResolver(IJsonTypeInfoResolver inner) : IJsonTypeInfoResolver +{ + /// + public JsonTypeInfo? GetTypeInfo(Type type, JsonSerializerOptions options) + { + var info = inner.GetTypeInfo(type, options); + // The $type would already be present for polymorphic serialization. + if (info?.Kind == JsonTypeInfoKind.Object && !info.Properties.Any(x => x.Name == "$type")) + { + var prop = info.CreateJsonPropertyInfo(typeof(string), "$type"); + prop.Get = obj => obj.GetType().FullName; + prop.Order = -1000; // Ensure it is serialized first + info.Properties.Add(prop); + } + return info; + } +}