Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.Net: Allow Add{Azure}OpenAI methods to resolve OpenAIClient from DI #4555

Merged
merged 1 commit into from
Jan 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -160,25 +160,24 @@ public static class OpenAIServiceCollectionExtensions
/// </summary>
/// <param name="builder">The <see cref="IKernelBuilder"/> instance to augment.</param>
/// <param name="deploymentName">Azure OpenAI deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource</param>
/// <param name="openAIClient">Custom <see cref="OpenAIClient"/>.</param>
/// <param name="openAIClient"><see cref="OpenAIClient"/> to use for the service. If null, one must be available in the service provider when this service is resolved.</param>
/// <param name="serviceId">A local identifier for the given AI service</param>
/// <param name="modelId">Model identifier, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart</param>
/// <returns>The same instance as <paramref name="builder"/>.</returns>
public static IKernelBuilder AddAzureOpenAITextGeneration(
this IKernelBuilder builder,
string deploymentName,
OpenAIClient openAIClient,
OpenAIClient? openAIClient = null,
string? serviceId = null,
string? modelId = null)
{
Verify.NotNull(builder);
Verify.NotNullOrWhiteSpace(deploymentName);
Verify.NotNull(openAIClient);

builder.Services.AddKeyedSingleton<ITextGenerationService>(serviceId, (serviceProvider, _) =>
new AzureOpenAITextGenerationService(
deploymentName,
openAIClient,
openAIClient ?? serviceProvider.GetRequiredService<OpenAIClient>(),
modelId,
serviceProvider.GetService<ILoggerFactory>()));

Expand All @@ -190,25 +189,24 @@ public static class OpenAIServiceCollectionExtensions
/// </summary>
/// <param name="services">The <see cref="IServiceCollection"/> instance to augment.</param>
/// <param name="deploymentName">Azure OpenAI deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource</param>
/// <param name="openAIClient">Custom <see cref="OpenAIClient"/>.</param>
/// <param name="openAIClient"><see cref="OpenAIClient"/> to use for the service. If null, one must be available in the service provider when this service is resolved.</param>
/// <param name="serviceId">A local identifier for the given AI service</param>
/// <param name="modelId">Model identifier, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart</param>
/// <returns>The same instance as <paramref name="services"/>.</returns>
public static IServiceCollection AddAzureOpenAITextGeneration(
this IServiceCollection services,
string deploymentName,
OpenAIClient openAIClient,
OpenAIClient? openAIClient = null,
string? serviceId = null,
string? modelId = null)
{
Verify.NotNull(services);
Verify.NotNullOrWhiteSpace(deploymentName);
Verify.NotNull(openAIClient);

return services.AddKeyedSingleton<ITextGenerationService>(serviceId, (serviceProvider, _) =>
new AzureOpenAITextGenerationService(
deploymentName,
openAIClient,
openAIClient ?? serviceProvider.GetRequiredService<OpenAIClient>(),
modelId,
serviceProvider.GetService<ILoggerFactory>()));
}
Expand Down Expand Up @@ -280,23 +278,22 @@ public static class OpenAIServiceCollectionExtensions
/// </summary>
/// <param name="builder">The <see cref="IKernelBuilder"/> instance to augment.</param>
/// <param name="modelId">OpenAI model name, see https://platform.openai.com/docs/models</param>
/// <param name="openAIClient">Custom <see cref="OpenAIClient"/>.</param>
/// <param name="openAIClient"><see cref="OpenAIClient"/> to use for the service. If null, one must be available in the service provider when this service is resolved.</param>
/// <param name="serviceId">A local identifier for the given AI service</param>
/// <returns>The same instance as <paramref name="builder"/>.</returns>
public static IKernelBuilder AddOpenAITextGeneration(
this IKernelBuilder builder,
string modelId,
OpenAIClient openAIClient,
OpenAIClient? openAIClient = null,
string? serviceId = null)
{
Verify.NotNull(builder);
Verify.NotNullOrWhiteSpace(modelId);
Verify.NotNull(openAIClient);

builder.Services.AddKeyedSingleton<ITextGenerationService>(serviceId, (serviceProvider, _) =>
new OpenAITextGenerationService(
modelId,
openAIClient,
openAIClient ?? serviceProvider.GetRequiredService<OpenAIClient>(),
serviceProvider.GetService<ILoggerFactory>()));

return builder;
Expand All @@ -307,22 +304,21 @@ public static class OpenAIServiceCollectionExtensions
/// </summary>
/// <param name="services">The <see cref="IServiceCollection"/> instance to augment.</param>
/// <param name="modelId">OpenAI model name, see https://platform.openai.com/docs/models</param>
/// <param name="openAIClient">Custom <see cref="OpenAIClient"/>.</param>
/// <param name="openAIClient"><see cref="OpenAIClient"/> to use for the service. If null, one must be available in the service provider when this service is resolved.</param>
/// <param name="serviceId">A local identifier for the given AI service</param>
/// <returns>The same instance as <paramref name="services"/>.</returns>
public static IServiceCollection AddOpenAITextGeneration(this IServiceCollection services,
string modelId,
OpenAIClient openAIClient,
OpenAIClient? openAIClient = null,
string? serviceId = null)
{
Verify.NotNull(services);
Verify.NotNullOrWhiteSpace(modelId);
Verify.NotNull(openAIClient);

return services.AddKeyedSingleton<ITextGenerationService>(serviceId, (serviceProvider, _) =>
new OpenAITextGenerationService(
modelId,
openAIClient,
openAIClient ?? serviceProvider.GetRequiredService<OpenAIClient>(),
serviceProvider.GetService<ILoggerFactory>()));
}

Expand Down Expand Up @@ -479,26 +475,25 @@ public static class OpenAIServiceCollectionExtensions
/// </summary>
/// <param name="builder">The <see cref="IKernelBuilder"/> instance to augment.</param>
/// <param name="deploymentName">Azure OpenAI deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource</param>
/// <param name="openAIClient">Custom <see cref="OpenAIClient"/>.</param>
/// <param name="openAIClient"><see cref="OpenAIClient"/> to use for the service. If null, one must be available in the service provider when this service is resolved.</param>
/// <param name="serviceId">A local identifier for the given AI service</param>
/// <param name="modelId">Model identifier, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart</param>
/// <returns>The same instance as <paramref name="builder"/>.</returns>
[Experimental("SKEXP0011")]
public static IKernelBuilder AddAzureOpenAITextEmbeddingGeneration(
this IKernelBuilder builder,
string deploymentName,
OpenAIClient openAIClient,
OpenAIClient? openAIClient = null,
string? serviceId = null,
string? modelId = null)
{
Verify.NotNull(builder);
Verify.NotNullOrWhiteSpace(deploymentName);
Verify.NotNull(openAIClient);

builder.Services.AddKeyedSingleton<ITextEmbeddingGenerationService>(serviceId, (serviceProvider, _) =>
new AzureOpenAITextEmbeddingGenerationService(
deploymentName,
openAIClient,
openAIClient ?? serviceProvider.GetRequiredService<OpenAIClient>(),
modelId,
serviceProvider.GetService<ILoggerFactory>()));

Expand All @@ -510,26 +505,25 @@ public static class OpenAIServiceCollectionExtensions
/// </summary>
/// <param name="services">The <see cref="IServiceCollection"/> instance to augment.</param>
/// <param name="deploymentName">Azure OpenAI deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource</param>
/// <param name="openAIClient">Custom <see cref="OpenAIClient"/>.</param>
/// <param name="openAIClient"><see cref="OpenAIClient"/> to use for the service. If null, one must be available in the service provider when this service is resolved.</param>
/// <param name="serviceId">A local identifier for the given AI service</param>
/// <param name="modelId">Model identifier, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart</param>
/// <returns>The same instance as <paramref name="services"/>.</returns>
[Experimental("SKEXP0011")]
public static IServiceCollection AddAzureOpenAITextEmbeddingGeneration(
this IServiceCollection services,
string deploymentName,
OpenAIClient openAIClient,
OpenAIClient? openAIClient = null,
string? serviceId = null,
string? modelId = null)
{
Verify.NotNull(services);
Verify.NotNullOrWhiteSpace(deploymentName);
Verify.NotNull(openAIClient);

return services.AddKeyedSingleton<ITextEmbeddingGenerationService>(serviceId, (serviceProvider, _) =>
new AzureOpenAITextEmbeddingGenerationService(
deploymentName,
openAIClient,
openAIClient ?? serviceProvider.GetRequiredService<OpenAIClient>(),
modelId,
serviceProvider.GetService<ILoggerFactory>()));
}
Expand Down Expand Up @@ -603,24 +597,23 @@ public static class OpenAIServiceCollectionExtensions
/// </summary>
/// <param name="builder">The <see cref="IServiceCollection"/> instance to augment.</param>
/// <param name="modelId">OpenAI model name, see https://platform.openai.com/docs/models</param>
/// <param name="openAIClient">Custom <see cref="OpenAIClient"/>.</param>
/// <param name="openAIClient"><see cref="OpenAIClient"/> to use for the service. If null, one must be available in the service provider when this service is resolved.</param>
/// <param name="serviceId">A local identifier for the given AI service</param>
/// <returns>The same instance as <paramref name="builder"/>.</returns>
[Experimental("SKEXP0011")]
public static IKernelBuilder AddOpenAITextEmbeddingGeneration(
this IKernelBuilder builder,
string modelId,
OpenAIClient openAIClient,
OpenAIClient? openAIClient = null,
string? serviceId = null)
{
Verify.NotNull(builder);
Verify.NotNullOrWhiteSpace(modelId);
Verify.NotNull(openAIClient);

builder.Services.AddKeyedSingleton<ITextEmbeddingGenerationService>(serviceId, (serviceProvider, _) =>
new OpenAITextEmbeddingGenerationService(
modelId,
openAIClient,
openAIClient ?? serviceProvider.GetRequiredService<OpenAIClient>(),
serviceProvider.GetService<ILoggerFactory>()));

return builder;
Expand All @@ -631,23 +624,22 @@ public static class OpenAIServiceCollectionExtensions
/// </summary>
/// <param name="services">The <see cref="IServiceCollection"/> instance to augment.</param>
/// <param name="modelId">The OpenAI model id.</param>
/// <param name="openAIClient">Custom <see cref="OpenAIClient"/>.</param>
/// <param name="openAIClient"><see cref="OpenAIClient"/> to use for the service. If null, one must be available in the service provider when this service is resolved.</param>
/// <param name="serviceId">A local identifier for the given AI service</param>
/// <returns>The same instance as <paramref name="services"/>.</returns>
[Experimental("SKEXP0011")]
public static IServiceCollection AddOpenAITextEmbeddingGeneration(this IServiceCollection services,
string modelId,
OpenAIClient openAIClient,
OpenAIClient? openAIClient = null,
string? serviceId = null)
{
Verify.NotNull(services);
Verify.NotNullOrWhiteSpace(modelId);
Verify.NotNull(openAIClient);

return services.AddKeyedSingleton<ITextEmbeddingGenerationService>(serviceId, (serviceProvider, _) =>
new OpenAITextEmbeddingGenerationService(
modelId,
openAIClient,
openAIClient ?? serviceProvider.GetRequiredService<OpenAIClient>(),
serviceProvider.GetService<ILoggerFactory>()));
}

Expand Down Expand Up @@ -820,23 +812,22 @@ public static class OpenAIServiceCollectionExtensions
/// </summary>
/// <param name="builder">The <see cref="IKernelBuilder"/> instance to augment.</param>
/// <param name="deploymentName">Azure OpenAI deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource</param>
/// <param name="openAIClient">Custom <see cref="OpenAIClient"/>.</param>
/// <param name="openAIClient"><see cref="OpenAIClient"/> to use for the service. If null, one must be available in the service provider when this service is resolved.</param>
/// <param name="serviceId">A local identifier for the given AI service</param>
/// <param name="modelId">Model identifier, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart</param>
/// <returns>The same instance as <paramref name="builder"/>.</returns>
public static IKernelBuilder AddAzureOpenAIChatCompletion(
this IKernelBuilder builder,
string deploymentName,
OpenAIClient openAIClient,
OpenAIClient? openAIClient = null,
string? serviceId = null,
string? modelId = null)
{
Verify.NotNull(builder);
Verify.NotNullOrWhiteSpace(deploymentName);
Verify.NotNull(openAIClient);

Func<IServiceProvider, object?, AzureOpenAIChatCompletionService> factory = (serviceProvider, _) =>
new(deploymentName, openAIClient, modelId, serviceProvider.GetService<ILoggerFactory>());
new(deploymentName, openAIClient ?? serviceProvider.GetRequiredService<OpenAIClient>(), modelId, serviceProvider.GetService<ILoggerFactory>());

builder.Services.AddKeyedSingleton<IChatCompletionService>(serviceId, factory);
builder.Services.AddKeyedSingleton<ITextGenerationService>(serviceId, factory);
Expand All @@ -849,23 +840,22 @@ public static class OpenAIServiceCollectionExtensions
/// </summary>
/// <param name="services">The <see cref="IServiceCollection"/> instance to augment.</param>
/// <param name="deploymentName">Azure OpenAI deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource</param>
/// <param name="openAIClient">Custom <see cref="OpenAIClient"/>.</param>
/// <param name="openAIClient"><see cref="OpenAIClient"/> to use for the service. If null, one must be available in the service provider when this service is resolved.</param>
/// <param name="serviceId">A local identifier for the given AI service</param>
/// <param name="modelId">Model identifier, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart</param>
/// <returns>The same instance as <paramref name="services"/>.</returns>
public static IServiceCollection AddAzureOpenAIChatCompletion(
this IServiceCollection services,
string deploymentName,
OpenAIClient openAIClient,
OpenAIClient? openAIClient = null,
string? serviceId = null,
string? modelId = null)
{
Verify.NotNull(services);
Verify.NotNullOrWhiteSpace(deploymentName);
Verify.NotNull(openAIClient);

Func<IServiceProvider, object?, AzureOpenAIChatCompletionService> factory = (serviceProvider, _) =>
new(deploymentName, openAIClient, modelId, serviceProvider.GetService<ILoggerFactory>());
new(deploymentName, openAIClient ?? serviceProvider.GetRequiredService<OpenAIClient>(), modelId, serviceProvider.GetService<ILoggerFactory>());

services.AddKeyedSingleton<IChatCompletionService>(serviceId, factory);
services.AddKeyedSingleton<ITextGenerationService>(serviceId, factory);
Expand Down Expand Up @@ -1006,21 +996,20 @@ public static class OpenAIServiceCollectionExtensions
/// </summary>
/// <param name="builder">The <see cref="IKernelBuilder"/> instance to augment.</param>
/// <param name="modelId">OpenAI model id</param>
/// <param name="openAIClient">Custom <see cref="OpenAIClient"/> for HTTP requests.</param>
/// <param name="openAIClient"><see cref="OpenAIClient"/> to use for the service. If null, one must be available in the service provider when this service is resolved.</param>
/// <param name="serviceId">A local identifier for the given AI service</param>
/// <returns>The same instance as <paramref name="builder"/>.</returns>
public static IKernelBuilder AddOpenAIChatCompletion(
this IKernelBuilder builder,
string modelId,
OpenAIClient openAIClient,
OpenAIClient? openAIClient = null,
string? serviceId = null)
{
Verify.NotNull(builder);
Verify.NotNullOrWhiteSpace(modelId);
Verify.NotNull(openAIClient);

Func<IServiceProvider, object?, OpenAIChatCompletionService> factory = (serviceProvider, _) =>
new(modelId, openAIClient, serviceProvider.GetService<ILoggerFactory>());
new(modelId, openAIClient ?? serviceProvider.GetRequiredService<OpenAIClient>(), serviceProvider.GetService<ILoggerFactory>());

builder.Services.AddKeyedSingleton<IChatCompletionService>(serviceId, factory);
builder.Services.AddKeyedSingleton<ITextGenerationService>(serviceId, factory);
Expand All @@ -1033,20 +1022,19 @@ public static class OpenAIServiceCollectionExtensions
/// </summary>
/// <param name="services">The <see cref="IServiceCollection"/> instance to augment.</param>
/// <param name="modelId">OpenAI model id</param>
/// <param name="openAIClient">Custom <see cref="OpenAIClient"/> for HTTP requests.</param>
/// <param name="openAIClient"><see cref="OpenAIClient"/> to use for the service. If null, one must be available in the service provider when this service is resolved.</param>
/// <param name="serviceId">A local identifier for the given AI service</param>
/// <returns>The same instance as <paramref name="services"/>.</returns>
public static IServiceCollection AddOpenAIChatCompletion(this IServiceCollection services,
string modelId,
OpenAIClient openAIClient,
OpenAIClient? openAIClient = null,
string? serviceId = null)
{
Verify.NotNull(services);
Verify.NotNullOrWhiteSpace(modelId);
Verify.NotNull(openAIClient);

Func<IServiceProvider, object?, OpenAIChatCompletionService> factory = (serviceProvider, _) =>
new(modelId, openAIClient, serviceProvider.GetService<ILoggerFactory>());
new(modelId, openAIClient ?? serviceProvider.GetRequiredService<OpenAIClient>(), serviceProvider.GetService<ILoggerFactory>());

services.AddKeyedSingleton<IChatCompletionService>(serviceId, factory);
services.AddKeyedSingleton<ITextGenerationService>(serviceId, factory);
Expand Down
Loading