Skip to content

Commit

Permalink
Add Azure Dall-E (#1209)
Browse files Browse the repository at this point in the history
### Motivation and Context

Implement Azure OpenAI Dall E

### Description
Implement Azure OpenAI Image Generation according to the official
[documentation](https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#image-generation)
and the examples provided by Azure OpenAI Studio, so that it can be used
in SK.

Due to the differences between the API of Azure OpenAI Dall-E and OpenAI
Dall-E, some modifications have been made to OpenAIClientBase under
CustomClient to ensure compatibility.

Co-authored-by: Lee Miller <lemiller@microsoft.com>
Co-authored-by: Shawn Callegari <36091529+shawncal@users.noreply.github.com>
Co-authored-by: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com>
  • Loading branch information
4 people committed Jun 9, 2023
1 parent fa3eb27 commit 471ec9c
Show file tree
Hide file tree
Showing 11 changed files with 645 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,111 +114,133 @@ private protected virtual void AddRequestHeaders(HttpRequestMessage request)
/// </summary>
private readonly HttpClient _httpClient;

private async Task<T> ExecutePostRequestAsync<T>(string url, string requestBody, CancellationToken cancellationToken = default)
private protected async Task<T> ExecutePostRequestAsync<T>(string url, string requestBody, CancellationToken cancellationToken = default)
{
try
{
using var content = new StringContent(requestBody, Encoding.UTF8, "application/json");
using var response = await this.ExecuteRequestAsync(url, HttpMethod.Post, content, cancellationToken).ConfigureAwait(false);
string responseJson = await response.Content.ReadAsStringAsync().ConfigureAwait(false);
T result = this.JsonDeserialize<T>(responseJson);
return result;
}
catch (Exception e) when (e is not AIException)
{
throw new AIException(
AIException.ErrorCodes.UnknownError,
$"Something went wrong: {e.Message}", e);
}
}

private protected T JsonDeserialize<T>(string responseJson)
{
var result = Json.Deserialize<T>(responseJson);
if (result is null)
{
throw new AIException(AIException.ErrorCodes.InvalidResponseContent, "Response JSON parse error");
}

return result;
}

private protected async Task<HttpResponseMessage> ExecuteRequestAsync(string url, HttpMethod method, HttpContent? content, CancellationToken cancellationToken = default)
{
HttpResponseMessage? response = null;
try
{
using (var request = new HttpRequestMessage(HttpMethod.Post, url))
using (var request = new HttpRequestMessage(method, url))
{
this.AddRequestHeaders(request);
request.Content = new StringContent(requestBody, Encoding.UTF8, "application/json");
if (content != null)
{
request.Content = content;
}

response = await this._httpClient.SendAsync(request, cancellationToken).ConfigureAwait(false);
}

this._log.LogTrace("HTTP response: {0} {1}", (int)response.StatusCode, response.StatusCode.ToString("G"));

string responseJson = await response.Content.ReadAsStringAsync().ConfigureAwait(false);
string? errorDetail = this.GetErrorMessageFromResponse(responseJson);

if (!response.IsSuccessStatusCode)
if (response.IsSuccessStatusCode)
{
switch ((HttpStatusCodeType)response.StatusCode)
{
case HttpStatusCodeType.BadRequest:
case HttpStatusCodeType.MethodNotAllowed:
case HttpStatusCodeType.NotFound:
case HttpStatusCodeType.NotAcceptable:
case HttpStatusCodeType.Conflict:
case HttpStatusCodeType.Gone:
case HttpStatusCodeType.LengthRequired:
case HttpStatusCodeType.PreconditionFailed:
case HttpStatusCodeType.RequestEntityTooLarge:
case HttpStatusCodeType.RequestUriTooLong:
case HttpStatusCodeType.UnsupportedMediaType:
case HttpStatusCodeType.RequestedRangeNotSatisfiable:
case HttpStatusCodeType.ExpectationFailed:
case HttpStatusCodeType.HttpVersionNotSupported:
case HttpStatusCodeType.UpgradeRequired:
case HttpStatusCodeType.MisdirectedRequest:
case HttpStatusCodeType.UnprocessableEntity:
case HttpStatusCodeType.Locked:
case HttpStatusCodeType.FailedDependency:
case HttpStatusCodeType.PreconditionRequired:
case HttpStatusCodeType.RequestHeaderFieldsTooLarge:
throw new AIException(
AIException.ErrorCodes.InvalidRequest,
$"The request is not valid, HTTP status: {response.StatusCode:G}",
errorDetail);

case HttpStatusCodeType.Unauthorized:
case HttpStatusCodeType.Forbidden:
case HttpStatusCodeType.ProxyAuthenticationRequired:
case HttpStatusCodeType.UnavailableForLegalReasons:
case HttpStatusCodeType.NetworkAuthenticationRequired:
throw new AIException(
AIException.ErrorCodes.AccessDenied,
$"The request is not authorized, HTTP status: {response.StatusCode:G}",
errorDetail);

case HttpStatusCodeType.RequestTimeout:
throw new AIException(
AIException.ErrorCodes.RequestTimeout,
$"The request timed out, HTTP status: {response.StatusCode:G}");

case HttpStatusCodeType.TooManyRequests:
throw new AIException(
AIException.ErrorCodes.Throttling,
$"Too many requests, HTTP status: {response.StatusCode:G}",
errorDetail);

case HttpStatusCodeType.InternalServerError:
case HttpStatusCodeType.NotImplemented:
case HttpStatusCodeType.BadGateway:
case HttpStatusCodeType.ServiceUnavailable:
case HttpStatusCodeType.GatewayTimeout:
case HttpStatusCodeType.InsufficientStorage:
throw new AIException(
AIException.ErrorCodes.ServiceError,
$"The service failed to process the request, HTTP status: {response.StatusCode:G}",
errorDetail);

default:
throw new AIException(
AIException.ErrorCodes.UnknownError,
$"Unexpected HTTP response, status: {response.StatusCode:G}",
errorDetail);
}
return response;
}

var result = Json.Deserialize<T>(responseJson);
if (result is null)
string responseJson = await response.Content.ReadAsStringAsync().ConfigureAwait(false);
string? errorDetail = this.GetErrorMessageFromResponse(responseJson);
switch ((HttpStatusCodeType)response.StatusCode)
{
throw new AIException(AIException.ErrorCodes.InvalidResponseContent, "Response JSON parse error");
case HttpStatusCodeType.BadRequest:
case HttpStatusCodeType.MethodNotAllowed:
case HttpStatusCodeType.NotFound:
case HttpStatusCodeType.NotAcceptable:
case HttpStatusCodeType.Conflict:
case HttpStatusCodeType.Gone:
case HttpStatusCodeType.LengthRequired:
case HttpStatusCodeType.PreconditionFailed:
case HttpStatusCodeType.RequestEntityTooLarge:
case HttpStatusCodeType.RequestUriTooLong:
case HttpStatusCodeType.UnsupportedMediaType:
case HttpStatusCodeType.RequestedRangeNotSatisfiable:
case HttpStatusCodeType.ExpectationFailed:
case HttpStatusCodeType.HttpVersionNotSupported:
case HttpStatusCodeType.UpgradeRequired:
case HttpStatusCodeType.MisdirectedRequest:
case HttpStatusCodeType.UnprocessableEntity:
case HttpStatusCodeType.Locked:
case HttpStatusCodeType.FailedDependency:
case HttpStatusCodeType.PreconditionRequired:
case HttpStatusCodeType.RequestHeaderFieldsTooLarge:
throw new AIException(
AIException.ErrorCodes.InvalidRequest,
$"The request is not valid, HTTP status: {response.StatusCode:G}",
errorDetail);

case HttpStatusCodeType.Unauthorized:
case HttpStatusCodeType.Forbidden:
case HttpStatusCodeType.ProxyAuthenticationRequired:
case HttpStatusCodeType.UnavailableForLegalReasons:
case HttpStatusCodeType.NetworkAuthenticationRequired:
throw new AIException(
AIException.ErrorCodes.AccessDenied,
$"The request is not authorized, HTTP status: {response.StatusCode:G}",
errorDetail);

case HttpStatusCodeType.RequestTimeout:
throw new AIException(
AIException.ErrorCodes.RequestTimeout,
$"The request timed out, HTTP status: {response.StatusCode:G}");

case HttpStatusCodeType.TooManyRequests:
throw new AIException(
AIException.ErrorCodes.Throttling,
$"Too many requests, HTTP status: {response.StatusCode:G}",
errorDetail);

case HttpStatusCodeType.InternalServerError:
case HttpStatusCodeType.NotImplemented:
case HttpStatusCodeType.BadGateway:
case HttpStatusCodeType.ServiceUnavailable:
case HttpStatusCodeType.GatewayTimeout:
case HttpStatusCodeType.InsufficientStorage:
throw new AIException(
AIException.ErrorCodes.ServiceError,
$"The service failed to process the request, HTTP status: {response.StatusCode:G}",
errorDetail);

default:
throw new AIException(
AIException.ErrorCodes.UnknownError,
$"Unexpected HTTP response, status: {response.StatusCode:G}",
errorDetail);
}

return result;
}
catch (Exception e) when (e is not AIException)
{
throw new AIException(
AIException.ErrorCodes.UnknownError,
$"Something went wrong: {e.Message}", e);
}
finally
{
response?.Dispose();
}
}

#endregion
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Text.Json.Serialization;

namespace Microsoft.SemanticKernel.Connectors.AI.OpenAI.ImageGeneration;

/// <summary>
/// Image generation response
/// </summary>
public class AzureImageGenerationResponse
{
/// <summary>
/// Image generation result
/// </summary>
[JsonPropertyName("result")]
public ImageGenerationResponse? Result { get; set; }

/// <summary>
/// Request Id
/// </summary>
[JsonPropertyName("id")]
public string Id { get; set; } = string.Empty;

/// <summary>
/// Request Status
/// </summary>
[JsonPropertyName("status")]
public string Status { get; set; } = string.Empty;

/// <summary>
/// Creation time
/// </summary>
[JsonPropertyName("created")]
public int Created { get; set; }

/// <summary>
/// Expiration time of the URL
/// </summary>
[JsonPropertyName("expires")]
public int Expires { get; set; }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (c) Microsoft. All rights reserved.

namespace Microsoft.SemanticKernel.Connectors.AI.OpenAI.ImageGeneration;

/// <summary>
/// Azure image generation response status
/// <see herf="https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#image-generation" />
/// </summary>
public static class AzureImageOperationStatus
{
/// <summary>
/// Image generation Succeeded
/// </summary>
public const string Succeeded = "succeeded";

/// <summary>
/// Image generation Failed
/// </summary>
public const string Failed = "failed";

/// <summary>
/// Task is running
/// </summary>
public const string Running = "running";

/// <summary>
/// Task is queued but hasn't started yet
/// </summary>
public const string NotRunning = "notRunning";

/// <summary>
/// The image has been removed from Azure's server.
/// </summary>
public const string Deleted = "deleted";

/// <summary>
/// Task has timed out
/// </summary>
public const string Cancelled = "cancelled";
}
Loading

0 comments on commit 471ec9c

Please sign in to comment.