Skip to content

Commit

Permalink
Allow passing custom HttpClient (#228)
Browse files Browse the repository at this point in the history
## Motivation and Context

For some businesses and their teams, there is often a need for more
detailed and personalized configuration when using HttpClient to make
network requests. This PR aligns the corresponding features to
be on par with the design of Semantic Kernel. Reference:
microsoft/semantic-kernel#743

## Description

Introduce a new optional parameter named `httpClient`. By setting its
default value to `null`, we ensure that the existing user's code remains
unaffected and backward compatible.
  • Loading branch information
qihangnet committed Jan 4, 2024
1 parent 957b9f6 commit 939b2fd
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 38 deletions.
15 changes: 10 additions & 5 deletions extensions/AzureOpenAI/AzureOpenAITextEmbeddingGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Azure.Identity;
Expand All @@ -21,15 +22,17 @@ public class AzureOpenAITextEmbeddingGenerator : ITextEmbeddingGenerator
public AzureOpenAITextEmbeddingGenerator(
AzureOpenAIConfig config,
ITextTokenizer? textTokenizer = null,
ILoggerFactory? loggerFactory = null)
: this(config, textTokenizer, loggerFactory?.CreateLogger<AzureOpenAITextEmbeddingGenerator>())
ILoggerFactory? loggerFactory = null,
HttpClient? httpClient = null)
: this(config, textTokenizer, loggerFactory?.CreateLogger<AzureOpenAITextEmbeddingGenerator>(), httpClient)
{
}

public AzureOpenAITextEmbeddingGenerator(
AzureOpenAIConfig config,
ITextTokenizer? textTokenizer = null,
ILogger<AzureOpenAITextEmbeddingGenerator>? log = null)
ILogger<AzureOpenAITextEmbeddingGenerator>? log = null,
HttpClient? httpClient = null)
{
this._log = log ?? DefaultLogger<AzureOpenAITextEmbeddingGenerator>.Instance;

Expand All @@ -52,15 +55,17 @@ public class AzureOpenAITextEmbeddingGenerator : ITextEmbeddingGenerator
deploymentName: config.Deployment,
modelId: config.Deployment,
endpoint: config.Endpoint,
credential: new DefaultAzureCredential());
credential: new DefaultAzureCredential(),
httpClient: httpClient);
break;

case AzureOpenAIConfig.AuthTypes.APIKey:
this._client = new AzureOpenAITextEmbeddingGenerationService(
deploymentName: config.Deployment,
modelId: config.Deployment,
endpoint: config.Endpoint,
apiKey: config.APIKey);
apiKey: config.APIKey,
httpClient: httpClient);
break;

default:
Expand Down
14 changes: 11 additions & 3 deletions extensions/AzureOpenAI/AzureOpenAITextGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

using System;
using System.Collections.Generic;
using System.Net.Http;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -27,15 +28,17 @@ public class AzureOpenAITextGenerator : ITextGenerator
public AzureOpenAITextGenerator(
AzureOpenAIConfig config,
ITextTokenizer? textTokenizer = null,
ILoggerFactory? loggerFactory = null)
: this(config, textTokenizer, loggerFactory?.CreateLogger<AzureOpenAITextGenerator>())
ILoggerFactory? loggerFactory = null,
HttpClient? httpClient = null)
: this(config, textTokenizer, loggerFactory?.CreateLogger<AzureOpenAITextGenerator>(), httpClient)
{
}

public AzureOpenAITextGenerator(
AzureOpenAIConfig config,
ITextTokenizer? textTokenizer = null,
ILogger<AzureOpenAITextGenerator>? log = null)
ILogger<AzureOpenAITextGenerator>? log = null,
HttpClient? httpClient = null)
{
this._log = log ?? DefaultLogger<AzureOpenAITextGenerator>.Instance;

Expand Down Expand Up @@ -73,6 +76,11 @@ public class AzureOpenAITextGenerator : ITextGenerator
}
};

if (httpClient is not null)
{
options.Transport = new HttpClientTransport(httpClient);
}

switch (config.Auth)
{
case AzureOpenAIConfig.AuthTypes.AzureIdentity:
Expand Down
26 changes: 18 additions & 8 deletions extensions/AzureOpenAI/DependencyInjection.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Net.Http;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.KernelMemory.AI;
Expand All @@ -20,13 +21,15 @@ public static partial class KernelMemoryBuilderExtensions
/// <param name="textTokenizer">Tokenizer used to count tokens sent to the embedding generator</param>
/// <param name="loggerFactory">.NET Logger factory</param>
/// <param name="onlyForRetrieval">Whether to use this embedding generator only during data ingestion, and not for retrieval (search and ask API)</param>
/// <param name="httpClient">Custom <see cref="HttpClient"/> for HTTP requests.</param>
/// <returns>KM builder instance</returns>
public static IKernelMemoryBuilder WithAzureOpenAITextEmbeddingGeneration(
this IKernelMemoryBuilder builder,
AzureOpenAIConfig config,
ITextTokenizer? textTokenizer = null,
ILoggerFactory? loggerFactory = null,
bool onlyForRetrieval = false)
bool onlyForRetrieval = false,
HttpClient? httpClient = null)
{
config.Validate();
textTokenizer ??= new DefaultGPTTokenizer();
Expand All @@ -38,7 +41,8 @@ public static partial class KernelMemoryBuilderExtensions
new AzureOpenAITextEmbeddingGenerator(
config: config,
textTokenizer: textTokenizer,
loggerFactory: loggerFactory));
loggerFactory: loggerFactory,
httpClient));
}

return builder;
Expand All @@ -50,15 +54,17 @@ public static partial class KernelMemoryBuilderExtensions
/// <param name="builder">Kernel Memory builder</param>
/// <param name="config">Azure OpenAI settings</param>
/// <param name="textTokenizer">Tokenizer used to count tokens used by prompts</param>
/// <param name="httpClient">Custom <see cref="HttpClient"/> for HTTP requests.</param>
/// <returns>KM builder instance</returns>
public static IKernelMemoryBuilder WithAzureOpenAITextGeneration(
this IKernelMemoryBuilder builder,
AzureOpenAIConfig config,
ITextTokenizer? textTokenizer = null)
ITextTokenizer? textTokenizer = null,
HttpClient? httpClient = null)
{
config.Validate();
textTokenizer ??= new DefaultGPTTokenizer();
builder.Services.AddAzureOpenAITextGeneration(config, textTokenizer);
builder.Services.AddAzureOpenAITextGeneration(config, textTokenizer, httpClient);
return builder;
}
}
Expand All @@ -68,28 +74,32 @@ public static partial class DependencyInjection
public static IServiceCollection AddAzureOpenAIEmbeddingGeneration(
this IServiceCollection services,
AzureOpenAIConfig config,
ITextTokenizer? textTokenizer = null)
ITextTokenizer? textTokenizer = null,
HttpClient? httpClient = null)
{
config.Validate();
textTokenizer ??= new DefaultGPTTokenizer();
return services
.AddSingleton<ITextEmbeddingGenerator>(serviceProvider => new AzureOpenAITextEmbeddingGenerator(
config,
textTokenizer: textTokenizer,
loggerFactory: serviceProvider.GetService<ILoggerFactory>()));
loggerFactory: serviceProvider.GetService<ILoggerFactory>(),
httpClient));
}

public static IServiceCollection AddAzureOpenAITextGeneration(
this IServiceCollection services,
AzureOpenAIConfig config,
ITextTokenizer? textTokenizer = null)
ITextTokenizer? textTokenizer = null,
HttpClient? httpClient = null)
{
config.Validate();
textTokenizer ??= new DefaultGPTTokenizer();
return services
.AddSingleton<ITextGenerator>(serviceProvider => new AzureOpenAITextGenerator(
config: config,
textTokenizer: textTokenizer,
log: serviceProvider.GetService<ILogger<AzureOpenAITextGenerator>>()));
log: serviceProvider.GetService<ILogger<AzureOpenAITextGenerator>>(),
httpClient: httpClient));
}
}
44 changes: 29 additions & 15 deletions extensions/OpenAI/DependencyInjection.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Net.Http;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.KernelMemory.AI;
Expand All @@ -26,6 +27,7 @@ public static partial class KernelMemoryBuilderExtensions
/// <param name="textEmbeddingTokenizer">Tokenizer used to count tokens sent to the embedding generator</param>
/// <param name="loggerFactory">.NET Logger factory</param>
/// <param name="onlyForRetrieval">Whether to use OpenAI defaults only for ingestion, and not for retrieval (search and ask API)</param>
/// <param name="httpClient">Custom <see cref="HttpClient"/> for HTTP requests.</param>
/// <returns>KM builder instance</returns>
public static IKernelMemoryBuilder WithOpenAIDefaults(
this IKernelMemoryBuilder builder,
Expand All @@ -34,7 +36,8 @@ public static partial class KernelMemoryBuilderExtensions
ITextTokenizer? textGenerationTokenizer = null,
ITextTokenizer? textEmbeddingTokenizer = null,
ILoggerFactory? loggerFactory = null,
bool onlyForRetrieval = false)
bool onlyForRetrieval = false,
HttpClient? httpClient = null)
{
textGenerationTokenizer ??= new DefaultGPTTokenizer();
textEmbeddingTokenizer ??= new DefaultGPTTokenizer();
Expand All @@ -50,15 +53,16 @@ public static partial class KernelMemoryBuilderExtensions
};
openAIConfig.Validate();

builder.Services.AddOpenAITextEmbeddingGeneration(openAIConfig, textEmbeddingTokenizer);
builder.Services.AddOpenAITextGeneration(openAIConfig, textGenerationTokenizer);
builder.Services.AddOpenAITextEmbeddingGeneration(openAIConfig, textEmbeddingTokenizer, httpClient);
builder.Services.AddOpenAITextGeneration(openAIConfig, textGenerationTokenizer, httpClient);

if (!onlyForRetrieval)
{
builder.AddIngestionEmbeddingGenerator(new OpenAITextEmbeddingGenerator(
config: openAIConfig,
textTokenizer: textEmbeddingTokenizer,
loggerFactory: loggerFactory));
loggerFactory: loggerFactory,
httpClient: httpClient));
}

return builder;
Expand All @@ -72,19 +76,21 @@ public static partial class KernelMemoryBuilderExtensions
/// <param name="textGenerationTokenizer">Tokenizer used to count tokens used by prompts</param>
/// <param name="textEmbeddingTokenizer">Tokenizer used to count tokens sent to the embedding generator</param>
/// <param name="onlyForRetrieval">Whether to use OpenAI only for ingestion, not for retrieval (search and ask API)</param>
/// <param name="httpClient">Custom <see cref="HttpClient"/> for HTTP requests.</param>
/// <returns>KM builder instance</returns>
public static IKernelMemoryBuilder WithOpenAI(
this IKernelMemoryBuilder builder,
OpenAIConfig config,
ITextTokenizer? textGenerationTokenizer = null,
ITextTokenizer? textEmbeddingTokenizer = null,
bool onlyForRetrieval = false)
bool onlyForRetrieval = false,
HttpClient? httpClient = null)
{
config.Validate();
textGenerationTokenizer ??= new DefaultGPTTokenizer();
textEmbeddingTokenizer ??= new DefaultGPTTokenizer();

builder.WithOpenAITextEmbeddingGeneration(config, textEmbeddingTokenizer, onlyForRetrieval);
builder.WithOpenAITextEmbeddingGeneration(config, textEmbeddingTokenizer, onlyForRetrieval, httpClient);
builder.WithOpenAITextGeneration(config, textGenerationTokenizer);
return builder;
}
Expand All @@ -96,21 +102,23 @@ public static partial class KernelMemoryBuilderExtensions
/// <param name="config">OpenAI settings</param>
/// <param name="textTokenizer">Tokenizer used to count tokens sent to the embedding generator</param>
/// <param name="onlyForRetrieval">Whether to use OpenAI only for ingestion, not for retrieval (search and ask API)</param>
/// <param name="httpClient">Custom <see cref="HttpClient"/> for HTTP requests.</param>
/// <returns>KM builder instance</returns>
public static IKernelMemoryBuilder WithOpenAITextEmbeddingGeneration(
this IKernelMemoryBuilder builder,
OpenAIConfig config,
ITextTokenizer? textTokenizer = null,
bool onlyForRetrieval = false)
bool onlyForRetrieval = false,
HttpClient? httpClient = null)
{
config.Validate();
textTokenizer ??= new DefaultGPTTokenizer();

builder.Services.AddOpenAITextEmbeddingGeneration(config);
builder.Services.AddOpenAITextEmbeddingGeneration(config, httpClient: httpClient);
if (!onlyForRetrieval)
{
builder.AddIngestionEmbeddingGenerator(
new OpenAITextEmbeddingGenerator(config, textTokenizer, loggerFactory: null));
new OpenAITextEmbeddingGenerator(config, textTokenizer, loggerFactory: null, httpClient));
}

return builder;
Expand All @@ -122,16 +130,18 @@ public static partial class KernelMemoryBuilderExtensions
/// <param name="builder">Kernel Memory builder</param>
/// <param name="config">OpenAI settings</param>
/// <param name="textTokenizer">Tokenizer used to count tokens used by prompts</param>
/// <param name="httpClient">Custom <see cref="HttpClient"/> for HTTP requests.</param>
/// <returns>KM builder instance</returns>
public static IKernelMemoryBuilder WithOpenAITextGeneration(
this IKernelMemoryBuilder builder,
OpenAIConfig config,
ITextTokenizer? textTokenizer = null)
ITextTokenizer? textTokenizer = null,
HttpClient? httpClient = null)
{
config.Validate();
textTokenizer ??= new DefaultGPTTokenizer();

builder.Services.AddOpenAITextGeneration(config, textTokenizer);
builder.Services.AddOpenAITextGeneration(config, textTokenizer, httpClient);
return builder;
}
}
Expand All @@ -141,7 +151,8 @@ public static partial class DependencyInjection
public static IServiceCollection AddOpenAITextEmbeddingGeneration(
this IServiceCollection services,
OpenAIConfig config,
ITextTokenizer? textTokenizer = null)
ITextTokenizer? textTokenizer = null,
HttpClient? httpClient = null)
{
config.Validate();
textTokenizer ??= new DefaultGPTTokenizer();
Expand All @@ -151,13 +162,15 @@ public static partial class DependencyInjection
serviceProvider => new OpenAITextEmbeddingGenerator(
config: config,
textTokenizer: textTokenizer,
loggerFactory: serviceProvider.GetService<ILoggerFactory>()));
loggerFactory: serviceProvider.GetService<ILoggerFactory>(),
httpClient));
}

public static IServiceCollection AddOpenAITextGeneration(
this IServiceCollection services,
OpenAIConfig config,
ITextTokenizer? textTokenizer = null)
ITextTokenizer? textTokenizer = null,
HttpClient? httpClient = null)
{
config.Validate();
textTokenizer ??= new DefaultGPTTokenizer();
Expand All @@ -166,6 +179,7 @@ public static partial class DependencyInjection
.AddSingleton<ITextGenerator, OpenAITextGenerator>(serviceProvider => new OpenAITextGenerator(
config: config,
textTokenizer: textTokenizer,
loggerFactory: serviceProvider.GetService<ILoggerFactory>()));
loggerFactory: serviceProvider.GetService<ILoggerFactory>(),
httpClient));
}
}
12 changes: 8 additions & 4 deletions extensions/OpenAI/OpenAITextEmbeddingGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
Expand All @@ -18,15 +19,17 @@ public class OpenAITextEmbeddingGenerator : ITextEmbeddingGenerator
public OpenAITextEmbeddingGenerator(
OpenAIConfig config,
ITextTokenizer? textTokenizer = null,
ILoggerFactory? loggerFactory = null)
: this(config, textTokenizer, loggerFactory?.CreateLogger<OpenAITextEmbeddingGenerator>())
ILoggerFactory? loggerFactory = null,
HttpClient? httpClient = null)
: this(config, textTokenizer, loggerFactory?.CreateLogger<OpenAITextEmbeddingGenerator>(), httpClient)
{
}

public OpenAITextEmbeddingGenerator(
OpenAIConfig config,
ITextTokenizer? textTokenizer = null,
ILogger<OpenAITextEmbeddingGenerator>? log = null)
ILogger<OpenAITextEmbeddingGenerator>? log = null,
HttpClient? httpClient = null)
{
this._log = log ?? DefaultLogger<OpenAITextEmbeddingGenerator>.Instance;

Expand All @@ -45,7 +48,8 @@ public class OpenAITextEmbeddingGenerator : ITextEmbeddingGenerator
this._client = new OpenAITextEmbeddingGenerationService(
modelId: config.EmbeddingModel,
apiKey: config.APIKey,
organization: config.OrgId);
organization: config.OrgId,
httpClient: httpClient);
}

/// <inheritdoc/>
Expand Down

0 comments on commit 939b2fd

Please sign in to comment.