Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions src/Extensions/ConfigurableChatClient.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
using Devlooped.Extensions.AI.Grok;
using Azure;
using Azure.AI.Inference;
using Azure.AI.OpenAI;
using Devlooped.Extensions.AI.Grok;
using Devlooped.Extensions.AI.OpenAI;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Configuration;
Expand Down Expand Up @@ -46,7 +49,7 @@ public IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(IEnumerabl

IChatClient Configure(IConfigurationSection configSection)
{
var options = configSection.Get<ConfigurableChatClientOptions>();
var options = configSection.Get<ConfigurableClientOptions>();
Throw.IfNullOrEmpty(options?.ModelId, $"{configSection}:modelid");

// If there was a custom id, we must validate it didn't change since that's not supported.
Expand Down Expand Up @@ -74,6 +77,10 @@ IChatClient Configure(IConfigurationSection configSection)

IChatClient client = options.Endpoint?.Host == "api.x.ai"
? new GrokChatClient(apikey, options.ModelId, options)
: options.Endpoint?.Host == "ai.azure.com"
? new ChatCompletionsClient(options.Endpoint, new AzureKeyCredential(apikey), configSection.Get<ConfigurableInferenceOptions>()).AsIChatClient(options.ModelId)
: options.Endpoint?.Host.EndsWith("openai.azure.com") == true
? new AzureOpenAIChatClient(options.Endpoint, new AzureKeyCredential(apikey), options.ModelId, configSection.Get<ConfigurableAzureOptions>())
: new OpenAIChatClient(apikey, options.ModelId, options);

configure?.Invoke(id, client);
Expand All @@ -98,7 +105,19 @@ void OnReload(object? state)
[LoggerMessage(LogLevel.Information, "ChatClient {Id} configured.")]
private partial void LogConfigured(string id);

class ConfigurableChatClientOptions : OpenAIClientOptions
class ConfigurableClientOptions : OpenAIClientOptions
{
public string? ApiKey { get; set; }
public string? ModelId { get; set; }
}

class ConfigurableInferenceOptions : AzureAIInferenceClientOptions
{
public string? ApiKey { get; set; }
public string? ModelId { get; set; }
}

class ConfigurableAzureOptions : AzureOpenAIClientOptions
{
public string? ApiKey { get; set; }
public string? ModelId { get; set; }
Expand Down
7 changes: 5 additions & 2 deletions src/Extensions/Extensions.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
<LangVersion>Preview</LangVersion>
<NoWarn>$(NoWarn);OPENAI001</NoWarn>
<AssemblyName>Devlooped.Extensions.AI</AssemblyName>
<PackageId>Devlooped.Extensions.AI</PackageId>
<RootNamespace>$(AssemblyName)</RootNamespace>
<PackageId>$(AssemblyName)</PackageId>
<Description>Extensions for Microsoft.Extensions.AI</Description>
<PackageLicenseExpression></PackageLicenseExpression>
<PackageLicenseFile>OSMFEULA.txt</PackageLicenseFile>
Expand All @@ -14,11 +15,13 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Microsoft.Extensions.Configuration.Binder" Version="9.0.9" />
<PackageReference Include="Azure.AI.OpenAI" Version="2.5.0-beta.1" />
<PackageReference Include="Microsoft.Extensions.AI.AzureAIInference" Version="9.9.1-preview.1.25474.6" />
<PackageReference Include="NuGetizer" Version="1.3.1" PrivateAssets="all" />
<PackageReference Include="Microsoft.Extensions.AI" Version="9.9.1" />
<PackageReference Include="Microsoft.Extensions.AI.OpenAI" Version="9.9.1-preview.1.25474.6" />
<PackageReference Include="Microsoft.Extensions.Configuration.Abstractions" Version="9.0.9" />
<PackageReference Include="Microsoft.Extensions.Configuration.Binder" Version="9.0.9" />
<PackageReference Include="Microsoft.Extensions.Logging" Version="9.0.9" />
<PackageReference Include="Spectre.Console" Version="0.51.1" />
<PackageReference Include="Spectre.Console.Json" Version="0.51.1" />
Expand Down
51 changes: 51 additions & 0 deletions src/Extensions/OpenAI/AzureInferenceChatClient.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
using System.Collections.Concurrent;
using Azure;
using Azure.AI.Inference;
using Microsoft.Extensions.AI;

namespace Devlooped.Extensions.AI.OpenAI;

/// <summary>
/// An <see cref="IChatClient"/> implementation for Azure AI Inference that supports per-request model selection.
/// </summary>
public class AzureInferenceChatClient : IChatClient
{
readonly ConcurrentDictionary<string, IChatClient> clients = new();

readonly string modelId;
readonly ChatCompletionsClient client;
readonly ChatClientMetadata? metadata;

/// <summary>
/// Initializes the client with the specified API key, model ID, and optional OpenAI client options.
/// </summary>
public AzureInferenceChatClient(Uri endpoint, AzureKeyCredential credential, string modelId, AzureAIInferenceClientOptions? options = default)
{
this.modelId = modelId;

// NOTE: by caching the pipeline, we speed up creation of new chat clients per model,
// since the pipeline will be the same for all of them.
client = new ChatCompletionsClient(endpoint, credential, options);
metadata = client.AsIChatClient(modelId)
.GetService(typeof(ChatClientMetadata)) as ChatClientMetadata;
}

/// <inheritdoc/>
public Task<ChatResponse> GetResponseAsync(IEnumerable<ChatMessage> messages, ChatOptions? options = null, CancellationToken cancellation = default)
=> GetChatClient(options?.ModelId ?? modelId).GetResponseAsync(messages, options, cancellation);

/// <inheritdoc/>
public IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(IEnumerable<ChatMessage> messages, ChatOptions? options = null, CancellationToken cancellation = default)
=> GetChatClient(options?.ModelId ?? modelId).GetStreamingResponseAsync(messages, options, cancellation);

IChatClient GetChatClient(string modelId) => clients.GetOrAdd(modelId, client.AsIChatClient);

void IDisposable.Dispose() => GC.SuppressFinalize(this);

/// <inheritdoc />
public object? GetService(Type serviceType, object? serviceKey = null) => serviceType switch
{
Type t when t == typeof(ChatClientMetadata) => metadata,
_ => null
};
}
68 changes: 68 additions & 0 deletions src/Extensions/OpenAI/AzureOpenAIChatClient.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
using System.ClientModel;
using System.ClientModel.Primitives;
using System.Collections.Concurrent;
using Azure.AI.OpenAI;
using Microsoft.Extensions.AI;

namespace Devlooped.Extensions.AI.OpenAI;

/// <summary>
/// An <see cref="IChatClient"/> implementation for Azure OpenAI that supports per-request model selection.
/// </summary>
public class AzureOpenAIChatClient : IChatClient
{
readonly ConcurrentDictionary<string, IChatClient> clients = new();

readonly Uri endpoint;
readonly string modelId;
readonly ClientPipeline pipeline;
readonly AzureOpenAIClientOptions options;
readonly ChatClientMetadata? metadata;

/// <summary>
/// Initializes the client with the given endpoint, API key, model ID, and optional Azure OpenAI client options.
/// </summary>
public AzureOpenAIChatClient(Uri endpoint, ApiKeyCredential credential, string modelId, AzureOpenAIClientOptions? options = default)
{
this.endpoint = endpoint;
this.modelId = modelId;
this.options = options ?? new();

// NOTE: by caching the pipeline, we speed up creation of new chat clients per model,
// since the pipeline will be the same for all of them.
var client = new AzureOpenAIClient(endpoint, credential, options);
metadata = client.GetChatClient(modelId)
.AsIChatClient()
.GetService(typeof(ChatClientMetadata)) as ChatClientMetadata;

metadata = new ChatClientMetadata(
providerName: "azure.ai.openai",
providerUri: metadata?.ProviderUri ?? endpoint,
defaultModelId: metadata?.DefaultModelId ?? modelId);

pipeline = client.Pipeline;
}

/// <inheritdoc/>
public Task<ChatResponse> GetResponseAsync(IEnumerable<ChatMessage> messages, ChatOptions? options = null, CancellationToken cancellation = default)
=> GetChatClient(options?.ModelId ?? modelId).GetResponseAsync(messages, options.SetResponseOptions(), cancellation);

/// <inheritdoc/>
public IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(IEnumerable<ChatMessage> messages, ChatOptions? options = null, CancellationToken cancellation = default)
=> GetChatClient(options?.ModelId ?? modelId).GetStreamingResponseAsync(messages, options.SetResponseOptions(), cancellation);

IChatClient GetChatClient(string modelId) => clients.GetOrAdd(modelId, model
=> new PipelineClient(pipeline, endpoint, options).GetOpenAIResponseClient(modelId).AsIChatClient());

void IDisposable.Dispose() => GC.SuppressFinalize(this);

/// <inheritdoc />
public object? GetService(Type serviceType, object? serviceKey = null) => serviceType switch
{
Type t when t == typeof(ChatClientMetadata) => metadata,
_ => null
};

// Allows creating the base OpenAIClient with a pre-created pipeline.
class PipelineClient(ClientPipeline pipeline, Uri endpoint, AzureOpenAIClientOptions options) : AzureOpenAIClient(pipeline, endpoint, options) { }
}
51 changes: 3 additions & 48 deletions src/Extensions/OpenAI/OpenAIChatClient.cs
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
using System.ClientModel;
using System.ClientModel.Primitives;
using System.Collections.Concurrent;
using System.Text.Json;
using Microsoft.Extensions.AI;
using OpenAI;
using OpenAI.Responses;

namespace Devlooped.Extensions.AI.OpenAI;

/// <summary>
/// An <see cref="IChatClient"/> implementation for OpenAI.
/// An <see cref="IChatClient"/> implementation for OpenAI that supports per-request model selection.
/// </summary>
public class OpenAIChatClient : IChatClient
{
Expand Down Expand Up @@ -39,38 +37,15 @@ public OpenAIChatClient(string apiKey, string modelId, OpenAIClientOptions? opti

/// <inheritdoc/>
public Task<ChatResponse> GetResponseAsync(IEnumerable<ChatMessage> messages, ChatOptions? options = null, CancellationToken cancellation = default)
=> GetChatClient(options?.ModelId ?? modelId).GetResponseAsync(messages, SetOptions(options), cancellation);
=> GetChatClient(options?.ModelId ?? modelId).GetResponseAsync(messages, options.SetResponseOptions(), cancellation);

/// <inheritdoc/>
public IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(IEnumerable<ChatMessage> messages, ChatOptions? options = null, CancellationToken cancellation = default)
=> GetChatClient(options?.ModelId ?? modelId).GetStreamingResponseAsync(messages, SetOptions(options), cancellation);
=> GetChatClient(options?.ModelId ?? modelId).GetStreamingResponseAsync(messages, options.SetResponseOptions(), cancellation);

IChatClient GetChatClient(string modelId) => clients.GetOrAdd(modelId, model
=> new PipelineClient(pipeline, options).GetOpenAIResponseClient(modelId).AsIChatClient());

static ChatOptions? SetOptions(ChatOptions? options)
{
if (options is null)
return null;

if (options.ReasoningEffort.HasValue || options.Verbosity.HasValue)
{
options.RawRepresentationFactory = _ =>
{
var creation = new ResponseCreationOptions();
if (options.ReasoningEffort.HasValue)
creation.ReasoningOptions = new ReasoningEffortOptions(options.ReasoningEffort!.Value);

if (options.Verbosity.HasValue)
creation.TextOptions = new VerbosityOptions(options.Verbosity!.Value);

return creation;
};
}

return options;
}

void IDisposable.Dispose() => GC.SuppressFinalize(this);

/// <inheritdoc />
Expand All @@ -82,24 +57,4 @@ IChatClient GetChatClient(string modelId) => clients.GetOrAdd(modelId, model

// Allows creating the base OpenAIClient with a pre-created pipeline.
class PipelineClient(ClientPipeline pipeline, OpenAIClientOptions? options) : OpenAIClient(pipeline, options) { }

class ReasoningEffortOptions(ReasoningEffort effort) : ResponseReasoningOptions
{
protected override void JsonModelWriteCore(Utf8JsonWriter writer, ModelReaderWriterOptions options)
{
writer.WritePropertyName("effort"u8);
writer.WriteStringValue(effort.ToString().ToLowerInvariant());
base.JsonModelWriteCore(writer, options);
}
}

class VerbosityOptions(Verbosity verbosity) : ResponseTextOptions
{
protected override void JsonModelWriteCore(Utf8JsonWriter writer, ModelReaderWriterOptions options)
{
writer.WritePropertyName("verbosity"u8);
writer.WriteStringValue(verbosity.ToString().ToLowerInvariant());
base.JsonModelWriteCore(writer, options);
}
}
}
52 changes: 52 additions & 0 deletions src/Extensions/OpenAI/OpenAIExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
using System.ClientModel.Primitives;
using System.Text.Json;
using Microsoft.Extensions.AI;
using OpenAI.Responses;

namespace Devlooped.Extensions.AI.OpenAI;

static class OpenAIExtensions
{
public static ChatOptions? SetResponseOptions(this ChatOptions? options)
{
if (options is null)
return null;

if (options.ReasoningEffort.HasValue || options.Verbosity.HasValue)
{
options.RawRepresentationFactory = _ =>
{
var creation = new ResponseCreationOptions();
if (options.ReasoningEffort.HasValue)
creation.ReasoningOptions = new ReasoningEffortOptions(options.ReasoningEffort!.Value);

if (options.Verbosity.HasValue)
creation.TextOptions = new VerbosityOptions(options.Verbosity!.Value);

return creation;
};
}

return options;
}

class ReasoningEffortOptions(ReasoningEffort effort) : ResponseReasoningOptions
{
protected override void JsonModelWriteCore(Utf8JsonWriter writer, ModelReaderWriterOptions options)
{
writer.WritePropertyName("effort"u8);
writer.WriteStringValue(effort.ToString().ToLowerInvariant());
base.JsonModelWriteCore(writer, options);
}
}

class VerbosityOptions(Verbosity verbosity) : ResponseTextOptions
{
protected override void JsonModelWriteCore(Utf8JsonWriter writer, ModelReaderWriterOptions options)
{
writer.WritePropertyName("verbosity"u8);
writer.WriteStringValue(verbosity.ToString().ToLowerInvariant());
base.JsonModelWriteCore(writer, options);
}
}
}
1 change: 1 addition & 0 deletions src/Extensions/UseChatClientsExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ public static IServiceCollection UseChatClients(this IServiceCollection services
var id = configuration[$"{section}:id"] ?? section[(prefix.Length + 1)..];

var options = configuration.GetRequiredSection(section).Get<ChatClientOptions>();
// We need logging set up for the configurable client to log changes
services.AddLogging();

var builder = services.AddKeyedChatClient(id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,53 @@ public void CanChangeAndSwapProvider()
Assert.Equal("xai", client.GetRequiredService<ChatClientMetadata>().ProviderName);
Assert.Equal("grok-4", client.GetRequiredService<ChatClientMetadata>().DefaultModelId);
}

[Fact]
public void CanConfigureAzureInference()
{
var configuration = new ConfigurationBuilder()
.AddInMemoryCollection(new Dictionary<string, string?>
{
["ai:clients:chat:modelid"] = "gpt-5",
["ai:clients:chat:apikey"] = "asdfasdf",
["ai:clients:chat:endpoint"] = "https://ai.azure.com/.default"
})
.Build();

var services = new ServiceCollection()
.AddSingleton<IConfiguration>(configuration)
.AddLogging(builder => builder.AddTestOutput(output))
.UseChatClients(configuration)
.BuildServiceProvider();

var client = services.GetRequiredKeyedService<IChatClient>("chat");

Assert.Equal("azure.ai.inference", client.GetRequiredService<ChatClientMetadata>().ProviderName);
Assert.Equal("gpt-5", client.GetRequiredService<ChatClientMetadata>().DefaultModelId);
}

[Fact]
public void CanConfigureAzureOpenAI()
{
var configuration = new ConfigurationBuilder()
.AddInMemoryCollection(new Dictionary<string, string?>
{
["ai:clients:chat:modelid"] = "gpt-5",
["ai:clients:chat:apikey"] = "asdfasdf",
["ai:clients:chat:endpoint"] = "https://chat.openai.azure.com/",
["ai:clients:chat:UserAgentApplicationId"] = "myapp/1.0"
})
.Build();

var services = new ServiceCollection()
.AddSingleton<IConfiguration>(configuration)
.AddLogging(builder => builder.AddTestOutput(output))
.UseChatClients(configuration)
.BuildServiceProvider();

var client = services.GetRequiredKeyedService<IChatClient>("chat");

Assert.Equal("azure.ai.openai", client.GetRequiredService<ChatClientMetadata>().ProviderName);
Assert.Equal("gpt-5", client.GetRequiredService<ChatClientMetadata>().DefaultModelId);
}
}