Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/AI.Tests/AI.Tests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
<PropertyGroup>
<TargetFramework>net10.0</TargetFramework>
<NoWarn>OPENAI001;$(NoWarn)</NoWarn>
<LangVersion>Preview</LangVersion>
<EmitCompilerGeneratedFiles>true</EmitCompilerGeneratedFiles>
</PropertyGroup>

Expand Down
93 changes: 61 additions & 32 deletions src/AI.Tests/Extensions/PipelineTestOutput.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,57 +12,86 @@ public static class PipelineTestOutput
/// </summary>
/// <typeparam name="TOptions">The options type to configure for HTTP logging.</typeparam>
/// <param name="pipelineOptions">The options instance to configure.</param>
/// <param name="output">The test output helper to write to.</param>
/// <param name="onRequest">A callback to process the <see cref="JsonNode"/> that was sent.</param>
/// <param name="onResponse">A callback to process the <see cref="JsonNode"/> that was received.</param>
/// <remarks>
/// NOTE: this is the lowst-level logging after all chat pipeline processing has been done.
/// <para>
/// If the options already provide a transport, it will be wrapped with the console
/// logging transport to minimize the impact on existing configurations.
/// </para>
/// </remarks>
public static TOptions UseTestOutput<TOptions>(this TOptions pipelineOptions, ITestOutputHelper output)
public static TOptions WriteTo<TOptions>(this TOptions pipelineOptions, ITestOutputHelper? output = default, Action<JsonNode>? onRequest = default, Action<JsonNode>? onResponse = default)
where TOptions : ClientPipelineOptions
{
pipelineOptions.Transport = new TestPipelineTransport(pipelineOptions.Transport ?? HttpClientPipelineTransport.Shared, output);

pipelineOptions.AddPolicy(new TestOutputPolicy(output ?? NullTestOutputHelper.Default, onRequest, onResponse), PipelinePosition.BeforeTransport);
return pipelineOptions;
}
}

public class TestPipelineTransport(PipelineTransport inner, ITestOutputHelper? output = null) : PipelineTransport
{
static readonly JsonSerializerOptions options = new JsonSerializerOptions(JsonSerializerDefaults.General)
class NullTestOutputHelper : ITestOutputHelper
{
WriteIndented = true,
};

public List<JsonNode> Requests { get; } = [];
public List<JsonNode> Responses { get; } = [];
public static ITestOutputHelper Default { get; } = new NullTestOutputHelper();
NullTestOutputHelper() { }
public void WriteLine(string message) { }
public void WriteLine(string format, params object[] args) { }
}

protected override async ValueTask ProcessCoreAsync(PipelineMessage message)
class TestOutputPolicy(ITestOutputHelper output, Action<JsonNode>? onRequest = default, Action<JsonNode>? onResponse = default) : PipelinePolicy
{
message.BufferResponse = true;
await inner.ProcessAsync(message);
static readonly JsonSerializerOptions options = new JsonSerializerOptions(JsonSerializerDefaults.General)
{
WriteIndented = true,
};

if (message.Request.Content is not null)
public override void Process(PipelineMessage message, IReadOnlyList<PipelinePolicy> pipeline, int currentIndex)
{
using var memory = new MemoryStream();
message.Request.Content.WriteTo(memory);
memory.Position = 0;
using var reader = new StreamReader(memory);
var content = await reader.ReadToEndAsync();
var node = JsonNode.Parse(content);
Requests.Add(node!);
output?.WriteLine(node!.ToJsonString(options));
message.BufferResponse = true;
ProcessNext(message, pipeline, currentIndex);

if (message.Request.Content is not null)
{
using var memory = new MemoryStream();
message.Request.Content.WriteTo(memory);
memory.Position = 0;
using var reader = new StreamReader(memory);
var content = reader.ReadToEnd();
var node = JsonNode.Parse(content);
onRequest?.Invoke(node!);
output?.WriteLine(node!.ToJsonString(options));
}

if (message.Response != null)
{
var node = JsonNode.Parse(message.Response.Content.ToString());
onResponse?.Invoke(node!);
output?.WriteLine(node!.ToJsonString(options));
}
}

if (message.Response != null)
public override async ValueTask ProcessAsync(PipelineMessage message, IReadOnlyList<PipelinePolicy> pipeline, int currentIndex)
{
var node = JsonNode.Parse(message.Response.Content.ToString());
Responses.Add(node!);
output?.WriteLine(node!.ToJsonString(options));
message.BufferResponse = true;
await ProcessNextAsync(message, pipeline, currentIndex);

if (message.Request.Content is not null)
{
using var memory = new MemoryStream();
message.Request.Content.WriteTo(memory);
memory.Position = 0;
using var reader = new StreamReader(memory);
var content = await reader.ReadToEndAsync();
var node = JsonNode.Parse(content);
onRequest?.Invoke(node!);
output?.WriteLine(node!.ToJsonString(options));
}

if (message.Response != null)
{
var node = JsonNode.Parse(message.Response.Content.ToString());
onResponse?.Invoke(node!);
output?.WriteLine(node!.ToJsonString(options));
}
}
}

protected override PipelineMessage CreateMessageCore() => inner.CreateMessage();
protected override void ProcessCore(PipelineMessage message) => inner.Process(message);
}
}
27 changes: 15 additions & 12 deletions src/AI.Tests/GrokTests.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using System.ClientModel.Primitives;
using System.Text.Json.Nodes;
using System.Text.Json.Nodes;
using Microsoft.Extensions.AI;
using static ConfigurationExtensions;

Expand Down Expand Up @@ -49,9 +48,11 @@ public async Task GrokInvokesToolAndSearch()
{ "user", "What's Tesla stock worth today?" },
};

var transport = new TestPipelineTransport(HttpClientPipelineTransport.Shared, output);
var requests = new List<JsonNode>();
var responses = new List<JsonNode>();

var grok = new GrokChatClient(Configuration["XAI_API_KEY"]!, "grok-3", new OpenAI.OpenAIClientOptions() { Transport = transport })
var grok = new GrokChatClient(Configuration["XAI_API_KEY"]!, "grok-3",
new OpenAI.OpenAIClientOptions().WriteTo(output, requests.Add, responses.Add))
.AsBuilder()
.UseFunctionInvocation()
.Build();
Expand All @@ -69,7 +70,7 @@ public async Task GrokInvokesToolAndSearch()
// "search_parameters": {
// "mode": "on"
//}
Assert.All(transport.Requests, x =>
Assert.All(requests, x =>
{
var search = Assert.IsType<JsonObject>(x["search_parameters"]);
Assert.Equal("on", search["mode"]?.GetValue<string>());
Expand All @@ -79,7 +80,7 @@ public async Task GrokInvokesToolAndSearch()
Assert.Contains(response.Messages, x => x.Role == ChatRole.Tool);

// Citations include nasdaq.com at least as a web search source
var node = transport.Responses.LastOrDefault();
var node = responses.LastOrDefault();
Assert.NotNull(node);
var citations = Assert.IsType<JsonArray>(node["citations"], false);
var yahoo = citations.Where(x => x != null).Any(x => x!.ToString().Contains("https://finance.yahoo.com/quote/TSLA/", StringComparison.Ordinal));
Expand All @@ -100,16 +101,18 @@ public async Task GrokInvokesHostedSearchTool()
{ "user", "What's Tesla stock worth today? Search X and the news for latest info." },
};

var transport = new TestPipelineTransport(HttpClientPipelineTransport.Shared, output);
var requests = new List<JsonNode>();
var responses = new List<JsonNode>();

var chat = new GrokChatClient(Configuration["XAI_API_KEY"]!, "grok-3", new OpenAI.OpenAIClientOptions() { Transport = transport });
var grok = new GrokChatClient(Configuration["XAI_API_KEY"]!, "grok-3",
new OpenAI.OpenAIClientOptions().WriteTo(output, requests.Add, responses.Add));

var options = new ChatOptions
{
Tools = [new HostedWebSearchTool()]
};

var response = await chat.GetResponseAsync(messages, options);
var response = await grok.GetResponseAsync(messages, options);
var text = response.Text;

Assert.Contains("TSLA", text);
Expand All @@ -118,15 +121,15 @@ public async Task GrokInvokesHostedSearchTool()
// "search_parameters": {
// "mode": "auto"
//}
Assert.All(transport.Requests, x =>
Assert.All(requests, x =>
{
var search = Assert.IsType<JsonObject>(x["search_parameters"]);
Assert.Equal("auto", search["mode"]?.GetValue<string>());
});

// Citations include nasdaq.com at least as a web search source
Assert.Single(transport.Responses);
var node = transport.Responses[0];
Assert.Single(responses);
var node = responses[0];
Assert.NotNull(node);
var citations = Assert.IsType<JsonArray>(node["citations"], false);
var yahoo = citations.Where(x => x != null).Any(x => x!.ToString().Contains("https://finance.yahoo.com/quote/TSLA/", StringComparison.Ordinal));
Expand Down
67 changes: 67 additions & 0 deletions src/AI.Tests/OpenAITests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
using System.Text.Json.Nodes;
using Microsoft.Extensions.AI;
using static ConfigurationExtensions;

namespace Devlooped.Extensions.AI;

public class OpenAITests(ITestOutputHelper output)
{
[SecretsFact("OPENAI_API_KEY")]
public async Task OpenAISwitchesModel()
{
var messages = new Chat()
{
{ "user", "What products does Tesla make?" },
};

var chat = new OpenAIChatClient(Configuration["OPENAI_API_KEY"]!, "gpt-4.1-nano", new OpenAI.OpenAIClientOptions().WriteTo(output));

var options = new ChatOptions
{
ModelId = "gpt-4.1-mini",
};

var response = await chat.GetResponseAsync(messages, options);

// NOTE: the chat client was requested as grok-3 but the chat options wanted a
// different model and the grok client honors that choice.
Assert.StartsWith("gpt-4.1-mini", response.ModelId);
}

[SecretsFact("OPENAI_API_KEY")]
public async Task OpenAIThinks()
{
var messages = new Chat()
{
{ "system", "You are an intelligent AI assistant that's an expert on financial matters." },
{ "user", "If you have a debt of 100k and accumulate a compounding 5% debt on top of it every year, how long before you are a negative millonaire? (round up to full integer value)" },
};

var requests = new List<JsonNode>();

var chat = new OpenAIChatClient(Configuration["OPENAI_API_KEY"]!, "o3-mini", new OpenAI.OpenAIClientOptions()
.WriteTo(output, requests.Add));

var options = new ChatOptions
{
ModelId = "o4-mini",
ReasoningEffort = ReasoningEffort.Medium
};

var response = await chat.GetResponseAsync(messages, options);

var text = response.Text;

Assert.Contains("48 years", text);
// NOTE: the chat client was requested as grok-3 but the chat options wanted a
// different model and the grok client honors that choice.
Assert.StartsWith("o4-mini", response.ModelId);

// Reasoning should have been set to medium
Assert.All(requests, x =>
{
var search = Assert.IsType<JsonObject>(x["reasoning"]);
Assert.Equal("medium", search["effort"]?.GetValue<string>());
});
}
}
71 changes: 71 additions & 0 deletions src/AI/OpenAI/OpenAIChatClient.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
using System.ClientModel;
using System.ClientModel.Primitives;
using System.Collections.Concurrent;
using Microsoft.Extensions.AI;
using OpenAI;
using OpenAI.Responses;

namespace Devlooped.Extensions.AI;

/// <summary>
/// An <see cref="IChatClient"/> implementation for OpenAI.
/// </summary>
public class OpenAIChatClient : IChatClient
{
readonly ConcurrentDictionary<string, IChatClient> clients = new();
readonly string modelId;
readonly ClientPipeline pipeline;
readonly OpenAIClientOptions? options;

/// <summary>
/// Initializes the client with the specified API key, model ID, and optional OpenAI client options.
/// </summary>
public OpenAIChatClient(string apiKey, string modelId, OpenAIClientOptions? options = default)
{
this.modelId = modelId;
this.options = options;

// NOTE: by caching the pipeline, we speed up creation of new chat clients per model,
// since the pipeline will be the same for all of them.
pipeline = new OpenAIClient(new ApiKeyCredential(apiKey), options).Pipeline;
}

/// <inheritdoc/>
public Task<ChatResponse> GetResponseAsync(IEnumerable<ChatMessage> messages, ChatOptions? options = null, CancellationToken cancellation = default)
=> GetChatClient(options?.ModelId ?? modelId).GetResponseAsync(messages, SetOptions(options), cancellation);

/// <inheritdoc/>
public IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(IEnumerable<ChatMessage> messages, ChatOptions? options = null, CancellationToken cancellation = default)
=> GetChatClient(options?.ModelId ?? modelId).GetStreamingResponseAsync(messages, SetOptions(options), cancellation);

IChatClient GetChatClient(string modelId) => clients.GetOrAdd(modelId, model
=> new PipelineClient(pipeline, options).GetOpenAIResponseClient(modelId).AsIChatClient());

static ChatOptions? SetOptions(ChatOptions? options)
{
if (options is null)
return null;

if (options.ReasoningEffort is ReasoningEffort effort)
{
options.RawRepresentationFactory = _ => new ResponseCreationOptions
{
ReasoningOptions = new ResponseReasoningOptions(effort switch
{
ReasoningEffort.High => ResponseReasoningEffortLevel.High,
ReasoningEffort.Medium => ResponseReasoningEffortLevel.Medium,
_ => ResponseReasoningEffortLevel.Low
})
};
}

return options;
}

void IDisposable.Dispose() { }

public object? GetService(Type serviceType, object? serviceKey = null) => null;

// Allows creating the base OpenAIClient with a pre-created pipeline.
class PipelineClient(ClientPipeline pipeline, OpenAIClientOptions? options) : OpenAIClient(pipeline, options) { }
}
Loading