Skip to content

Commit

Permalink
Allow passing custom HttpClient (microsoft#743)
Browse files Browse the repository at this point in the history
### Motivation and Context
My company network needs more custom set up in HttpClient to make calls,
but SK OpenAI code only allows me to pass a hander delegate factory. I
cannot use SK without more settings in HttpClient and I want to pass in
my own instances anyway so I can manage them from my application.

I think I am not the only one to need this:
 - microsoft#527
 - microsoft#133

See
https://learn.microsoft.com/en-us/dotnet/fundamentals/networking/http/httpclient-guidelines.

### Description
Changes included:
- The OpenAI and AzureOpenAI connectors take optional HttpClient? in the
constructor methods instead of IDelegatingHandlerFactory
- Note first approach was adding two new constructor methods to each
type so you could choose. But two constructors could not be there
together with optional HttpClient? in one and optional
IDelegatingHandlerFactory? in the other because it made calls ambiguous.
Then I tried making HttpClient not optional, but that made the parameter
order different because in some there are also other optional parameters
like openai `string? organization = null`. So the not optional
HttpClient would have to come BEFORE optional openai organization in one
constructor, but the optional IDelegatingHandlerFactory comes AFTER
optional openai organization in another constructor. In the end it
seemed better to just take HttpClient and do the custom client creation
in the extension methods. I hope this is okay I know it is a bigger
change than I planned.
- IDelegatingHandlerFactory is not being used as a factory in these Open
AI connector classes anyway. They use it one to create an HttpClient
instance in the constructor method anyway and never call it again so
it's just as easy to pass in the HttpClient you want.
- Updated the extension methods in `KernelConfigOpenAIExtensions` so
that they can optionally pass in a custom HttpClient and ILogger. If
null then it uses the same defaults as before.
- There was some mix use of `Logger log` and `ILogger logger` in
different places so on the parts I changed I set these all to "logger"
because it makes it more clear that it is a object and not a "log"
method. I can undo if you don't like this.
- Changed "credentials" to "credential" to match this is one
`TokenCredenial` not a credential collection. I can undo if you don't
like this.
- Made the default HttpClientHandler static and shared for new
HttpClients with `disposeHandler: false`
dotnet/runtime#16255
 - Updated a sample Program
 - I can add more tests if required

### Contribution Checklist
<!-- Before submitting this PR, please make sure: -->
- [X] The code builds clean without any errors or warnings
- [X] The PR follows SK Contribution Guidelines
(https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
- [X] The code follows the .NET coding conventions
(https://learn.microsoft.com/dotnet/csharp/fundamentals/coding-style/coding-conventions)
verified with `dotnet format`
- [X] All unit tests pass, and I have added new tests where possible
- Most tests pass but it skips some because I don't have keys for all
these model types so I hope your github build will run this and double
check.
- [ ] I didn't break anyone 😄
- I changed the constructor method args in the OpenAI code, and the will
break your users that are creating these open ai classes using their
constructors directly. If the users are using the extension methods
(AddAzureOpenAI) then they should not have problems.
- If you want I can put back all of the constructor methods so they can
be created with HttpClient or IDelegatingHandlerFactory but with
parameters orders different as explained
- I don't think those classes need to be aware of
IDelegatingHandlerFactory or have it included in the constructor
methods, since the extension method or calling code can always convert
to HttpClient before calling.
  • Loading branch information
carlos-the-ai committed May 1, 2023
1 parent 84d43ad commit b44707d
Show file tree
Hide file tree
Showing 13 changed files with 274 additions and 139 deletions.
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Net.Http;
using Azure;
using Azure.AI.OpenAI;
using Azure.Core;
using Azure.Core.Pipeline;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.Diagnostics;
using Microsoft.SemanticKernel.Reliability;

namespace Microsoft.SemanticKernel.Connectors.AI.OpenAI.AzureSdk;

Expand All @@ -23,27 +24,26 @@ public abstract class AzureOpenAIClientBase : ClientBase
/// <param name="modelId">Azure OpenAI model ID or deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource</param>
/// <param name="endpoint">Azure OpenAI deployment URL, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart</param>
/// <param name="apiKey">Azure OpenAI API key, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart</param>
/// <param name="handlerFactory">Retry handler factory for HTTP requests.</param>
/// <param name="log">Application logger</param>
/// <param name="httpClient">Custom <see cref="HttpClient"/> for HTTP requests.</param>
/// <param name="logger">Application logger</param>
protected AzureOpenAIClientBase(
string modelId,
string endpoint,
string apiKey,
IDelegatingHandlerFactory? handlerFactory = null,
ILogger? log = null)
HttpClient? httpClient = null,
ILogger? logger = null)
{
Verify.NotNullOrWhiteSpace(modelId);
Verify.NotNullOrWhiteSpace(endpoint);
Verify.StartsWith(endpoint, "https://", "The Azure OpenAI endpoint must start with 'https://'");
Verify.NotNullOrWhiteSpace(apiKey);

var options = new OpenAIClientOptions();
// TODO: reimplement
// Doesn't work
// if (handlerFactory != null)
// {
// options.Transport = new HttpClientTransport(handlerFactory.Create(log));
// }

if (httpClient != null)
{
options.Transport = new HttpClientTransport(httpClient);
};

this.ModelId = modelId;
this.Client = new OpenAIClient(new Uri(endpoint), new AzureKeyCredential(apiKey), options);
Expand All @@ -54,29 +54,27 @@ public abstract class AzureOpenAIClientBase : ClientBase
/// </summary>
/// <param name="modelId">Azure OpenAI model ID or deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource</param>
/// <param name="endpoint">Azure OpenAI deployment URL, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart</param>
/// <param name="credentials">Token credentials, e.g. DefaultAzureCredential, ManagedIdentityCredential, EnvironmentCredential, etc.</param>
/// <param name="handlerFactory">Retry handler factory for HTTP requests.</param>
/// <param name="credential">Token credential, e.g. DefaultAzureCredential, ManagedIdentityCredential, EnvironmentCredential, etc.</param>
/// <param name="httpClient">Custom <see cref="HttpClient"/> for HTTP requests.</param>
/// <param name="log">Application logger</param>
protected AzureOpenAIClientBase(
string modelId,
string endpoint,
TokenCredential credentials,
IDelegatingHandlerFactory? handlerFactory = null,
TokenCredential credential,
HttpClient? httpClient = null,
ILogger? log = null)
{
Verify.NotNullOrWhiteSpace(modelId);
Verify.NotNullOrWhiteSpace(endpoint);
Verify.StartsWith(endpoint, "https://", "The Azure OpenAI endpoint must start with 'https://'");

var options = new OpenAIClientOptions();
// TODO: reimplement
// Doesn't work
// if (handlerFactory != null)
// {
// options.Transport = new HttpClientTransport(handlerFactory.Create(log));
// }
if (httpClient != null)
{
options.Transport = new HttpClientTransport(httpClient);
}

this.ModelId = modelId;
this.Client = new OpenAIClient(new Uri(endpoint), credentials, options);
this.Client = new OpenAIClient(new Uri(endpoint), credential, options);
}
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Net.Http;
using Azure.AI.OpenAI;
using Azure.Core;
using Azure.Core.Pipeline;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.Diagnostics;
using Microsoft.SemanticKernel.Reliability;

namespace Microsoft.SemanticKernel.Connectors.AI.OpenAI.AzureSdk;

Expand All @@ -21,28 +22,26 @@ public abstract class OpenAIClientBase : ClientBase
/// <param name="modelId">Model name</param>
/// <param name="apiKey">OpenAI API Key</param>
/// <param name="organization">OpenAI Organization Id (usually optional)</param>
/// <param name="handlerFactory">Retry handler factory for HTTP requests.</param>
/// <param name="log">Application logger</param>
/// <param name="httpClient">Custom <see cref="HttpClient"/> for HTTP requests.</param>
/// <param name="logger">Application logger</param>
protected OpenAIClientBase(
string modelId,
string apiKey,
string? organization = null,
IDelegatingHandlerFactory? handlerFactory = null,
ILogger? log = null)
HttpClient? httpClient = null,
ILogger? logger = null
)
{
Verify.NotNullOrWhiteSpace(modelId);
Verify.NotNullOrWhiteSpace(apiKey);

this.ModelId = modelId;

var options = new OpenAIClientOptions();

// TODO: reimplement
// Doesn't work
// if (handlerFactory != null)
// {
// options.Transport = new HttpClientTransport(handlerFactory.Create(log));
// }
if (httpClient != null)
{
options.Transport = new HttpClientTransport(httpClient);
};

if (!string.IsNullOrWhiteSpace(organization))
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Azure.Core;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.AI.ChatCompletion;
using Microsoft.SemanticKernel.AI.TextCompletion;
using Microsoft.SemanticKernel.Connectors.AI.OpenAI.AzureSdk;
using Microsoft.SemanticKernel.Reliability;

namespace Microsoft.SemanticKernel.Connectors.AI.OpenAI.ChatCompletion;

Expand All @@ -23,14 +23,14 @@ public sealed class AzureChatCompletion : AzureOpenAIClientBase, IChatCompletion
/// <param name="modelId">Azure OpenAI model ID or deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource</param>
/// <param name="endpoint">Azure OpenAI deployment URL, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart</param>
/// <param name="apiKey">Azure OpenAI API key, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart</param>
/// <param name="handlerFactory">Retry handler factory for HTTP requests.</param>
/// <param name="log">Application logger</param>
/// <param name="httpClient">Custom <see cref="HttpClient"/> for HTTP requests.</param>
/// <param name="logger">Application logger</param>
public AzureChatCompletion(
string modelId,
string endpoint,
string apiKey,
IDelegatingHandlerFactory? handlerFactory = null,
ILogger? log = null) : base(modelId, endpoint, apiKey, handlerFactory, log)
HttpClient? httpClient = null,
ILogger? logger = null) : base(modelId, endpoint, apiKey, httpClient, logger)
{
}

Expand All @@ -40,14 +40,14 @@ public sealed class AzureChatCompletion : AzureOpenAIClientBase, IChatCompletion
/// <param name="modelId">Azure OpenAI model ID or deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource</param>
/// <param name="endpoint">Azure OpenAI deployment URL, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart</param>
/// <param name="credentials">Token credentials, e.g. DefaultAzureCredential, ManagedIdentityCredential, EnvironmentCredential, etc.</param>
/// <param name="handlerFactory">Retry handler factory for HTTP requests.</param>
/// <param name="log">Application logger</param>
/// <param name="httpClient">Custom <see cref="HttpClient"/> for HTTP requests.</param>
/// <param name="logger">Application logger</param>
public AzureChatCompletion(
string modelId,
string endpoint,
TokenCredential credentials,
IDelegatingHandlerFactory? handlerFactory = null,
ILogger? log = null) : base(modelId, endpoint, credentials, handlerFactory, log)
HttpClient? httpClient = null,
ILogger? logger = null) : base(modelId, endpoint, credentials, httpClient, logger)
{
}

Expand All @@ -57,9 +57,7 @@ public sealed class AzureChatCompletion : AzureOpenAIClientBase, IChatCompletion
ChatRequestSettings? requestSettings = null,
CancellationToken cancellationToken = default)
{
if (requestSettings == null) { requestSettings = new ChatRequestSettings(); }

return this.InternalGenerateChatMessageAsync(chat, requestSettings, cancellationToken);
return this.InternalGenerateChatMessageAsync(chat, requestSettings ?? new(), cancellationToken);
}

/// <inheritdoc/>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.AI.ChatCompletion;
using Microsoft.SemanticKernel.AI.TextCompletion;
using Microsoft.SemanticKernel.Connectors.AI.OpenAI.AzureSdk;
using Microsoft.SemanticKernel.Reliability;

namespace Microsoft.SemanticKernel.Connectors.AI.OpenAI.ChatCompletion;

Expand All @@ -22,15 +22,15 @@ public sealed class OpenAIChatCompletion : OpenAIClientBase, IChatCompletion, IT
/// <param name="modelId">Model name</param>
/// <param name="apiKey">OpenAI API Key</param>
/// <param name="organization">OpenAI Organization Id (usually optional)</param>
/// <param name="handlerFactory">Retry handler factory for HTTP requests.</param>
/// <param name="log">Application logger</param>
/// <param name="httpClient">Custom <see cref="HttpClient"/> for HTTP requests.</param>
/// <param name="logger">Application logger</param>
public OpenAIChatCompletion(
string modelId,
string apiKey,
string? organization = null,
IDelegatingHandlerFactory? handlerFactory = null,
ILogger? log = null
) : base(modelId, apiKey, organization, handlerFactory, log)
HttpClient? httpClient = null,
ILogger? logger = null
) : base(modelId, apiKey, organization, httpClient, logger)
{
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
using Microsoft.SemanticKernel.Connectors.AI.OpenAI.ImageGeneration;
using Microsoft.SemanticKernel.Connectors.AI.OpenAI.TextEmbedding;
using Microsoft.SemanticKernel.Diagnostics;
using Microsoft.SemanticKernel.Reliability;
using Microsoft.SemanticKernel.Text;

namespace Microsoft.SemanticKernel.Connectors.AI.OpenAI.CustomClient;
Expand All @@ -28,6 +27,8 @@ namespace Microsoft.SemanticKernel.Connectors.AI.OpenAI.CustomClient;
[SuppressMessage("Design", "CA1054:URI-like parameters should not be strings", Justification = "OpenAI users use strings")]
public abstract class OpenAIClientBase : IDisposable
{
protected readonly static HttpClientHandler DefaultHttpClientHandler = new() { CheckCertificateRevocationList = true };

/// <summary>
/// Logger
/// </summary>
Expand All @@ -38,20 +39,20 @@ public abstract class OpenAIClientBase : IDisposable
/// </summary>
protected HttpClient HTTPClient { get; }

private readonly HttpClientHandler _httpClientHandler;
private readonly IDelegatingHandlerFactory _handlerFactory;
private readonly DelegatingHandler _retryHandler;

internal OpenAIClientBase(ILogger? log = null, IDelegatingHandlerFactory? handlerFactory = null)
internal OpenAIClientBase(HttpClient? httpClient = null, ILogger? logger = null)
{
this.Log = log ?? this.Log;
this._handlerFactory = handlerFactory ?? new DefaultHttpRetryHandlerFactory();
this.Log = logger ?? this.Log;

this._httpClientHandler = new() { CheckCertificateRevocationList = true };
this._retryHandler = this._handlerFactory.Create(this.Log);
this._retryHandler.InnerHandler = this._httpClientHandler;
if (httpClient == null)
{
this.HTTPClient = new HttpClient(DefaultHttpClientHandler, disposeHandler: false);
this._disposeHttpClient = true; // If client is created internally, dispose it when done
}
else
{
this.HTTPClient = httpClient;
}

this.HTTPClient = new HttpClient(this._retryHandler);
this.HTTPClient.DefaultRequestHeaders.Add("User-Agent", HTTPUseragent);
}

Expand Down Expand Up @@ -156,11 +157,9 @@ public void Dispose()
/// <param name="disposing"></param>
protected virtual void Dispose(bool disposing)
{
if (disposing)
if (disposing & this._disposeHttpClient)
{
this.HTTPClient.Dispose();
this._httpClientHandler.Dispose();
this._retryHandler.Dispose();
}
}

Expand Down Expand Up @@ -189,6 +188,9 @@ protected virtual void Dispose(bool disposing)
// HTTP user agent sent to remote endpoints
private const string HTTPUseragent = "Microsoft Semantic Kernel";

// Set to true to dispose of HttpClient when disposing. If HttpClient was passed in, then the caller can manage.
private readonly bool _disposeHttpClient = false;

private async Task<T> ExecutePostRequestAsync<T>(string url, string requestBody, CancellationToken cancellationToken = default)
{
string responseJson;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Linq;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -9,7 +10,6 @@
using Microsoft.SemanticKernel.AI.ImageGeneration;
using Microsoft.SemanticKernel.Connectors.AI.OpenAI.CustomClient;
using Microsoft.SemanticKernel.Diagnostics;
using Microsoft.SemanticKernel.Reliability;
using Microsoft.SemanticKernel.Text;

namespace Microsoft.SemanticKernel.Connectors.AI.OpenAI.ImageGeneration;
Expand All @@ -24,14 +24,14 @@ public class OpenAIImageGeneration : OpenAIClientBase, IImageGeneration
/// </summary>
/// <param name="apiKey">OpenAI API key, see https://platform.openai.com/account/api-keys</param>
/// <param name="organization">OpenAI organization id. This is usually optional unless your account belongs to multiple organizations.</param>
/// <param name="handlerFactory">Retry handler</param>
/// <param name="log">Logger</param>
/// <param name="httpClient">Custom <see cref="HttpClient"/> for HTTP requests.</param>
/// <param name="logger">Application logger</param>
public OpenAIImageGeneration(
string apiKey,
string? organization = null,
IDelegatingHandlerFactory? handlerFactory = null,
ILogger? log = null
)
HttpClient? httpClient = null,
ILogger? logger = null
) : base(httpClient, logger)
{
Verify.NotNullOrWhiteSpace(apiKey);
this.HTTPClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", apiKey);
Expand Down

0 comments on commit b44707d

Please sign in to comment.