Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some clean up in OpenAIClientBase #788

Merged
merged 4 commits into from
May 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Net.Http;
using System.Text;
Expand All @@ -21,39 +20,34 @@

namespace Microsoft.SemanticKernel.Connectors.AI.OpenAI.CustomClient;

/// <summary>
/// An abstract OpenAI Client.
/// </summary>
[SuppressMessage("Design", "CA1054:URI-like parameters should not be strings", Justification = "OpenAI users use strings")]
#pragma warning disable CA1063 // Class isn't publicly extensible and thus doesn't implement the full IDisposable pattern
#pragma warning disable CA1816 // No derived types implement a finalizer

/// <summary>Base type for OpenAI clients.</summary>
public abstract class OpenAIClientBase : IDisposable
{
protected static readonly HttpClientHandler DefaultHttpClientHandler = new() { CheckCertificateRevocationList = true };

/// <summary>
/// Logger
/// </summary>
protected ILogger Log { get; } = NullLogger.Instance;
/// <summary>Initialize the client.</summary>
private protected OpenAIClientBase(HttpClient? httpClient = null, ILogger? logger = null)
{
this._httpClient = httpClient ?? new HttpClient(s_defaultHttpClientHandler, disposeHandler: false);
this._disposeHttpClient = this._httpClient != httpClient; // dispose a non-shared client when this is disposed

/// <summary>
/// HTTP client
/// </summary>
protected HttpClient HTTPClient { get; }
this._log = logger ?? NullLogger.Instance;
}

internal OpenAIClientBase(HttpClient? httpClient = null, ILogger? logger = null)
/// <summary>Clean up resources used by this instance.</summary>
public void Dispose()
{
this.Log = logger ?? this.Log;

if (httpClient == null)
{
this.HTTPClient = new HttpClient(DefaultHttpClientHandler, disposeHandler: false);
this._disposeHttpClient = true; // If client is created internally, dispose it when done
}
else
if (this._disposeHttpClient)
{
this.HTTPClient = httpClient;
this._httpClient.Dispose();
}
}

this.HTTPClient.DefaultRequestHeaders.Add("User-Agent", HTTPUserAgent);
/// <summary>Adds headers to use for OpenAI HTTP requests.</summary>
private protected virtual void AddRequestHeaders(HttpRequestMessage request)
{
request.Headers.Add("User-Agent", HttpUserAgent);
}

/// <summary>
Expand All @@ -64,112 +58,43 @@ internal OpenAIClientBase(HttpClient? httpClient = null, ILogger? logger = null)
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>List of text embeddings</returns>
/// <exception cref="AIException">AIException thrown during the request.</exception>
protected async Task<IList<Embedding<float>>> ExecuteTextEmbeddingRequestAsync(
private protected async Task<IList<Embedding<float>>> ExecuteTextEmbeddingRequestAsync(
string url,
string requestBody,
CancellationToken cancellationToken = default)
{
try
{
var result = await this.ExecutePostRequestAsync<TextEmbeddingResponse>(url, requestBody, cancellationToken).ConfigureAwait(false);
if (result.Embeddings.Count < 1)
{
throw new AIException(
AIException.ErrorCodes.InvalidResponseContent,
"Embeddings not found");
}

return result.Embeddings.Select(e => new Embedding<float>(e.Values)).ToList();
}
catch (Exception e) when (e is not AIException)
var result = await this.ExecutePostRequestAsync<TextEmbeddingResponse>(url, requestBody, cancellationToken).ConfigureAwait(false);
if (result.Embeddings is not { Count: >= 1 })
{
throw new AIException(
AIException.ErrorCodes.UnknownError,
$"Something went wrong: {e.Message}", e);
AIException.ErrorCodes.InvalidResponseContent,
"Embeddings not found");
}
}

/// <summary>
/// Run the HTTP request to generate a list of images
/// </summary>
/// <param name="url">URL for the image generation request API</param>
/// <param name="requestBody">Request payload</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>List of image URLs</returns>
/// <exception cref="AIException">AIException thrown during the request.</exception>
protected async Task<IList<string>> ExecuteImageUrlGenerationRequestAsync(
string url,
string requestBody,
CancellationToken cancellationToken = default)
{
try
{
var result = await this.ExecutePostRequestAsync<ImageGenerationResponse>(url, requestBody, cancellationToken).ConfigureAwait(false);
return result.Images.Select(x => x.Url).ToList();
}
catch (Exception e) when (e is not AIException)
{
throw new AIException(
AIException.ErrorCodes.UnknownError,
$"Something went wrong: {e.Message}", e);
}
return result.Embeddings.Select(e => new Embedding<float>(e.Values)).ToList();
}

/// <summary>
/// Run the HTTP request to generate a list of images
/// </summary>
/// <param name="url">URL for the image generation request API</param>
/// <param name="requestBody">Request payload</param>
/// <param name="extractResponseFunc">Function to invoke to extract the desired portion of the image generation response.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>List of images serialized in base64</returns>
/// <returns>List of image URLs</returns>
/// <exception cref="AIException">AIException thrown during the request.</exception>
protected async Task<IList<string>> ExecuteImageBase64GenerationRequestAsync(
private protected async Task<IList<string>> ExecuteImageGenerationRequestAsync(
string url,
string requestBody,
Func<ImageGenerationResponse.Image, string> extractResponseFunc,
CancellationToken cancellationToken = default)
{
try
{
var result = await this.ExecutePostRequestAsync<ImageGenerationResponse>(url, requestBody, cancellationToken).ConfigureAwait(false);
return result.Images.Select(x => x.AsBase64).ToList();
}
catch (Exception e) when (e is not AIException)
{
throw new AIException(
AIException.ErrorCodes.UnknownError,
$"Something went wrong: {e.Message}", e);
}
}

/// <summary>
/// Explicit finalizer called by IDisposable
/// </summary>
public void Dispose()
{
this.Dispose(true);
// Request CL runtime not to call the finalizer - reduce cost of GC
GC.SuppressFinalize(this);
}

/// <summary>
/// Overridable finalizer for concrete classes
/// </summary>
/// <param name="disposing"></param>
protected virtual void Dispose(bool disposing)
{
if (disposing & this._disposeHttpClient)
{
this.HTTPClient.Dispose();
}
var result = await this.ExecutePostRequestAsync<ImageGenerationResponse>(url, requestBody, cancellationToken).ConfigureAwait(false);
return result.Images.Select(extractResponseFunc).ToList();
}

protected virtual string? GetErrorMessageFromResponse(string? jsonResponsePayload)
private protected virtual string? GetErrorMessageFromResponse(string jsonResponsePayload)
{
if (jsonResponsePayload is null)
{
return null;
}

try
{
JsonNode? root = JsonSerializer.Deserialize<JsonNode>(jsonResponsePayload);
Expand All @@ -178,33 +103,48 @@ protected virtual void Dispose(bool disposing)
}
catch (Exception ex) when (ex is NotSupportedException or JsonException)
{
this.Log.LogTrace("Unable to extract error from response body content. Exception: {0}:{1}", ex.GetType(), ex.Message);
return null;
this._log.LogTrace("Unable to extract error from response body content. Exception: {0}:{1}", ex.GetType(), ex.Message);
}

return null;
}

#region private ================================================================================

// Shared singleton HttpClientHandler used when an existing HttpClient isn't provided
private static readonly HttpClientHandler s_defaultHttpClientHandler = new() { CheckCertificateRevocationList = true };

// HTTP user agent sent to remote endpoints
private const string HTTPUserAgent = "Microsoft-Semantic-Kernel";
private const string HttpUserAgent = "Microsoft-Semantic-Kernel";

// Set to true to dispose of HttpClient when disposing. If HttpClient was passed in, then the caller can manage.
private readonly bool _disposeHttpClient = false;
private readonly bool _disposeHttpClient;

/// <summary>
/// Logger
/// </summary>
private readonly ILogger _log;

/// <summary>
/// The <see cref="_httpClient"/> to use for issuing requests.
/// </summary>
private readonly HttpClient _httpClient;

private async Task<T> ExecutePostRequestAsync<T>(string url, string requestBody, CancellationToken cancellationToken = default)
{
string responseJson;

HttpResponseMessage? response = null;
try
{
using HttpContent content = new StringContent(requestBody, Encoding.UTF8, "application/json");

HttpResponseMessage response = await this.HTTPClient.PostAsync(url, content, cancellationToken).ConfigureAwait(false)
?? throw new AIException(AIException.ErrorCodes.NoResponse);
using (var request = new HttpRequestMessage(HttpMethod.Post, url))
{
this.AddRequestHeaders(request);
request.Content = new StringContent(requestBody, Encoding.UTF8, "application/json");
response = await this._httpClient.SendAsync(request, cancellationToken).ConfigureAwait(false);
}

this.Log.LogTrace("HTTP response: {0} {1}", (int)response.StatusCode, response.StatusCode.ToString("G"));
this._log.LogTrace("HTTP response: {0} {1}", (int)response.StatusCode, response.StatusCode.ToString("G"));

responseJson = await response.Content.ReadAsStringAsync().ConfigureAwait(false);
string responseJson = await response.Content.ReadAsStringAsync().ConfigureAwait(false);
string? errorDetail = this.GetErrorMessageFromResponse(responseJson);

if (!response.IsSuccessStatusCode)
Expand Down Expand Up @@ -276,29 +216,27 @@ private async Task<T> ExecutePostRequestAsync<T>(string url, string requestBody,
errorDetail);
}
}
}
catch (Exception e) when (e is not AIException)
{
throw new AIException(
AIException.ErrorCodes.UnknownError,
$"Something went wrong: {e.Message}", e);
}

try
{
var result = Json.Deserialize<T>(responseJson);
if (result != null) { return result; }

throw new AIException(
if (result is null)
{
throw new AIException(
AIException.ErrorCodes.InvalidResponseContent,
"Response JSON parse error");
}

return result;
}
catch (Exception e) when (e is not AIException)
{
throw new AIException(
AIException.ErrorCodes.UnknownError,
$"Something went wrong: {e.Message}", e);
}
finally
{
response?.Dispose();
}
}

#endregion
Expand Down