diff --git a/src/Extensions/ConfigurableChatClient.cs b/src/Extensions/ConfigurableChatClient.cs index 7a2c7a2..5abd376 100644 --- a/src/Extensions/ConfigurableChatClient.cs +++ b/src/Extensions/ConfigurableChatClient.cs @@ -1,4 +1,7 @@ -using Devlooped.Extensions.AI.Grok; +using Azure; +using Azure.AI.Inference; +using Azure.AI.OpenAI; +using Devlooped.Extensions.AI.Grok; using Devlooped.Extensions.AI.OpenAI; using Microsoft.Extensions.AI; using Microsoft.Extensions.Configuration; @@ -46,7 +49,7 @@ public IAsyncEnumerable GetStreamingResponseAsync(IEnumerabl IChatClient Configure(IConfigurationSection configSection) { - var options = configSection.Get(); + var options = configSection.Get(); Throw.IfNullOrEmpty(options?.ModelId, $"{configSection}:modelid"); // If there was a custom id, we must validate it didn't change since that's not supported. @@ -74,6 +77,10 @@ IChatClient Configure(IConfigurationSection configSection) IChatClient client = options.Endpoint?.Host == "api.x.ai" ? new GrokChatClient(apikey, options.ModelId, options) + : options.Endpoint?.Host == "ai.azure.com" + ? new ChatCompletionsClient(options.Endpoint, new AzureKeyCredential(apikey), configSection.Get()).AsIChatClient(options.ModelId) + : options.Endpoint?.Host.EndsWith("openai.azure.com") == true + ? new AzureOpenAIChatClient(options.Endpoint, new AzureKeyCredential(apikey), options.ModelId, configSection.Get()) : new OpenAIChatClient(apikey, options.ModelId, options); configure?.Invoke(id, client); @@ -98,7 +105,19 @@ void OnReload(object? state) [LoggerMessage(LogLevel.Information, "ChatClient {Id} configured.")] private partial void LogConfigured(string id); - class ConfigurableChatClientOptions : OpenAIClientOptions + class ConfigurableClientOptions : OpenAIClientOptions + { + public string? ApiKey { get; set; } + public string? ModelId { get; set; } + } + + class ConfigurableInferenceOptions : AzureAIInferenceClientOptions + { + public string? ApiKey { get; set; } + public string? ModelId { get; set; } + } + + class ConfigurableAzureOptions : AzureOpenAIClientOptions { public string? ApiKey { get; set; } public string? ModelId { get; set; } diff --git a/src/Extensions/Extensions.csproj b/src/Extensions/Extensions.csproj index 2242996..48c7b0d 100644 --- a/src/Extensions/Extensions.csproj +++ b/src/Extensions/Extensions.csproj @@ -5,7 +5,8 @@ Preview $(NoWarn);OPENAI001 Devlooped.Extensions.AI - Devlooped.Extensions.AI + $(AssemblyName) + $(AssemblyName) Extensions for Microsoft.Extensions.AI OSMFEULA.txt @@ -14,11 +15,13 @@ - + + + diff --git a/src/Extensions/OpenAI/AzureInferenceChatClient.cs b/src/Extensions/OpenAI/AzureInferenceChatClient.cs new file mode 100644 index 0000000..ad85740 --- /dev/null +++ b/src/Extensions/OpenAI/AzureInferenceChatClient.cs @@ -0,0 +1,51 @@ +using System.Collections.Concurrent; +using Azure; +using Azure.AI.Inference; +using Microsoft.Extensions.AI; + +namespace Devlooped.Extensions.AI.OpenAI; + +/// +/// An implementation for Azure AI Inference that supports per-request model selection. +/// +public class AzureInferenceChatClient : IChatClient +{ + readonly ConcurrentDictionary clients = new(); + + readonly string modelId; + readonly ChatCompletionsClient client; + readonly ChatClientMetadata? metadata; + + /// + /// Initializes the client with the specified API key, model ID, and optional OpenAI client options. + /// + public AzureInferenceChatClient(Uri endpoint, AzureKeyCredential credential, string modelId, AzureAIInferenceClientOptions? options = default) + { + this.modelId = modelId; + + // NOTE: by caching the pipeline, we speed up creation of new chat clients per model, + // since the pipeline will be the same for all of them. + client = new ChatCompletionsClient(endpoint, credential, options); + metadata = client.AsIChatClient(modelId) + .GetService(typeof(ChatClientMetadata)) as ChatClientMetadata; + } + + /// + public Task GetResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellation = default) + => GetChatClient(options?.ModelId ?? modelId).GetResponseAsync(messages, options, cancellation); + + /// + public IAsyncEnumerable GetStreamingResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellation = default) + => GetChatClient(options?.ModelId ?? modelId).GetStreamingResponseAsync(messages, options, cancellation); + + IChatClient GetChatClient(string modelId) => clients.GetOrAdd(modelId, client.AsIChatClient); + + void IDisposable.Dispose() => GC.SuppressFinalize(this); + + /// + public object? GetService(Type serviceType, object? serviceKey = null) => serviceType switch + { + Type t when t == typeof(ChatClientMetadata) => metadata, + _ => null + }; +} diff --git a/src/Extensions/OpenAI/AzureOpenAIChatClient.cs b/src/Extensions/OpenAI/AzureOpenAIChatClient.cs new file mode 100644 index 0000000..8f215b8 --- /dev/null +++ b/src/Extensions/OpenAI/AzureOpenAIChatClient.cs @@ -0,0 +1,68 @@ +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections.Concurrent; +using Azure.AI.OpenAI; +using Microsoft.Extensions.AI; + +namespace Devlooped.Extensions.AI.OpenAI; + +/// +/// An implementation for Azure OpenAI that supports per-request model selection. +/// +public class AzureOpenAIChatClient : IChatClient +{ + readonly ConcurrentDictionary clients = new(); + + readonly Uri endpoint; + readonly string modelId; + readonly ClientPipeline pipeline; + readonly AzureOpenAIClientOptions options; + readonly ChatClientMetadata? metadata; + + /// + /// Initializes the client with the given endpoint, API key, model ID, and optional Azure OpenAI client options. + /// + public AzureOpenAIChatClient(Uri endpoint, ApiKeyCredential credential, string modelId, AzureOpenAIClientOptions? options = default) + { + this.endpoint = endpoint; + this.modelId = modelId; + this.options = options ?? new(); + + // NOTE: by caching the pipeline, we speed up creation of new chat clients per model, + // since the pipeline will be the same for all of them. + var client = new AzureOpenAIClient(endpoint, credential, options); + metadata = client.GetChatClient(modelId) + .AsIChatClient() + .GetService(typeof(ChatClientMetadata)) as ChatClientMetadata; + + metadata = new ChatClientMetadata( + providerName: "azure.ai.openai", + providerUri: metadata?.ProviderUri ?? endpoint, + defaultModelId: metadata?.DefaultModelId ?? modelId); + + pipeline = client.Pipeline; + } + + /// + public Task GetResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellation = default) + => GetChatClient(options?.ModelId ?? modelId).GetResponseAsync(messages, options.SetResponseOptions(), cancellation); + + /// + public IAsyncEnumerable GetStreamingResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellation = default) + => GetChatClient(options?.ModelId ?? modelId).GetStreamingResponseAsync(messages, options.SetResponseOptions(), cancellation); + + IChatClient GetChatClient(string modelId) => clients.GetOrAdd(modelId, model + => new PipelineClient(pipeline, endpoint, options).GetOpenAIResponseClient(modelId).AsIChatClient()); + + void IDisposable.Dispose() => GC.SuppressFinalize(this); + + /// + public object? GetService(Type serviceType, object? serviceKey = null) => serviceType switch + { + Type t when t == typeof(ChatClientMetadata) => metadata, + _ => null + }; + + // Allows creating the base OpenAIClient with a pre-created pipeline. + class PipelineClient(ClientPipeline pipeline, Uri endpoint, AzureOpenAIClientOptions options) : AzureOpenAIClient(pipeline, endpoint, options) { } +} diff --git a/src/Extensions/OpenAI/OpenAIChatClient.cs b/src/Extensions/OpenAI/OpenAIChatClient.cs index 59ab415..cb43665 100644 --- a/src/Extensions/OpenAI/OpenAIChatClient.cs +++ b/src/Extensions/OpenAI/OpenAIChatClient.cs @@ -1,15 +1,13 @@ using System.ClientModel; using System.ClientModel.Primitives; using System.Collections.Concurrent; -using System.Text.Json; using Microsoft.Extensions.AI; using OpenAI; -using OpenAI.Responses; namespace Devlooped.Extensions.AI.OpenAI; /// -/// An implementation for OpenAI. +/// An implementation for OpenAI that supports per-request model selection. /// public class OpenAIChatClient : IChatClient { @@ -39,38 +37,15 @@ public OpenAIChatClient(string apiKey, string modelId, OpenAIClientOptions? opti /// public Task GetResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellation = default) - => GetChatClient(options?.ModelId ?? modelId).GetResponseAsync(messages, SetOptions(options), cancellation); + => GetChatClient(options?.ModelId ?? modelId).GetResponseAsync(messages, options.SetResponseOptions(), cancellation); /// public IAsyncEnumerable GetStreamingResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellation = default) - => GetChatClient(options?.ModelId ?? modelId).GetStreamingResponseAsync(messages, SetOptions(options), cancellation); + => GetChatClient(options?.ModelId ?? modelId).GetStreamingResponseAsync(messages, options.SetResponseOptions(), cancellation); IChatClient GetChatClient(string modelId) => clients.GetOrAdd(modelId, model => new PipelineClient(pipeline, options).GetOpenAIResponseClient(modelId).AsIChatClient()); - static ChatOptions? SetOptions(ChatOptions? options) - { - if (options is null) - return null; - - if (options.ReasoningEffort.HasValue || options.Verbosity.HasValue) - { - options.RawRepresentationFactory = _ => - { - var creation = new ResponseCreationOptions(); - if (options.ReasoningEffort.HasValue) - creation.ReasoningOptions = new ReasoningEffortOptions(options.ReasoningEffort!.Value); - - if (options.Verbosity.HasValue) - creation.TextOptions = new VerbosityOptions(options.Verbosity!.Value); - - return creation; - }; - } - - return options; - } - void IDisposable.Dispose() => GC.SuppressFinalize(this); /// @@ -82,24 +57,4 @@ IChatClient GetChatClient(string modelId) => clients.GetOrAdd(modelId, model // Allows creating the base OpenAIClient with a pre-created pipeline. class PipelineClient(ClientPipeline pipeline, OpenAIClientOptions? options) : OpenAIClient(pipeline, options) { } - - class ReasoningEffortOptions(ReasoningEffort effort) : ResponseReasoningOptions - { - protected override void JsonModelWriteCore(Utf8JsonWriter writer, ModelReaderWriterOptions options) - { - writer.WritePropertyName("effort"u8); - writer.WriteStringValue(effort.ToString().ToLowerInvariant()); - base.JsonModelWriteCore(writer, options); - } - } - - class VerbosityOptions(Verbosity verbosity) : ResponseTextOptions - { - protected override void JsonModelWriteCore(Utf8JsonWriter writer, ModelReaderWriterOptions options) - { - writer.WritePropertyName("verbosity"u8); - writer.WriteStringValue(verbosity.ToString().ToLowerInvariant()); - base.JsonModelWriteCore(writer, options); - } - } } diff --git a/src/Extensions/OpenAI/OpenAIExtensions.cs b/src/Extensions/OpenAI/OpenAIExtensions.cs new file mode 100644 index 0000000..f4e4313 --- /dev/null +++ b/src/Extensions/OpenAI/OpenAIExtensions.cs @@ -0,0 +1,52 @@ +using System.ClientModel.Primitives; +using System.Text.Json; +using Microsoft.Extensions.AI; +using OpenAI.Responses; + +namespace Devlooped.Extensions.AI.OpenAI; + +static class OpenAIExtensions +{ + public static ChatOptions? SetResponseOptions(this ChatOptions? options) + { + if (options is null) + return null; + + if (options.ReasoningEffort.HasValue || options.Verbosity.HasValue) + { + options.RawRepresentationFactory = _ => + { + var creation = new ResponseCreationOptions(); + if (options.ReasoningEffort.HasValue) + creation.ReasoningOptions = new ReasoningEffortOptions(options.ReasoningEffort!.Value); + + if (options.Verbosity.HasValue) + creation.TextOptions = new VerbosityOptions(options.Verbosity!.Value); + + return creation; + }; + } + + return options; + } + + class ReasoningEffortOptions(ReasoningEffort effort) : ResponseReasoningOptions + { + protected override void JsonModelWriteCore(Utf8JsonWriter writer, ModelReaderWriterOptions options) + { + writer.WritePropertyName("effort"u8); + writer.WriteStringValue(effort.ToString().ToLowerInvariant()); + base.JsonModelWriteCore(writer, options); + } + } + + class VerbosityOptions(Verbosity verbosity) : ResponseTextOptions + { + protected override void JsonModelWriteCore(Utf8JsonWriter writer, ModelReaderWriterOptions options) + { + writer.WritePropertyName("verbosity"u8); + writer.WriteStringValue(verbosity.ToString().ToLowerInvariant()); + base.JsonModelWriteCore(writer, options); + } + } +} diff --git a/src/Extensions/UseChatClientsExtensions.cs b/src/Extensions/UseChatClientsExtensions.cs index 2f299e4..e9a3b9b 100644 --- a/src/Extensions/UseChatClientsExtensions.cs +++ b/src/Extensions/UseChatClientsExtensions.cs @@ -21,6 +21,7 @@ public static IServiceCollection UseChatClients(this IServiceCollection services var id = configuration[$"{section}:id"] ?? section[(prefix.Length + 1)..]; var options = configuration.GetRequiredSection(section).Get(); + // We need logging set up for the configurable client to log changes services.AddLogging(); var builder = services.AddKeyedChatClient(id, diff --git a/src/Tests/ConfigurableTests.cs b/src/Tests/ConfigurableClientTests.cs similarity index 76% rename from src/Tests/ConfigurableTests.cs rename to src/Tests/ConfigurableClientTests.cs index 47f0d7f..d3529bc 100644 --- a/src/Tests/ConfigurableTests.cs +++ b/src/Tests/ConfigurableClientTests.cs @@ -163,4 +163,53 @@ public void CanChangeAndSwapProvider() Assert.Equal("xai", client.GetRequiredService().ProviderName); Assert.Equal("grok-4", client.GetRequiredService().DefaultModelId); } + + [Fact] + public void CanConfigureAzureInference() + { + var configuration = new ConfigurationBuilder() + .AddInMemoryCollection(new Dictionary + { + ["ai:clients:chat:modelid"] = "gpt-5", + ["ai:clients:chat:apikey"] = "asdfasdf", + ["ai:clients:chat:endpoint"] = "https://ai.azure.com/.default" + }) + .Build(); + + var services = new ServiceCollection() + .AddSingleton(configuration) + .AddLogging(builder => builder.AddTestOutput(output)) + .UseChatClients(configuration) + .BuildServiceProvider(); + + var client = services.GetRequiredKeyedService("chat"); + + Assert.Equal("azure.ai.inference", client.GetRequiredService().ProviderName); + Assert.Equal("gpt-5", client.GetRequiredService().DefaultModelId); + } + + [Fact] + public void CanConfigureAzureOpenAI() + { + var configuration = new ConfigurationBuilder() + .AddInMemoryCollection(new Dictionary + { + ["ai:clients:chat:modelid"] = "gpt-5", + ["ai:clients:chat:apikey"] = "asdfasdf", + ["ai:clients:chat:endpoint"] = "https://chat.openai.azure.com/", + ["ai:clients:chat:UserAgentApplicationId"] = "myapp/1.0" + }) + .Build(); + + var services = new ServiceCollection() + .AddSingleton(configuration) + .AddLogging(builder => builder.AddTestOutput(output)) + .UseChatClients(configuration) + .BuildServiceProvider(); + + var client = services.GetRequiredKeyedService("chat"); + + Assert.Equal("azure.ai.openai", client.GetRequiredService().ProviderName); + Assert.Equal("gpt-5", client.GetRequiredService().DefaultModelId); + } }