Skip to content

Commit

Permalink
.Net: Add Support for 'dall-e-3' Model in OpenAIImageGeneration Class (
Browse files Browse the repository at this point in the history
…#6623)

### Motivation and Context
Closes #3435 

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [ ] The code builds clean without any errors or warnings
- [ ] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [ ] All unit tests pass, and I have added new tests where possible
- [ ] I didn't break anyone 😄

Co-authored-by: SergeyMenshykh <68852919+SergeyMenshykh@users.noreply.github.com>
  • Loading branch information
markwallace-microsoft and SergeyMenshykh committed Jun 12, 2024
1 parent 3c5e053 commit ddf1d46
Show file tree
Hide file tree
Showing 10 changed files with 174 additions and 8 deletions.
1 change: 1 addition & 0 deletions .github/_typos.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ ans = "ans" # Short for answers
arange = "arange" # Method in Python numpy package
prompty = "prompty" # prompty is a format name.
ist = "ist" # German language
dall = "dall" # OpenAI model name

[default.extend-identifiers]
ags = "ags" # Azure Graph Service
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,53 @@
<Right>lib/net8.0/Microsoft.SemanticKernel.Connectors.OpenAI.dll</Right>
<IsBaselineSuppression>true</IsBaselineSuppression>
</Suppression>
<Suppression>
<DiagnosticId>CP0002</DiagnosticId>
<Target>M:Microsoft.SemanticKernel.Connectors.OpenAI.OpenAITextToImageService.#ctor(System.String,System.String,System.Net.Http.HttpClient,Microsoft.Extensions.Logging.ILoggerFactory)</Target>
<Left>lib/net8.0/Microsoft.SemanticKernel.Connectors.OpenAI.dll</Left>
<Right>lib/net8.0/Microsoft.SemanticKernel.Connectors.OpenAI.dll</Right>
<IsBaselineSuppression>true</IsBaselineSuppression>
</Suppression>
<Suppression>
<DiagnosticId>CP0002</DiagnosticId>
<Target>M:Microsoft.SemanticKernel.OpenAIServiceCollectionExtensions.AddOpenAITextToImage(Microsoft.Extensions.DependencyInjection.IServiceCollection,System.String,System.String,System.String)</Target>
<Left>lib/net8.0/Microsoft.SemanticKernel.Connectors.OpenAI.dll</Left>
<Right>lib/net8.0/Microsoft.SemanticKernel.Connectors.OpenAI.dll</Right>
<IsBaselineSuppression>true</IsBaselineSuppression>
</Suppression>
<Suppression>
<DiagnosticId>CP0002</DiagnosticId>
<Target>M:Microsoft.SemanticKernel.OpenAIServiceCollectionExtensions.AddOpenAITextToImage(Microsoft.SemanticKernel.IKernelBuilder,System.String,System.String,System.String,System.Net.Http.HttpClient)</Target>
<Left>lib/net8.0/Microsoft.SemanticKernel.Connectors.OpenAI.dll</Left>
<Right>lib/net8.0/Microsoft.SemanticKernel.Connectors.OpenAI.dll</Right>
<IsBaselineSuppression>true</IsBaselineSuppression>
</Suppression>
<Suppression>
<DiagnosticId>CP0002</DiagnosticId>
<Target>M:Microsoft.SemanticKernel.Connectors.OpenAI.OpenAIFileService.GetFileContent(System.String,System.Threading.CancellationToken)</Target>
<Left>lib/netstandard2.0/Microsoft.SemanticKernel.Connectors.OpenAI.dll</Left>
<Right>lib/netstandard2.0/Microsoft.SemanticKernel.Connectors.OpenAI.dll</Right>
<IsBaselineSuppression>true</IsBaselineSuppression>
</Suppression>
<Suppression>
<DiagnosticId>CP0002</DiagnosticId>
<Target>M:Microsoft.SemanticKernel.Connectors.OpenAI.OpenAITextToImageService.#ctor(System.String,System.String,System.Net.Http.HttpClient,Microsoft.Extensions.Logging.ILoggerFactory)</Target>
<Left>lib/netstandard2.0/Microsoft.SemanticKernel.Connectors.OpenAI.dll</Left>
<Right>lib/netstandard2.0/Microsoft.SemanticKernel.Connectors.OpenAI.dll</Right>
<IsBaselineSuppression>true</IsBaselineSuppression>
</Suppression>
<Suppression>
<DiagnosticId>CP0002</DiagnosticId>
<Target>M:Microsoft.SemanticKernel.OpenAIServiceCollectionExtensions.AddOpenAITextToImage(Microsoft.Extensions.DependencyInjection.IServiceCollection,System.String,System.String,System.String)</Target>
<Left>lib/netstandard2.0/Microsoft.SemanticKernel.Connectors.OpenAI.dll</Left>
<Right>lib/netstandard2.0/Microsoft.SemanticKernel.Connectors.OpenAI.dll</Right>
<IsBaselineSuppression>true</IsBaselineSuppression>
</Suppression>
<Suppression>
<DiagnosticId>CP0002</DiagnosticId>
<Target>M:Microsoft.SemanticKernel.OpenAIServiceCollectionExtensions.AddOpenAITextToImage(Microsoft.SemanticKernel.IKernelBuilder,System.String,System.String,System.String,System.Net.Http.HttpClient)</Target>
<Left>lib/netstandard2.0/Microsoft.SemanticKernel.Connectors.OpenAI.dll</Left>
<Right>lib/netstandard2.0/Microsoft.SemanticKernel.Connectors.OpenAI.dll</Right>
<IsBaselineSuppression>true</IsBaselineSuppression>
</Suppression>
</Suppressions>
Original file line number Diff line number Diff line change
Expand Up @@ -1309,6 +1309,7 @@ public static IServiceCollection AddAzureOpenAITextToImage(
/// <param name="builder">The <see cref="IKernelBuilder"/> instance to augment.</param>
/// <param name="apiKey">OpenAI API key, see https://platform.openai.com/account/api-keys</param>
/// <param name="orgId">OpenAI organization id. This is usually optional unless your account belongs to multiple organizations.</param>
/// <param name="modelId">The model to use for image generation.</param>
/// <param name="serviceId">A local identifier for the given AI service</param>
/// <param name="httpClient">The HttpClient to use with this service.</param>
/// <returns>The same instance as <paramref name="builder"/>.</returns>
Expand All @@ -1317,6 +1318,7 @@ public static IKernelBuilder AddOpenAITextToImage(
this IKernelBuilder builder,
string apiKey,
string? orgId = null,
string? modelId = null,
string? serviceId = null,
HttpClient? httpClient = null)
{
Expand All @@ -1327,6 +1329,7 @@ public static IKernelBuilder AddOpenAITextToImage(
new OpenAITextToImageService(
apiKey,
orgId,
modelId,
HttpClientProvider.GetHttpClient(httpClient, serviceProvider),
serviceProvider.GetService<ILoggerFactory>()));

Expand All @@ -1339,12 +1342,14 @@ public static IKernelBuilder AddOpenAITextToImage(
/// <param name="services">The <see cref="IServiceCollection"/> instance to augment.</param>
/// <param name="apiKey">OpenAI API key, see https://platform.openai.com/account/api-keys</param>
/// <param name="orgId">OpenAI organization id. This is usually optional unless your account belongs to multiple organizations.</param>
/// <param name="modelId">The model to use for image generation.</param>
/// <param name="serviceId">A local identifier for the given AI service</param>
/// <returns>The same instance as <paramref name="services"/>.</returns>
[Experimental("SKEXP0010")]
public static IServiceCollection AddOpenAITextToImage(this IServiceCollection services,
string apiKey,
string? orgId = null,
string? modelId = null,
string? serviceId = null)
{
Verify.NotNull(services);
Expand All @@ -1354,6 +1359,7 @@ public static IServiceCollection AddOpenAITextToImage(this IServiceCollection se
new OpenAITextToImageService(
apiKey,
orgId,
modelId,
HttpClientProvider.GetHttpClient(serviceProvider),
serviceProvider.GetService<ILoggerFactory>()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.Services;
using Microsoft.SemanticKernel.TextToImage;

namespace Microsoft.SemanticKernel.Connectors.OpenAI;
Expand Down Expand Up @@ -35,25 +36,37 @@ public sealed class OpenAITextToImageService : ITextToImageService
/// </summary>
private readonly string _authorizationHeaderValue;

/// <summary>
/// The model to use for image generation.
/// </summary>
private readonly string? _modelId;

/// <summary>
/// Initializes a new instance of the <see cref="OpenAITextToImageService"/> class.
/// </summary>
/// <param name="apiKey">OpenAI API key, see https://platform.openai.com/account/api-keys</param>
/// <param name="organization">OpenAI organization id. This is usually optional unless your account belongs to multiple organizations.</param>
/// <param name="modelId">The model to use for image generation.</param>
/// <param name="httpClient">Custom <see cref="HttpClient"/> for HTTP requests.</param>
/// <param name="loggerFactory">The <see cref="ILoggerFactory"/> to use for logging. If null, no logging will be performed.</param>
public OpenAITextToImageService(
string apiKey,
string? organization = null,
string? modelId = null,
HttpClient? httpClient = null,
ILoggerFactory? loggerFactory = null)
{
Verify.NotNullOrWhiteSpace(apiKey);
this._authorizationHeaderValue = $"Bearer {apiKey}";
this._organizationHeaderValue = organization;
this._modelId = modelId;

this._core = new(httpClient, loggerFactory?.CreateLogger(this.GetType()));
this._core.AddAttribute(OpenAIClientCore.OrganizationKey, organization);
if (modelId is not null)
{
this._core.AddAttribute(AIServiceExtensions.ModelIdKey, modelId);
}

this._core.RequestCreated += (_, request) =>
{
Expand All @@ -77,10 +90,11 @@ public Task<string> GenerateImageAsync(string description, int width, int height
throw new ArgumentOutOfRangeException(nameof(width), width, "OpenAI can generate only square images of size 256x256, 512x512, or 1024x1024.");
}

return this.GenerateImageAsync(description, width, height, "url", x => x.Url, cancellationToken);
return this.GenerateImageAsync(this._modelId, description, width, height, "url", x => x.Url, cancellationToken);
}

private async Task<string> GenerateImageAsync(
string? model,
string description,
int width, int height,
string format, Func<TextToImageResponse.Image, string> extractResponse,
Expand All @@ -90,6 +104,7 @@ private async Task<string> GenerateImageAsync(

var requestBody = JsonSerializer.Serialize(new TextToImageRequest
{
Model = model,
Prompt = description,
Size = $"{width}x{height}",
Count = 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,34 @@ namespace Microsoft.SemanticKernel.Connectors.OpenAI;
/// </summary>
internal sealed class TextToImageRequest
{
/// <summary>
/// Model to use for image generation
/// </summary>
[JsonPropertyName("model")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? Model { get; set; }

/// <summary>
/// Image prompt
/// </summary>
[JsonPropertyName("prompt")]
[JsonPropertyOrder(1)]
public string Prompt { get; set; } = string.Empty;

/// <summary>
/// Image size
/// </summary>
[JsonPropertyName("size")]
[JsonPropertyOrder(2)]
public string Size { get; set; } = "256x256";

/// <summary>
/// How many images to generate
/// </summary>
[JsonPropertyName("n")]
[JsonPropertyOrder(3)]
public int Count { get; set; } = 1;

/// <summary>
/// Image format, "url" or "b64_json"
/// </summary>
[JsonPropertyName("response_format")]
[JsonPropertyOrder(4)]
public string Format { get; set; } = "url";
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public void ConstructorWorksCorrectly(bool includeLoggerFactory)
// Assert
Assert.NotNull(service);
Assert.Equal("organization", service.Attributes["Organization"]);
Assert.False(service.Attributes.ContainsKey("ModelId"));
}

[Theory]
Expand All @@ -51,7 +52,8 @@ public void ConstructorWorksCorrectly(bool includeLoggerFactory)
public async Task GenerateImageWorksCorrectlyAsync(int width, int height, bool expectedException)
{
// Arrange
var service = new OpenAITextToImageService("api-key", "organization", this._httpClient);
var service = new OpenAITextToImageService("api-key", "organization", "dall-e-3", this._httpClient);
Assert.Equal("dall-e-3", service.Attributes["ModelId"]);
this._messageHandlerStub.ResponseToReturn = new HttpResponseMessage(System.Net.HttpStatusCode.OK)
{
Content = new StringContent("""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Threading.Tasks;
using Microsoft.Extensions.Configuration;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.TextToImage;
using SemanticKernel.IntegrationTests.TestSettings;
using Xunit;

namespace SemanticKernel.IntegrationTests.Connectors.OpenAI;
public sealed class OpenAITextToImageTests
{
private readonly IConfigurationRoot _configuration = new ConfigurationBuilder()
.AddJsonFile(path: "testsettings.json", optional: false, reloadOnChange: true)
.AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true)
.AddEnvironmentVariables()
.AddUserSecrets<OpenAITextToAudioTests>()
.Build();

[Fact(Skip = "This test is for manual verification.")]
public async Task OpenAITextToImageTestAsync()
{
// Arrange
OpenAIConfiguration? openAIConfiguration = this._configuration.GetSection("OpenAITextToImage").Get<OpenAIConfiguration>();
Assert.NotNull(openAIConfiguration);

var kernel = Kernel.CreateBuilder()
.AddOpenAITextToImage(apiKey: openAIConfiguration.ApiKey)
.Build();

var service = kernel.GetRequiredService<ITextToImageService>();

// Act
var result = await service.GenerateImageAsync("The sun rises in the east and sets in the west.", 512, 512);

// Assert
Assert.NotNull(result);
Assert.NotEmpty(result);
}

[Fact(Skip = "This test is for manual verification.")]
public async Task OpenAITextToImageByModelTestAsync()
{
// Arrange
OpenAIConfiguration? openAIConfiguration = this._configuration.GetSection("OpenAITextToImage").Get<OpenAIConfiguration>();
Assert.NotNull(openAIConfiguration);

var kernel = Kernel.CreateBuilder()
.AddOpenAITextToImage(apiKey: openAIConfiguration.ApiKey, modelId: openAIConfiguration.ModelId)
.Build();

var service = kernel.GetRequiredService<ITextToImageService>();

// Act
var result = await service.GenerateImageAsync("The sun rises in the east and sets in the west.", 1024, 1024);

// Assert
Assert.NotNull(result);
Assert.NotEmpty(result);
}

[Fact(Skip = "This test is for manual verification.")]
public async Task AzureOpenAITextToImageTestAsync()
{
// Arrange
AzureOpenAIConfiguration? azureOpenAIConfiguration = this._configuration.GetSection("AzureOpenAITextToImage").Get<AzureOpenAIConfiguration>();
Assert.NotNull(azureOpenAIConfiguration);

var kernel = Kernel.CreateBuilder()
.AddAzureOpenAITextToImage(
azureOpenAIConfiguration.DeploymentName,
azureOpenAIConfiguration.Endpoint,
azureOpenAIConfiguration.ApiKey)
.Build();

var service = kernel.GetRequiredService<ITextToImageService>();

// Act
var result = await service.GenerateImageAsync("The sun rises in the east and sets in the west.", 1024, 1024);

// Assert
Assert.NotNull(result);
Assert.NotEmpty(result);
}
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.SemanticKernel;
using System.IO;
using System.Text.Json.Nodes;
using System.Threading.Tasks;
using Microsoft.SemanticKernel;
using Xunit;

namespace SemanticKernel.IntegrationTests.CrossLanguage;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.SemanticKernel;
using System;
using System.IO;
using System.Text.Json.Nodes;
using System.Threading.Tasks;
using Microsoft.SemanticKernel;
using Xunit;

namespace SemanticKernel.IntegrationTests.CrossLanguage;
Expand Down
12 changes: 12 additions & 0 deletions dotnet/src/IntegrationTests/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,16 @@ To set your secrets with Secret Manager:
cd dotnet/src/IntegrationTests
dotnet user-secrets init
dotnet user-secrets set "OpenAI:ServiceId" "text-davinci-003"
dotnet user-secrets set "OpenAI:ModelId" "text-davinci-003"
dotnet user-secrets set "OpenAI:ChatModelId" "gpt-4"
dotnet user-secrets set "OpenAI:ApiKey" "..."
dotnet user-secrets set "OpenAITextToImage:ServiceId" "dall-e-3"
dotnet user-secrets set "OpenAITextToImage:ModelId" "dall-e-3"
dotnet user-secrets set "OpenAITextToImage:ApiKey" "..."
dotnet user-secrets set "AzureOpenAI:ServiceId" "azure-text-davinci-003"
dotnet user-secrets set "AzureOpenAI:DeploymentName" "text-davinci-003"
dotnet user-secrets set "AzureOpenAI:ChatDeploymentName" "gpt-4"
Expand All @@ -45,14 +50,21 @@ dotnet user-secrets set "AzureOpenAIEmbeddings:DeploymentName" "text-embedding-a
dotnet user-secrets set "AzureOpenAIEmbeddings:Endpoint" "https://contoso.openai.azure.com/"
dotnet user-secrets set "AzureOpenAIEmbeddings:ApiKey" "..."
dotnet user-secrets set "AzureOpenAIAudioToText:ServiceId" "azure-audio-to-text"
dotnet user-secrets set "AzureOpenAIAudioToText:DeploymentName" "whisper-1"
dotnet user-secrets set "AzureOpenAIAudioToText:Endpoint" "https://contoso.openai.azure.com/"
dotnet user-secrets set "AzureOpenAIAudioToText:ApiKey" "..."
dotnet user-secrets set "AzureOpenAITextToAudio:ServiceId" "azure-text-to-audio"
dotnet user-secrets set "AzureOpenAITextToAudio:DeploymentName" "tts-1"
dotnet user-secrets set "AzureOpenAITextToAudio:Endpoint" "https://contoso.openai.azure.com/"
dotnet user-secrets set "AzureOpenAITextToAudio:ApiKey" "..."
dotnet user-secrets set "AzureOpenAITextToImage:ServiceId" "azure-text-to-image"
dotnet user-secrets set "AzureOpenAITextToImage:DeploymentName" "dall-e-3"
dotnet user-secrets set "AzureOpenAITextToImage:Endpoint" "https://contoso.openai.azure.com/"
dotnet user-secrets set "AzureOpenAITextToImage:ApiKey" "..."
dotnet user-secrets set "MistralAI:ChatModel" "mistral-large-latest"
dotnet user-secrets set "MistralAI:EmbeddingModel" "mistral-embed"
dotnet user-secrets set "MistralAI:ApiKey" "..."
Expand Down

0 comments on commit ddf1d46

Please sign in to comment.