From 471ec9c19af3ce2f7c314b498e52f24f4fe38c9a Mon Sep 17 00:00:00 2001 From: xbotter Date: Fri, 9 Jun 2023 23:10:28 +0800 Subject: [PATCH] Add Azure Dall-E (#1209) ### Motivation and Context Implement Azure OpenAI Dall E ### Description Implement Azure OpenAI Image Generation according to the official [documentation](https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#image-generation) and the examples provided by Azure OpenAI Studio, so that it can be used in SK. Due to the differences between the API of Azure OpenAI Dall-E and OpenAI Dall-E, some modifications have been made to OpenAIClientBase under CustomClient to ensure compatibility. Co-authored-by: Lee Miller Co-authored-by: Shawn Callegari <36091529+shawncal@users.noreply.github.com> Co-authored-by: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com> --- .../CustomClient/OpenAIClientBase.cs | 186 +++++++------- .../AzureImageGenerationResponse.cs | 41 ++++ .../AzureImageOperationStatus.cs | 40 +++ .../AzureOpenAIImageGeneration.cs | 232 ++++++++++++++++++ .../OpenAIKernelBuilderExtensions.cs | 32 ++- .../Connectors.UnitTests.csproj | 6 + .../AzureOpenAIImageGenerationTests.cs | 78 ++++++ .../OpenAI/OpenAITestHelper.cs | 20 ++ .../image_generation_test_response.json | 4 + .../TestData/image_result_test_response.json | 12 + .../kernel-syntax-examples/Example18_DallE.cs | 79 +++++- 11 files changed, 645 insertions(+), 85 deletions(-) create mode 100644 dotnet/src/Connectors/Connectors.AI.OpenAI/ImageGeneration/AzureImageGenerationResponse.cs create mode 100644 dotnet/src/Connectors/Connectors.AI.OpenAI/ImageGeneration/AzureImageOperationStatus.cs create mode 100644 dotnet/src/Connectors/Connectors.AI.OpenAI/ImageGeneration/AzureOpenAIImageGeneration.cs create mode 100644 dotnet/src/Connectors/Connectors.UnitTests/OpenAI/ImageGeneration/AzureOpenAIImageGenerationTests.cs create mode 100644 dotnet/src/Connectors/Connectors.UnitTests/OpenAI/OpenAITestHelper.cs create mode 100644 dotnet/src/Connectors/Connectors.UnitTests/OpenAI/TestData/image_generation_test_response.json create mode 100644 dotnet/src/Connectors/Connectors.UnitTests/OpenAI/TestData/image_result_test_response.json diff --git a/dotnet/src/Connectors/Connectors.AI.OpenAI/CustomClient/OpenAIClientBase.cs b/dotnet/src/Connectors/Connectors.AI.OpenAI/CustomClient/OpenAIClientBase.cs index 040857a333db..f36bb6d55745 100644 --- a/dotnet/src/Connectors/Connectors.AI.OpenAI/CustomClient/OpenAIClientBase.cs +++ b/dotnet/src/Connectors/Connectors.AI.OpenAI/CustomClient/OpenAIClientBase.cs @@ -114,100 +114,126 @@ private protected async Task> ExecuteImageGenerationRequestAsync( /// private readonly HttpClient _httpClient; - private async Task ExecutePostRequestAsync(string url, string requestBody, CancellationToken cancellationToken = default) + private protected async Task ExecutePostRequestAsync(string url, string requestBody, CancellationToken cancellationToken = default) + { + try + { + using var content = new StringContent(requestBody, Encoding.UTF8, "application/json"); + using var response = await this.ExecuteRequestAsync(url, HttpMethod.Post, content, cancellationToken).ConfigureAwait(false); + string responseJson = await response.Content.ReadAsStringAsync().ConfigureAwait(false); + T result = this.JsonDeserialize(responseJson); + return result; + } + catch (Exception e) when (e is not AIException) + { + throw new AIException( + AIException.ErrorCodes.UnknownError, + $"Something went wrong: {e.Message}", e); + } + } + + private protected T JsonDeserialize(string responseJson) + { + var result = Json.Deserialize(responseJson); + if (result is null) + { + throw new AIException(AIException.ErrorCodes.InvalidResponseContent, "Response JSON parse error"); + } + + return result; + } + + private protected async Task ExecuteRequestAsync(string url, HttpMethod method, HttpContent? content, CancellationToken cancellationToken = default) { HttpResponseMessage? response = null; try { - using (var request = new HttpRequestMessage(HttpMethod.Post, url)) + using (var request = new HttpRequestMessage(method, url)) { this.AddRequestHeaders(request); - request.Content = new StringContent(requestBody, Encoding.UTF8, "application/json"); + if (content != null) + { + request.Content = content; + } + response = await this._httpClient.SendAsync(request, cancellationToken).ConfigureAwait(false); } this._log.LogTrace("HTTP response: {0} {1}", (int)response.StatusCode, response.StatusCode.ToString("G")); - string responseJson = await response.Content.ReadAsStringAsync().ConfigureAwait(false); - string? errorDetail = this.GetErrorMessageFromResponse(responseJson); - - if (!response.IsSuccessStatusCode) + if (response.IsSuccessStatusCode) { - switch ((HttpStatusCodeType)response.StatusCode) - { - case HttpStatusCodeType.BadRequest: - case HttpStatusCodeType.MethodNotAllowed: - case HttpStatusCodeType.NotFound: - case HttpStatusCodeType.NotAcceptable: - case HttpStatusCodeType.Conflict: - case HttpStatusCodeType.Gone: - case HttpStatusCodeType.LengthRequired: - case HttpStatusCodeType.PreconditionFailed: - case HttpStatusCodeType.RequestEntityTooLarge: - case HttpStatusCodeType.RequestUriTooLong: - case HttpStatusCodeType.UnsupportedMediaType: - case HttpStatusCodeType.RequestedRangeNotSatisfiable: - case HttpStatusCodeType.ExpectationFailed: - case HttpStatusCodeType.HttpVersionNotSupported: - case HttpStatusCodeType.UpgradeRequired: - case HttpStatusCodeType.MisdirectedRequest: - case HttpStatusCodeType.UnprocessableEntity: - case HttpStatusCodeType.Locked: - case HttpStatusCodeType.FailedDependency: - case HttpStatusCodeType.PreconditionRequired: - case HttpStatusCodeType.RequestHeaderFieldsTooLarge: - throw new AIException( - AIException.ErrorCodes.InvalidRequest, - $"The request is not valid, HTTP status: {response.StatusCode:G}", - errorDetail); - - case HttpStatusCodeType.Unauthorized: - case HttpStatusCodeType.Forbidden: - case HttpStatusCodeType.ProxyAuthenticationRequired: - case HttpStatusCodeType.UnavailableForLegalReasons: - case HttpStatusCodeType.NetworkAuthenticationRequired: - throw new AIException( - AIException.ErrorCodes.AccessDenied, - $"The request is not authorized, HTTP status: {response.StatusCode:G}", - errorDetail); - - case HttpStatusCodeType.RequestTimeout: - throw new AIException( - AIException.ErrorCodes.RequestTimeout, - $"The request timed out, HTTP status: {response.StatusCode:G}"); - - case HttpStatusCodeType.TooManyRequests: - throw new AIException( - AIException.ErrorCodes.Throttling, - $"Too many requests, HTTP status: {response.StatusCode:G}", - errorDetail); - - case HttpStatusCodeType.InternalServerError: - case HttpStatusCodeType.NotImplemented: - case HttpStatusCodeType.BadGateway: - case HttpStatusCodeType.ServiceUnavailable: - case HttpStatusCodeType.GatewayTimeout: - case HttpStatusCodeType.InsufficientStorage: - throw new AIException( - AIException.ErrorCodes.ServiceError, - $"The service failed to process the request, HTTP status: {response.StatusCode:G}", - errorDetail); - - default: - throw new AIException( - AIException.ErrorCodes.UnknownError, - $"Unexpected HTTP response, status: {response.StatusCode:G}", - errorDetail); - } + return response; } - var result = Json.Deserialize(responseJson); - if (result is null) + string responseJson = await response.Content.ReadAsStringAsync().ConfigureAwait(false); + string? errorDetail = this.GetErrorMessageFromResponse(responseJson); + switch ((HttpStatusCodeType)response.StatusCode) { - throw new AIException(AIException.ErrorCodes.InvalidResponseContent, "Response JSON parse error"); + case HttpStatusCodeType.BadRequest: + case HttpStatusCodeType.MethodNotAllowed: + case HttpStatusCodeType.NotFound: + case HttpStatusCodeType.NotAcceptable: + case HttpStatusCodeType.Conflict: + case HttpStatusCodeType.Gone: + case HttpStatusCodeType.LengthRequired: + case HttpStatusCodeType.PreconditionFailed: + case HttpStatusCodeType.RequestEntityTooLarge: + case HttpStatusCodeType.RequestUriTooLong: + case HttpStatusCodeType.UnsupportedMediaType: + case HttpStatusCodeType.RequestedRangeNotSatisfiable: + case HttpStatusCodeType.ExpectationFailed: + case HttpStatusCodeType.HttpVersionNotSupported: + case HttpStatusCodeType.UpgradeRequired: + case HttpStatusCodeType.MisdirectedRequest: + case HttpStatusCodeType.UnprocessableEntity: + case HttpStatusCodeType.Locked: + case HttpStatusCodeType.FailedDependency: + case HttpStatusCodeType.PreconditionRequired: + case HttpStatusCodeType.RequestHeaderFieldsTooLarge: + throw new AIException( + AIException.ErrorCodes.InvalidRequest, + $"The request is not valid, HTTP status: {response.StatusCode:G}", + errorDetail); + + case HttpStatusCodeType.Unauthorized: + case HttpStatusCodeType.Forbidden: + case HttpStatusCodeType.ProxyAuthenticationRequired: + case HttpStatusCodeType.UnavailableForLegalReasons: + case HttpStatusCodeType.NetworkAuthenticationRequired: + throw new AIException( + AIException.ErrorCodes.AccessDenied, + $"The request is not authorized, HTTP status: {response.StatusCode:G}", + errorDetail); + + case HttpStatusCodeType.RequestTimeout: + throw new AIException( + AIException.ErrorCodes.RequestTimeout, + $"The request timed out, HTTP status: {response.StatusCode:G}"); + + case HttpStatusCodeType.TooManyRequests: + throw new AIException( + AIException.ErrorCodes.Throttling, + $"Too many requests, HTTP status: {response.StatusCode:G}", + errorDetail); + + case HttpStatusCodeType.InternalServerError: + case HttpStatusCodeType.NotImplemented: + case HttpStatusCodeType.BadGateway: + case HttpStatusCodeType.ServiceUnavailable: + case HttpStatusCodeType.GatewayTimeout: + case HttpStatusCodeType.InsufficientStorage: + throw new AIException( + AIException.ErrorCodes.ServiceError, + $"The service failed to process the request, HTTP status: {response.StatusCode:G}", + errorDetail); + + default: + throw new AIException( + AIException.ErrorCodes.UnknownError, + $"Unexpected HTTP response, status: {response.StatusCode:G}", + errorDetail); } - - return result; } catch (Exception e) when (e is not AIException) { @@ -215,10 +241,6 @@ private async Task ExecutePostRequestAsync(string url, string requestBody, AIException.ErrorCodes.UnknownError, $"Something went wrong: {e.Message}", e); } - finally - { - response?.Dispose(); - } } #endregion diff --git a/dotnet/src/Connectors/Connectors.AI.OpenAI/ImageGeneration/AzureImageGenerationResponse.cs b/dotnet/src/Connectors/Connectors.AI.OpenAI/ImageGeneration/AzureImageGenerationResponse.cs new file mode 100644 index 000000000000..961a5aa361ec --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AI.OpenAI/ImageGeneration/AzureImageGenerationResponse.cs @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json.Serialization; + +namespace Microsoft.SemanticKernel.Connectors.AI.OpenAI.ImageGeneration; + +/// +/// Image generation response +/// +public class AzureImageGenerationResponse +{ + /// + /// Image generation result + /// + [JsonPropertyName("result")] + public ImageGenerationResponse? Result { get; set; } + + /// + /// Request Id + /// + [JsonPropertyName("id")] + public string Id { get; set; } = string.Empty; + + /// + /// Request Status + /// + [JsonPropertyName("status")] + public string Status { get; set; } = string.Empty; + + /// + /// Creation time + /// + [JsonPropertyName("created")] + public int Created { get; set; } + + /// + /// Expiration time of the URL + /// + [JsonPropertyName("expires")] + public int Expires { get; set; } +} diff --git a/dotnet/src/Connectors/Connectors.AI.OpenAI/ImageGeneration/AzureImageOperationStatus.cs b/dotnet/src/Connectors/Connectors.AI.OpenAI/ImageGeneration/AzureImageOperationStatus.cs new file mode 100644 index 000000000000..1abe033b2780 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AI.OpenAI/ImageGeneration/AzureImageOperationStatus.cs @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.SemanticKernel.Connectors.AI.OpenAI.ImageGeneration; + +/// +/// Azure image generation response status +/// +/// +public static class AzureImageOperationStatus +{ + /// + /// Image generation Succeeded + /// + public const string Succeeded = "succeeded"; + + /// + /// Image generation Failed + /// + public const string Failed = "failed"; + + /// + /// Task is running + /// + public const string Running = "running"; + + /// + /// Task is queued but hasn't started yet + /// + public const string NotRunning = "notRunning"; + + /// + /// The image has been removed from Azure's server. + /// + public const string Deleted = "deleted"; + + /// + /// Task has timed out + /// + public const string Cancelled = "cancelled"; +} diff --git a/dotnet/src/Connectors/Connectors.AI.OpenAI/ImageGeneration/AzureOpenAIImageGeneration.cs b/dotnet/src/Connectors/Connectors.AI.OpenAI/ImageGeneration/AzureOpenAIImageGeneration.cs new file mode 100644 index 000000000000..910ae36dbb99 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AI.OpenAI/ImageGeneration/AzureOpenAIImageGeneration.cs @@ -0,0 +1,232 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Linq; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Microsoft.SemanticKernel.AI; +using Microsoft.SemanticKernel.AI.ImageGeneration; +using Microsoft.SemanticKernel.Connectors.AI.OpenAI.CustomClient; +using Microsoft.SemanticKernel.Diagnostics; +using Microsoft.SemanticKernel.Text; + +namespace Microsoft.SemanticKernel.Connectors.AI.OpenAI.ImageGeneration; + +/// +/// Azure OpenAI Image generation +/// +/// +public class AzureOpenAIImageGeneration : OpenAIClientBase, IImageGeneration +{ + /// + /// Generation Image Operation path + /// + private const string GenerationImageOperation = "openai/images/generations:submit"; + + /// + /// Get Image Operation path + /// + private const string GetImageOperation = "openai/operations/images"; + + /// + /// Azure OpenAI REST API endpoint + /// + private readonly string _endpoint; + + /// + /// Azure OpenAI API key + /// + private readonly string _apiKey; + + /// + /// Maximum number of attempts to retrieve the image generation operation result. + /// + private readonly int _maxRetryCount; + + /// + /// Azure OpenAI Endpoint ApiVersion + /// + private readonly string _apiVersion; + + /// + /// Create a new instance of Azure OpenAI image generation service + /// + /// Azure OpenAI deployment URL, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart + /// Azure OpenAI API key, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart + /// Custom for HTTP requests. + /// Application logger + /// Maximum number of attempts to retrieve the image generation operation result. + /// Azure OpenAI Endpoint ApiVersion + public AzureOpenAIImageGeneration(string endpoint, string apiKey, HttpClient? httpClient = null, ILogger? logger = null, int maxRetryCount = 5, string apiVersion = "2023-06-01-preview") : base(httpClient, logger) + { + Verify.NotNullOrWhiteSpace(endpoint); + Verify.NotNullOrWhiteSpace(apiKey); + Verify.StartsWith(endpoint, "https://", "The Azure OpenAI endpoint must start with 'https://'"); + + this._endpoint = endpoint; + this._apiKey = apiKey; + this._maxRetryCount = maxRetryCount; + this._apiVersion = apiVersion; + } + + /// + /// Create a new instance of Azure OpenAI image generation service + /// + /// Azure OpenAI API key, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart + /// Custom for HTTP requests. + /// Azure OpenAI deployment URL, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart + /// Application logger + /// Maximum number of attempts to retrieve the image generation operation result. + /// Azure OpenAI Endpoint ApiVersion + public AzureOpenAIImageGeneration(string apiKey, HttpClient httpClient, string? endpoint = null, ILogger? logger = null, int maxRetryCount = 5, string apiVersion = "2023-06-01-preview") : base(httpClient, logger) + { + Verify.NotNull(httpClient); + Verify.NotNullOrWhiteSpace(apiKey); + + if (httpClient.BaseAddress == null && string.IsNullOrEmpty(endpoint)) + { + throw new AIException( + AIException.ErrorCodes.InvalidConfiguration, + "The HttpClient BaseAddress and endpoint are both null or empty. Please ensure at least one is provided."); + } + + endpoint = !string.IsNullOrEmpty(endpoint) ? endpoint! : httpClient.BaseAddress!.AbsoluteUri; + Verify.StartsWith(endpoint, "https://", "The Azure OpenAI endpoint must start with 'https://'"); + + this._endpoint = endpoint; + this._apiKey = apiKey; + this._maxRetryCount = maxRetryCount; + this._apiVersion = apiVersion; + } + + /// + public async Task GenerateImageAsync(string description, int width, int height, CancellationToken cancellationToken = default) + { + var operationId = await this.StartImageGenerationAsync(description, width, height, cancellationToken).ConfigureAwait(false); + var result = await this.GetImageGenerationResultAsync(operationId, cancellationToken).ConfigureAwait(false); + + if (result.Result == null) + { + throw new AzureSdk.OpenAIInvalidResponseException(null, "Azure Image Generation null response"); + } + + if (result.Result.Images.Count == 0) + { + throw new AzureSdk.OpenAIInvalidResponseException(result, "Azure Image Generation result not found"); + } + + return result.Result.Images.First().Url; + } + + /// + /// Start an image generation task + /// + /// Image description + /// Image width in pixels + /// Image height in pixels + /// The to monitor for cancellation requests. The default is . + /// The operationId that identifies the original image generation request. + private async Task StartImageGenerationAsync(string description, int width, int height, CancellationToken cancellationToken = default) + { + Verify.NotNull(description); + if (width != height || (width != 256 && width != 512 && width != 1024)) + { + throw new ArgumentOutOfRangeException(nameof(width), width, "OpenAI can generate only square images of size 256x256, 512x512, or 1024x1024."); + } + + var requestBody = Json.Serialize(new ImageGenerationRequest + { + Prompt = description, + Size = $"{width}x{height}", + Count = 1 + }); + + var uri = this.GetUri(GenerationImageOperation); + var result = await this.ExecutePostRequestAsync(uri, requestBody, cancellationToken).ConfigureAwait(false); + + if (result == null || string.IsNullOrWhiteSpace(result.Id)) + { + throw new AIException(AIException.ErrorCodes.InvalidResponseContent, "Response not contains result"); + } + + return result.Id; + } + + /// + /// Retrieve the results of an image generation operation. + /// + /// The operationId that identifies the original image generation request. + /// The to monitor for cancellation requests. The default is . + /// + private async Task GetImageGenerationResultAsync(string operationId, CancellationToken cancellationToken = default) + { + var operationLocation = this.GetUri(GetImageOperation, operationId); + + var retryCount = 0; + try + { + while (true) + { + if (this._maxRetryCount == retryCount) + { + throw new AIException(AIException.ErrorCodes.RequestTimeout, "Reached maximum retry attempts"); + } + + using var response = await this.ExecuteRequestAsync(operationLocation, HttpMethod.Get, null, cancellationToken).ConfigureAwait(false); + var responseJson = await response.Content.ReadAsStringAsync().ConfigureAwait(false); + var result = this.JsonDeserialize(responseJson); + + if (result.Status.Equals(AzureImageOperationStatus.Succeeded, StringComparison.OrdinalIgnoreCase)) + { + return result; + } + else if (this.IsFailedOrCancelled(result.Status)) + { + throw new AzureSdk.OpenAIInvalidResponseException(result, $"Azure OpenAI image generation {result.Status}"); + } + + if (response.Headers.TryGetValues("retry-after", out var afterValues) && long.TryParse(afterValues.FirstOrDefault(), out var after)) + { + await Task.Delay(TimeSpan.FromSeconds(after), cancellationToken).ConfigureAwait(false); + } + + // increase retry count + retryCount++; + } + } + catch (Exception e) when (e is not AIException) + { + throw new AIException( + AIException.ErrorCodes.UnknownError, + $"Something went wrong: {e.Message}", e); + } + } + + private string GetUri(string operation, params string[] parameters) + { + var uri = new Azure.Core.RequestUriBuilder(); + uri.Reset(new Uri(this._endpoint)); + uri.AppendPath(operation, false); + foreach (var parameter in parameters) + { + uri.AppendPath("/" + parameter, false); + } + uri.AppendQuery("api-version", this._apiVersion); + return uri.ToString(); + } + + private bool IsFailedOrCancelled(string status) + { + return status.Equals(AzureImageOperationStatus.Failed, StringComparison.OrdinalIgnoreCase) + || status.Equals(AzureImageOperationStatus.Cancelled, StringComparison.OrdinalIgnoreCase) + || status.Equals(AzureImageOperationStatus.Deleted, StringComparison.OrdinalIgnoreCase); + } + + /// Adds headers to use for Azure OpenAI HTTP requests. + private protected override void AddRequestHeaders(HttpRequestMessage request) + { + request.Headers.Add("api-key", this._apiKey); + } +} diff --git a/dotnet/src/Connectors/Connectors.AI.OpenAI/OpenAIKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.AI.OpenAI/OpenAIKernelBuilderExtensions.cs index b33ee5021816..ab41b39e3c6f 100644 --- a/dotnet/src/Connectors/Connectors.AI.OpenAI/OpenAIKernelBuilderExtensions.cs +++ b/dotnet/src/Connectors/Connectors.AI.OpenAI/OpenAIKernelBuilderExtensions.cs @@ -369,6 +369,37 @@ public static KernelBuilder WithOpenAIImageGenerationService(this KernelBuilder return builder; } + /// + /// Add the Azure OpenAI DallE image generation service to the list + /// + /// The instance + /// Azure OpenAI deployment URL, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart + /// Azure OpenAI API key, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart + /// A local identifier for the given AI service + /// Whether the service should be the default for its type. + /// Custom for HTTP requests. + /// Maximum number of attempts to retrieve the image generation operation result. + /// Self instance + public static KernelBuilder WithAzureOpenAIImageGenerationService(this KernelBuilder builder, + string endpoint, + string apiKey, + string? serviceId = null, + bool setAsDefault = false, + HttpClient? httpClient = null, + int maxRetryCount = 5) + { + builder.WithAIService(serviceId, ((ILogger Logger, KernelConfig Config) parameters) => + new AzureOpenAIImageGeneration( + endpoint, + apiKey, + GetHttpClient(parameters.Config, httpClient, parameters.Logger), + parameters.Logger, + maxRetryCount), + setAsDefault); + + return builder; + } + /// /// Retrieves an instance of HttpClient. /// @@ -387,6 +418,5 @@ private static HttpClient GetHttpClient(KernelConfig config, HttpClient? httpCli return httpClient; } - #endregion } diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Connectors.UnitTests.csproj b/dotnet/src/Connectors/Connectors.UnitTests/Connectors.UnitTests.csproj index a8e186eb6a23..bd8f9069ba7f 100644 --- a/dotnet/src/Connectors/Connectors.UnitTests/Connectors.UnitTests.csproj +++ b/dotnet/src/Connectors/Connectors.UnitTests/Connectors.UnitTests.csproj @@ -41,6 +41,12 @@ Always + + Always + + + Always + diff --git a/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/ImageGeneration/AzureOpenAIImageGenerationTests.cs b/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/ImageGeneration/AzureOpenAIImageGenerationTests.cs new file mode 100644 index 000000000000..4c8ce1784dd3 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/ImageGeneration/AzureOpenAIImageGenerationTests.cs @@ -0,0 +1,78 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Net; +using System.Net.Http; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.Connectors.AI.OpenAI.ImageGeneration; +using Moq; +using Moq.Protected; +using Xunit; + +namespace SemanticKernel.Connectors.UnitTests.OpenAI.ImageGeneration; + +/// +/// Unit tests for class. +/// +public sealed class AzureOpenAIImageGenerationTests +{ + /// + /// Returns a mocked instance of . + /// + /// The to return for image generation. + /// The to return for image result. + /// A mocked instance. + private static HttpClient GetHttpClientMock(HttpResponseMessage generationResult, HttpResponseMessage imageResult) + { + var httpClientHandler = new Mock(); + + httpClientHandler + .Protected() + .Setup>( + "SendAsync", + ItExpr.Is(request => request.RequestUri!.AbsolutePath.Contains("openai/images/generations:submit")), + ItExpr.IsAny()) + .ReturnsAsync(generationResult); + + httpClientHandler + .Protected() + .Setup>( + "SendAsync", + ItExpr.Is(request => request.RequestUri!.AbsolutePath.Contains("openai/operations/images")), + ItExpr.IsAny()) + .ReturnsAsync(imageResult); + + return new HttpClient(httpClientHandler.Object); + } + + /// + /// Creates an instance of to return with test data. + /// + /// The HTTP status code for the response. + /// The name of the test response file. + /// An instance of with the specified test data. + private static HttpResponseMessage CreateResponseMessage(HttpStatusCode statusCode, string fileName) + { + var response = new HttpResponseMessage(statusCode); + response.Content = new StringContent(OpenAITestHelper.GetTestResponse(fileName), Encoding.UTF8, "application/json"); + return response; + } + + [Fact] + public async Task ItShouldGenerateImageSuccussedAsync() + { + //Arrange + using var generateResult = CreateResponseMessage(HttpStatusCode.Accepted, "image_generation_test_response.json"); + using var imageResult = CreateResponseMessage(HttpStatusCode.OK, "image_result_test_response.json"); + using var mockHttpClient = GetHttpClientMock(generateResult, imageResult); + + var generation = new AzureOpenAIImageGeneration("https://fake-endpoint/", "fake-api-key", mockHttpClient); + + //Act + var result = await generation.GenerateImageAsync("description", 256, 256); + + //Assert + Assert.NotNull(result); + } +} diff --git a/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/OpenAITestHelper.cs b/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/OpenAITestHelper.cs new file mode 100644 index 000000000000..f6ee6bb93a11 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/OpenAITestHelper.cs @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.IO; + +namespace SemanticKernel.Connectors.UnitTests.OpenAI; + +/// +/// Helper for OpenAI test purposes. +/// +internal static class OpenAITestHelper +{ + /// + /// Reads test response from file for mocking purposes. + /// + /// Name of the file with test response. + internal static string GetTestResponse(string fileName) + { + return File.ReadAllText($"./OpenAI/TestData/{fileName}"); + } +} diff --git a/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/TestData/image_generation_test_response.json b/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/TestData/image_generation_test_response.json new file mode 100644 index 000000000000..87b9ab7d7cce --- /dev/null +++ b/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/TestData/image_generation_test_response.json @@ -0,0 +1,4 @@ +{ + "id": "32ba9f77-d620-4b6c-9265-ad50cb314a5c", + "status": "notRunning" +} \ No newline at end of file diff --git a/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/TestData/image_result_test_response.json b/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/TestData/image_result_test_response.json new file mode 100644 index 000000000000..61904f1b0a02 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/TestData/image_result_test_response.json @@ -0,0 +1,12 @@ +{ + "created": 1686192127, + "expires": 1686278532, + "id": "32ba9f77-d620-4b6c-9265-ad50cb314a5c", + "result": { + "created": 1686192127, + "data": [ + { "url": "https://dalleproduse.blob.core.windows.net/private/images/generated_00.png" } + ] + }, + "status": "succeeded" +} \ No newline at end of file diff --git a/samples/dotnet/kernel-syntax-examples/Example18_DallE.cs b/samples/dotnet/kernel-syntax-examples/Example18_DallE.cs index 1cdfa1025e45..de3ce14612e3 100644 --- a/samples/dotnet/kernel-syntax-examples/Example18_DallE.cs +++ b/samples/dotnet/kernel-syntax-examples/Example18_DallE.cs @@ -5,6 +5,7 @@ using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.AI.ChatCompletion; using Microsoft.SemanticKernel.AI.ImageGeneration; +using Microsoft.SemanticKernel.Connectors.AI.OpenAI.ChatCompletion; using RepoUtils; /** @@ -16,12 +17,18 @@ public static class Example18_DallE { public static async Task RunAsync() { - Console.WriteLine("======== Dall-E 2 Image Generation ========"); + await OpenAIDallEAsync(); + await AzureOpenAIDallEAsync(); + } + + public static async Task OpenAIDallEAsync() + { + Console.WriteLine("======== OpenAI Dall-E 2 Image Generation ========"); IKernel kernel = new KernelBuilder() .WithLogger(ConsoleLogger.Log) // Add your image generation service - .WithOpenAIImageGenerationService("dallE", Env.Var("OPENAI_API_KEY")) + .WithOpenAIImageGenerationService(Env.Var("OPENAI_API_KEY")) // Add your chat completion service .WithOpenAIChatCompletionService("gpt-3.5-turbo", Env.Var("OPENAI_API_KEY")) .Build(); @@ -82,4 +89,72 @@ A cute baby sea otter */ } + + public static async Task AzureOpenAIDallEAsync() + { + Console.WriteLine("========Azure OpenAI Dall-E 2 Image Generation ========"); + + IKernel kernel = new KernelBuilder() + .WithLogger(ConsoleLogger.Log) + // Add your image generation service + .WithAzureOpenAIImageGenerationService(Env.Var("AZURE_OPENAI_ENDPOINT"), Env.Var("AZURE_OPENAI_API_KEY")) + // Add your chat completion service + .WithAzureChatCompletionService("gpt-35-turbo", Env.Var("AZURE_OPENAI_ENDPOINT"), Env.Var("AZURE_OPENAI_API_KEY")) + .Build(); + + IImageGeneration dallE = kernel.GetService(); + var imageDescription = "A cute baby sea otter"; + var image = await dallE.GenerateImageAsync(imageDescription, 256, 256); + + Console.WriteLine(imageDescription); + Console.WriteLine("Image URL: " + image); + + /* Output: + + A cute baby sea otter + Image URL: https://dalleproduse.blob.core.windows.net/private/images/.... + + */ + + Console.WriteLine("======== Chat with images ========"); + + IChatCompletion chatGPT = kernel.GetService(); + var chatHistory = (OpenAIChatHistory)chatGPT.CreateNewChat( + "You're chatting with a user. Instead of replying directly to the user" + + " provide the description of an image that expresses what you want to say." + + " The user won't see your message, they will see only the image. The system " + + " generates an image using your description, so it's important you describe the image with details."); + + var msg = "Hi, I'm from Tokyo, where are you from?"; + chatHistory.AddUserMessage(msg); + Console.WriteLine("User: " + msg); + + string reply = await chatGPT.GenerateMessageAsync(chatHistory); + chatHistory.AddAssistantMessage(reply); + image = await dallE.GenerateImageAsync(reply, 256, 256); + Console.WriteLine("Bot: " + image); + Console.WriteLine("Img description: " + reply); + + msg = "Oh, wow. Not sure where that is, could you provide more details?"; + chatHistory.AddUserMessage(msg); + Console.WriteLine("User: " + msg); + + reply = await chatGPT.GenerateMessageAsync(chatHistory); + chatHistory.AddAssistantMessage(reply); + image = await dallE.GenerateImageAsync(reply, 256, 256); + Console.WriteLine("Bot: " + image); + Console.WriteLine("Img description: " + reply); + + /* Output: + + User: Hi, I'm from Tokyo, where are you from? + Bot: https://dalleproduse.blob.core.windows.net/private/images/...... + Img description: [An image of a globe with a pin dropped on a location in the middle of the ocean] + + User: Oh, wow. Not sure where that is, could you provide more details? + Bot: https://dalleproduse.blob.core.windows.net/private/images/...... + Img description: [An image of a map zooming in on the pin location, revealing a small island with a palm tree on it] + + */ + } }