Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing custom HttpClient #743

Merged
merged 10 commits into from
May 1, 2023
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Net.Http;
using Azure;
using Azure.AI.OpenAI;
using Azure.Core;
using Azure.Core.Pipeline;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.Diagnostics;
using Microsoft.SemanticKernel.Reliability;

namespace Microsoft.SemanticKernel.Connectors.AI.OpenAI.AzureSdk;

Expand All @@ -23,27 +24,26 @@ public abstract class AzureOpenAIClientBase : ClientBase
/// <param name="modelId">Azure OpenAI model ID or deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource</param>
/// <param name="endpoint">Azure OpenAI deployment URL, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart</param>
/// <param name="apiKey">Azure OpenAI API key, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart</param>
/// <param name="handlerFactory">Retry handler factory for HTTP requests.</param>
/// <param name="log">Application logger</param>
/// <param name="httpClient">Custom <see cref="HttpClient"/> for HTTP requests.</param>
/// <param name="logger">Application logger</param>
protected AzureOpenAIClientBase(
string modelId,
string endpoint,
string apiKey,
IDelegatingHandlerFactory? handlerFactory = null,
ILogger? log = null)
HttpClient? httpClient = null,
ILogger? logger = null)
{
Verify.NotNullOrWhiteSpace(modelId);
Verify.NotNullOrWhiteSpace(endpoint);
Verify.StartsWith(endpoint, "https://", "The Azure OpenAI endpoint must start with 'https://'");
Verify.NotNullOrWhiteSpace(apiKey);

var options = new OpenAIClientOptions();
// TODO: reimplement
// Doesn't work
// if (handlerFactory != null)
// {
// options.Transport = new HttpClientTransport(handlerFactory.Create(log));
// }

if (httpClient != null)
{
options.Transport = new HttpClientTransport(httpClient);
};

this.ModelId = modelId;
this.Client = new OpenAIClient(new Uri(endpoint), new AzureKeyCredential(apiKey), options);
Expand All @@ -54,29 +54,27 @@ public abstract class AzureOpenAIClientBase : ClientBase
/// </summary>
/// <param name="modelId">Azure OpenAI model ID or deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource</param>
/// <param name="endpoint">Azure OpenAI deployment URL, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart</param>
/// <param name="credentials">Token credentials, e.g. DefaultAzureCredential, ManagedIdentityCredential, EnvironmentCredential, etc.</param>
/// <param name="handlerFactory">Retry handler factory for HTTP requests.</param>
/// <param name="credential">Token credential, e.g. DefaultAzureCredential, ManagedIdentityCredential, EnvironmentCredential, etc.</param>
/// <param name="httpClient">Custom <see cref="HttpClient"/> for HTTP requests.</param>
/// <param name="log">Application logger</param>
protected AzureOpenAIClientBase(
string modelId,
string endpoint,
TokenCredential credentials,
IDelegatingHandlerFactory? handlerFactory = null,
TokenCredential credential,
HttpClient? httpClient = null,
ILogger? log = null)
{
Verify.NotNullOrWhiteSpace(modelId);
Verify.NotNullOrWhiteSpace(endpoint);
Verify.StartsWith(endpoint, "https://", "The Azure OpenAI endpoint must start with 'https://'");

var options = new OpenAIClientOptions();
// TODO: reimplement
// Doesn't work
// if (handlerFactory != null)
// {
// options.Transport = new HttpClientTransport(handlerFactory.Create(log));
// }
if (httpClient != null)
{
options.Transport = new HttpClientTransport(httpClient);
}

this.ModelId = modelId;
this.Client = new OpenAIClient(new Uri(endpoint), credentials, options);
this.Client = new OpenAIClient(new Uri(endpoint), credential, options);
}
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Net.Http;
using Azure.AI.OpenAI;
using Azure.Core;
using Azure.Core.Pipeline;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.Diagnostics;
using Microsoft.SemanticKernel.Reliability;

namespace Microsoft.SemanticKernel.Connectors.AI.OpenAI.AzureSdk;

Expand All @@ -21,28 +22,26 @@ public abstract class OpenAIClientBase : ClientBase
/// <param name="modelId">Model name</param>
/// <param name="apiKey">OpenAI API Key</param>
/// <param name="organization">OpenAI Organization Id (usually optional)</param>
/// <param name="handlerFactory">Retry handler factory for HTTP requests.</param>
/// <param name="log">Application logger</param>
/// <param name="httpClient">Custom <see cref="HttpClient"/> for HTTP requests.</param>
/// <param name="logger">Application logger</param>
protected OpenAIClientBase(
string modelId,
string apiKey,
string? organization = null,
IDelegatingHandlerFactory? handlerFactory = null,
ILogger? log = null)
HttpClient? httpClient = null,
ILogger? logger = null
)
{
Verify.NotNullOrWhiteSpace(modelId);
Verify.NotNullOrWhiteSpace(apiKey);

this.ModelId = modelId;

var options = new OpenAIClientOptions();

// TODO: reimplement
// Doesn't work
// if (handlerFactory != null)
// {
// options.Transport = new HttpClientTransport(handlerFactory.Create(log));
// }
if (httpClient != null)
{
options.Transport = new HttpClientTransport(httpClient);
};

if (!string.IsNullOrWhiteSpace(organization))
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Azure.Core;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.AI.ChatCompletion;
using Microsoft.SemanticKernel.AI.TextCompletion;
using Microsoft.SemanticKernel.Connectors.AI.OpenAI.AzureSdk;
using Microsoft.SemanticKernel.Reliability;

namespace Microsoft.SemanticKernel.Connectors.AI.OpenAI.ChatCompletion;

Expand All @@ -23,14 +23,14 @@ public sealed class AzureChatCompletion : AzureOpenAIClientBase, IChatCompletion
/// <param name="modelId">Azure OpenAI model ID or deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource</param>
/// <param name="endpoint">Azure OpenAI deployment URL, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart</param>
/// <param name="apiKey">Azure OpenAI API key, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart</param>
/// <param name="handlerFactory">Retry handler factory for HTTP requests.</param>
/// <param name="log">Application logger</param>
/// <param name="httpClient">Custom <see cref="HttpClient"/> for HTTP requests.</param>
/// <param name="logger">Application logger</param>
public AzureChatCompletion(
string modelId,
string endpoint,
string apiKey,
IDelegatingHandlerFactory? handlerFactory = null,
ILogger? log = null) : base(modelId, endpoint, apiKey, handlerFactory, log)
HttpClient? httpClient = null,
ILogger? logger = null) : base(modelId, endpoint, apiKey, httpClient, logger)
{
}

Expand All @@ -40,14 +40,14 @@ public sealed class AzureChatCompletion : AzureOpenAIClientBase, IChatCompletion
/// <param name="modelId">Azure OpenAI model ID or deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource</param>
/// <param name="endpoint">Azure OpenAI deployment URL, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart</param>
/// <param name="credentials">Token credentials, e.g. DefaultAzureCredential, ManagedIdentityCredential, EnvironmentCredential, etc.</param>
/// <param name="handlerFactory">Retry handler factory for HTTP requests.</param>
/// <param name="log">Application logger</param>
/// <param name="httpClient">Custom <see cref="HttpClient"/> for HTTP requests.</param>
/// <param name="logger">Application logger</param>
public AzureChatCompletion(
string modelId,
string endpoint,
TokenCredential credentials,
IDelegatingHandlerFactory? handlerFactory = null,
ILogger? log = null) : base(modelId, endpoint, credentials, handlerFactory, log)
HttpClient? httpClient = null,
ILogger? logger = null) : base(modelId, endpoint, credentials, httpClient, logger)
{
}

Expand All @@ -57,9 +57,7 @@ public sealed class AzureChatCompletion : AzureOpenAIClientBase, IChatCompletion
ChatRequestSettings? requestSettings = null,
CancellationToken cancellationToken = default)
{
if (requestSettings == null) { requestSettings = new ChatRequestSettings(); }

return this.InternalGenerateChatMessageAsync(chat, requestSettings, cancellationToken);
return this.InternalGenerateChatMessageAsync(chat, requestSettings ?? new(), cancellationToken);
}

/// <inheritdoc/>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.AI.ChatCompletion;
using Microsoft.SemanticKernel.AI.TextCompletion;
using Microsoft.SemanticKernel.Connectors.AI.OpenAI.AzureSdk;
using Microsoft.SemanticKernel.Reliability;

namespace Microsoft.SemanticKernel.Connectors.AI.OpenAI.ChatCompletion;

Expand All @@ -22,15 +22,15 @@ public sealed class OpenAIChatCompletion : OpenAIClientBase, IChatCompletion, IT
/// <param name="modelId">Model name</param>
/// <param name="apiKey">OpenAI API Key</param>
/// <param name="organization">OpenAI Organization Id (usually optional)</param>
/// <param name="handlerFactory">Retry handler factory for HTTP requests.</param>
/// <param name="log">Application logger</param>
/// <param name="httpClient">Custom <see cref="HttpClient"/> for HTTP requests.</param>
/// <param name="logger">Application logger</param>
public OpenAIChatCompletion(
string modelId,
string apiKey,
string? organization = null,
IDelegatingHandlerFactory? handlerFactory = null,
ILogger? log = null
) : base(modelId, apiKey, organization, handlerFactory, log)
HttpClient? httpClient = null,
ILogger? logger = null
) : base(modelId, apiKey, organization, httpClient, logger)
{
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
using Microsoft.SemanticKernel.Connectors.AI.OpenAI.ImageGeneration;
using Microsoft.SemanticKernel.Connectors.AI.OpenAI.TextEmbedding;
using Microsoft.SemanticKernel.Diagnostics;
using Microsoft.SemanticKernel.Reliability;
using Microsoft.SemanticKernel.Text;

namespace Microsoft.SemanticKernel.Connectors.AI.OpenAI.CustomClient;
Expand All @@ -28,6 +27,8 @@ namespace Microsoft.SemanticKernel.Connectors.AI.OpenAI.CustomClient;
[SuppressMessage("Design", "CA1054:URI-like parameters should not be strings", Justification = "OpenAI users use strings")]
public abstract class OpenAIClientBase : IDisposable
{
protected readonly static HttpClientHandler DefaultHttpClientHandler = new() { CheckCertificateRevocationList = true };

/// <summary>
/// Logger
/// </summary>
Expand All @@ -38,20 +39,20 @@ public abstract class OpenAIClientBase : IDisposable
/// </summary>
protected HttpClient HTTPClient { get; }

private readonly HttpClientHandler _httpClientHandler;
private readonly IDelegatingHandlerFactory _handlerFactory;
private readonly DelegatingHandler _retryHandler;

internal OpenAIClientBase(ILogger? log = null, IDelegatingHandlerFactory? handlerFactory = null)
internal OpenAIClientBase(HttpClient? httpClient = null, ILogger? logger = null)
{
this.Log = log ?? this.Log;
this._handlerFactory = handlerFactory ?? new DefaultHttpRetryHandlerFactory();
this.Log = logger ?? this.Log;

this._httpClientHandler = new() { CheckCertificateRevocationList = true };
this._retryHandler = this._handlerFactory.Create(this.Log);
this._retryHandler.InnerHandler = this._httpClientHandler;
if (httpClient == null)
{
this.HTTPClient = new HttpClient(DefaultHttpClientHandler, disposeHandler: false);
this._disposeHttpClient = true; // If client is created internally, dispose it when done
}
else
{
this.HTTPClient = httpClient;
}

this.HTTPClient = new HttpClient(this._retryHandler);
this.HTTPClient.DefaultRequestHeaders.Add("User-Agent", HTTPUseragent);
}

Expand Down Expand Up @@ -156,11 +157,9 @@ public void Dispose()
/// <param name="disposing"></param>
protected virtual void Dispose(bool disposing)
{
if (disposing)
if (disposing & this._disposeHttpClient)
{
this.HTTPClient.Dispose();
this._httpClientHandler.Dispose();
this._retryHandler.Dispose();
}
}

Expand Down Expand Up @@ -189,6 +188,9 @@ protected virtual void Dispose(bool disposing)
// HTTP user agent sent to remote endpoints
private const string HTTPUseragent = "Microsoft Semantic Kernel";

// Set to true to dispose of HttpClient when disposing. If HttpClient was passed in, then the caller can manage.
private readonly bool _disposeHttpClient = false;

private async Task<T> ExecutePostRequestAsync<T>(string url, string requestBody, CancellationToken cancellationToken = default)
{
string responseJson;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Linq;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -9,7 +10,6 @@
using Microsoft.SemanticKernel.AI.ImageGeneration;
using Microsoft.SemanticKernel.Connectors.AI.OpenAI.CustomClient;
using Microsoft.SemanticKernel.Diagnostics;
using Microsoft.SemanticKernel.Reliability;
using Microsoft.SemanticKernel.Text;

namespace Microsoft.SemanticKernel.Connectors.AI.OpenAI.ImageGeneration;
Expand All @@ -24,14 +24,14 @@ public class OpenAIImageGeneration : OpenAIClientBase, IImageGeneration
/// </summary>
/// <param name="apiKey">OpenAI API key, see https://platform.openai.com/account/api-keys</param>
/// <param name="organization">OpenAI organization id. This is usually optional unless your account belongs to multiple organizations.</param>
/// <param name="handlerFactory">Retry handler</param>
/// <param name="log">Logger</param>
/// <param name="httpClient">Custom <see cref="HttpClient"/> for HTTP requests.</param>
/// <param name="logger">Application logger</param>
public OpenAIImageGeneration(
string apiKey,
string? organization = null,
IDelegatingHandlerFactory? handlerFactory = null,
ILogger? log = null
)
HttpClient? httpClient = null,
ILogger? logger = null
) : base(httpClient, logger)
{
Verify.NotNullOrWhiteSpace(apiKey);
this.HTTPClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", apiKey);
Expand Down