Skip to content

Commit

Permalink
Some clean up in OpenAIClientBase (microsoft#788)
Browse files Browse the repository at this point in the history
### Motivation and Context
Fix some possible functional issues as well as remove unnecessary code.

### Description
- The type accepts an HttpClient instance that it'll use. It shouldn't
then mutate that instance, as it might be shared with other consumers.
But it's adding to its DefaultRequestHeaders collection, which will
impact all other users of that instance (DefaultRequestHeaders also
isn't thread-safe and should only ever be added to while it's not being
used for issuing requests). I've removed the interaction with
DefaultRequestHeaders and instead added the relevant headers to each
message being sent.
- The base is exposing a lot of protected surface area, but it's not
actually extensible outside of the assembly. I've changed all of its
internal and protected members to instead be "private protected" or just
"private", reflecting their actual usage and ensuring the surface area
doesn't show up in public documentation (which protected members would).
- As it's not actually extensible publicly, and as no derived types are
finalizable, it needn't implement the full IDisposable pattern nor use
GC.SuppressFinalize.
- Both the base and the derived image generation client had duplicated
APIs for url vs base64 image generation. I've consolidated that via
arguments to just have one copy of the code. Note that the base64 image
generation was private and unused, so I've deleted that unused member,
but the functionality can be exposed easily if desired based on the
shared routines now in place.
- Removed some unnecessary null checking for things that'll never be
null.
- Ensured the HttpClient response message is disposed.
  • Loading branch information
stephentoub committed May 4, 2023
1 parent 4b4659a commit 68cfd5a
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 158 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Net.Http;
using System.Text;
Expand All @@ -21,39 +20,34 @@

namespace Microsoft.SemanticKernel.Connectors.AI.OpenAI.CustomClient;

/// <summary>
/// An abstract OpenAI Client.
/// </summary>
[SuppressMessage("Design", "CA1054:URI-like parameters should not be strings", Justification = "OpenAI users use strings")]
#pragma warning disable CA1063 // Class isn't publicly extensible and thus doesn't implement the full IDisposable pattern
#pragma warning disable CA1816 // No derived types implement a finalizer

/// <summary>Base type for OpenAI clients.</summary>
public abstract class OpenAIClientBase : IDisposable
{
protected static readonly HttpClientHandler DefaultHttpClientHandler = new() { CheckCertificateRevocationList = true };

/// <summary>
/// Logger
/// </summary>
protected ILogger Log { get; } = NullLogger.Instance;
/// <summary>Initialize the client.</summary>
private protected OpenAIClientBase(HttpClient? httpClient = null, ILogger? logger = null)
{
this._httpClient = httpClient ?? new HttpClient(s_defaultHttpClientHandler, disposeHandler: false);
this._disposeHttpClient = this._httpClient != httpClient; // dispose a non-shared client when this is disposed

/// <summary>
/// HTTP client
/// </summary>
protected HttpClient HTTPClient { get; }
this._log = logger ?? NullLogger.Instance;
}

internal OpenAIClientBase(HttpClient? httpClient = null, ILogger? logger = null)
/// <summary>Clean up resources used by this instance.</summary>
public void Dispose()
{
this.Log = logger ?? this.Log;

if (httpClient == null)
{
this.HTTPClient = new HttpClient(DefaultHttpClientHandler, disposeHandler: false);
this._disposeHttpClient = true; // If client is created internally, dispose it when done
}
else
if (this._disposeHttpClient)
{
this.HTTPClient = httpClient;
this._httpClient.Dispose();
}
}

this.HTTPClient.DefaultRequestHeaders.Add("User-Agent", HTTPUserAgent);
/// <summary>Adds headers to use for OpenAI HTTP requests.</summary>
private protected virtual void AddRequestHeaders(HttpRequestMessage request)
{
request.Headers.Add("User-Agent", HttpUserAgent);
}

/// <summary>
Expand All @@ -64,112 +58,43 @@ internal OpenAIClientBase(HttpClient? httpClient = null, ILogger? logger = null)
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>List of text embeddings</returns>
/// <exception cref="AIException">AIException thrown during the request.</exception>
protected async Task<IList<Embedding<float>>> ExecuteTextEmbeddingRequestAsync(
private protected async Task<IList<Embedding<float>>> ExecuteTextEmbeddingRequestAsync(
string url,
string requestBody,
CancellationToken cancellationToken = default)
{
try
{
var result = await this.ExecutePostRequestAsync<TextEmbeddingResponse>(url, requestBody, cancellationToken).ConfigureAwait(false);
if (result.Embeddings.Count < 1)
{
throw new AIException(
AIException.ErrorCodes.InvalidResponseContent,
"Embeddings not found");
}

return result.Embeddings.Select(e => new Embedding<float>(e.Values)).ToList();
}
catch (Exception e) when (e is not AIException)
var result = await this.ExecutePostRequestAsync<TextEmbeddingResponse>(url, requestBody, cancellationToken).ConfigureAwait(false);
if (result.Embeddings is not { Count: >= 1 })
{
throw new AIException(
AIException.ErrorCodes.UnknownError,
$"Something went wrong: {e.Message}", e);
AIException.ErrorCodes.InvalidResponseContent,
"Embeddings not found");
}
}

/// <summary>
/// Run the HTTP request to generate a list of images
/// </summary>
/// <param name="url">URL for the image generation request API</param>
/// <param name="requestBody">Request payload</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>List of image URLs</returns>
/// <exception cref="AIException">AIException thrown during the request.</exception>
protected async Task<IList<string>> ExecuteImageUrlGenerationRequestAsync(
string url,
string requestBody,
CancellationToken cancellationToken = default)
{
try
{
var result = await this.ExecutePostRequestAsync<ImageGenerationResponse>(url, requestBody, cancellationToken).ConfigureAwait(false);
return result.Images.Select(x => x.Url).ToList();
}
catch (Exception e) when (e is not AIException)
{
throw new AIException(
AIException.ErrorCodes.UnknownError,
$"Something went wrong: {e.Message}", e);
}
return result.Embeddings.Select(e => new Embedding<float>(e.Values)).ToList();
}

/// <summary>
/// Run the HTTP request to generate a list of images
/// </summary>
/// <param name="url">URL for the image generation request API</param>
/// <param name="requestBody">Request payload</param>
/// <param name="extractResponseFunc">Function to invoke to extract the desired portion of the image generation response.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>List of images serialized in base64</returns>
/// <returns>List of image URLs</returns>
/// <exception cref="AIException">AIException thrown during the request.</exception>
protected async Task<IList<string>> ExecuteImageBase64GenerationRequestAsync(
private protected async Task<IList<string>> ExecuteImageGenerationRequestAsync(
string url,
string requestBody,
Func<ImageGenerationResponse.Image, string> extractResponseFunc,
CancellationToken cancellationToken = default)
{
try
{
var result = await this.ExecutePostRequestAsync<ImageGenerationResponse>(url, requestBody, cancellationToken).ConfigureAwait(false);
return result.Images.Select(x => x.AsBase64).ToList();
}
catch (Exception e) when (e is not AIException)
{
throw new AIException(
AIException.ErrorCodes.UnknownError,
$"Something went wrong: {e.Message}", e);
}
}

/// <summary>
/// Explicit finalizer called by IDisposable
/// </summary>
public void Dispose()
{
this.Dispose(true);
// Request CL runtime not to call the finalizer - reduce cost of GC
GC.SuppressFinalize(this);
}

/// <summary>
/// Overridable finalizer for concrete classes
/// </summary>
/// <param name="disposing"></param>
protected virtual void Dispose(bool disposing)
{
if (disposing & this._disposeHttpClient)
{
this.HTTPClient.Dispose();
}
var result = await this.ExecutePostRequestAsync<ImageGenerationResponse>(url, requestBody, cancellationToken).ConfigureAwait(false);
return result.Images.Select(extractResponseFunc).ToList();
}

protected virtual string? GetErrorMessageFromResponse(string? jsonResponsePayload)
private protected virtual string? GetErrorMessageFromResponse(string jsonResponsePayload)
{
if (jsonResponsePayload is null)
{
return null;
}

try
{
JsonNode? root = JsonSerializer.Deserialize<JsonNode>(jsonResponsePayload);
Expand All @@ -178,33 +103,48 @@ protected virtual void Dispose(bool disposing)
}
catch (Exception ex) when (ex is NotSupportedException or JsonException)
{
this.Log.LogTrace("Unable to extract error from response body content. Exception: {0}:{1}", ex.GetType(), ex.Message);
return null;
this._log.LogTrace("Unable to extract error from response body content. Exception: {0}:{1}", ex.GetType(), ex.Message);
}

return null;
}

#region private ================================================================================

// Shared singleton HttpClientHandler used when an existing HttpClient isn't provided
private static readonly HttpClientHandler s_defaultHttpClientHandler = new() { CheckCertificateRevocationList = true };

// HTTP user agent sent to remote endpoints
private const string HTTPUserAgent = "Microsoft-Semantic-Kernel";
private const string HttpUserAgent = "Microsoft-Semantic-Kernel";

// Set to true to dispose of HttpClient when disposing. If HttpClient was passed in, then the caller can manage.
private readonly bool _disposeHttpClient = false;
private readonly bool _disposeHttpClient;

/// <summary>
/// Logger
/// </summary>
private readonly ILogger _log;

/// <summary>
/// The <see cref="_httpClient"/> to use for issuing requests.
/// </summary>
private readonly HttpClient _httpClient;

private async Task<T> ExecutePostRequestAsync<T>(string url, string requestBody, CancellationToken cancellationToken = default)
{
string responseJson;

HttpResponseMessage? response = null;
try
{
using HttpContent content = new StringContent(requestBody, Encoding.UTF8, "application/json");

HttpResponseMessage response = await this.HTTPClient.PostAsync(url, content, cancellationToken).ConfigureAwait(false)
?? throw new AIException(AIException.ErrorCodes.NoResponse);
using (var request = new HttpRequestMessage(HttpMethod.Post, url))
{
this.AddRequestHeaders(request);
request.Content = new StringContent(requestBody, Encoding.UTF8, "application/json");
response = await this._httpClient.SendAsync(request, cancellationToken).ConfigureAwait(false);
}

this.Log.LogTrace("HTTP response: {0} {1}", (int)response.StatusCode, response.StatusCode.ToString("G"));
this._log.LogTrace("HTTP response: {0} {1}", (int)response.StatusCode, response.StatusCode.ToString("G"));

responseJson = await response.Content.ReadAsStringAsync().ConfigureAwait(false);
string responseJson = await response.Content.ReadAsStringAsync().ConfigureAwait(false);
string? errorDetail = this.GetErrorMessageFromResponse(responseJson);

if (!response.IsSuccessStatusCode)
Expand Down Expand Up @@ -276,29 +216,27 @@ private async Task<T> ExecutePostRequestAsync<T>(string url, string requestBody,
errorDetail);
}
}
}
catch (Exception e) when (e is not AIException)
{
throw new AIException(
AIException.ErrorCodes.UnknownError,
$"Something went wrong: {e.Message}", e);
}

try
{
var result = Json.Deserialize<T>(responseJson);
if (result != null) { return result; }

throw new AIException(
if (result is null)
{
throw new AIException(
AIException.ErrorCodes.InvalidResponseContent,
"Response JSON parse error");
}

return result;
}
catch (Exception e) when (e is not AIException)
{
throw new AIException(
AIException.ErrorCodes.UnknownError,
$"Something went wrong: {e.Message}", e);
}
finally
{
response?.Dispose();
}
}

#endregion
Expand Down

0 comments on commit 68cfd5a

Please sign in to comment.