Skip to content

Commit

Permalink
.Net. ContentFilterResults added to text/message content metadata (#5020
Browse files Browse the repository at this point in the history
)

### Motivation and Context
Today, the code for populating metadata for message/text content does
not add the ContentFilterResults property value from individual choices
into the content metadata.

Closes #4996

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [ ] The code builds clean without any errors or warnings
- [ ] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [ ] All unit tests pass, and I have added new tests where possible
- [ ] I didn't break anyone 😄
  • Loading branch information
SergeyMenshykh committed Feb 19, 2024
1 parent 4e7009a commit a9c4535
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 27 deletions.
24 changes: 11 additions & 13 deletions dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ internal ClientCore(ILogger? logger = null)
}

this.CaptureUsageDetails(responseData.Usage);
IReadOnlyDictionary<string, object?> metadata = GetResponseMetadata(responseData);
return responseData.Choices.Select(choice => new TextContent(choice.Text, this.DeploymentOrModelName, choice, Encoding.UTF8, metadata)).ToList();

return responseData.Choices.Select(choice => new TextContent(choice.Text, this.DeploymentOrModelName, choice, Encoding.UTF8, GetChoiceMetadata(responseData, choice))).ToList();
}

internal async IAsyncEnumerable<StreamingTextContent> GetStreamingTextContentsAsync(
Expand All @@ -154,37 +154,37 @@ internal ClientCore(ILogger? logger = null)

StreamingResponse<Completions>? response = await RunRequestAsync(() => this.Client.GetCompletionsStreamingAsync(options, cancellationToken)).ConfigureAwait(false);

IReadOnlyDictionary<string, object?>? metadata = null;
await foreach (Completions completions in response)
{
metadata ??= GetResponseMetadata(completions);
foreach (Choice choice in completions.Choices)
{
yield return new OpenAIStreamingTextContent(choice.Text, choice.Index, this.DeploymentOrModelName, choice, metadata);
yield return new OpenAIStreamingTextContent(choice.Text, choice.Index, this.DeploymentOrModelName, choice, GetChoiceMetadata(completions, choice));
}
}
}

private static Dictionary<string, object?> GetResponseMetadata(Completions completions)
private static Dictionary<string, object?> GetChoiceMetadata(Completions completions, Choice choice)
{
return new Dictionary<string, object?>(4)
return new Dictionary<string, object?>(5)
{
{ nameof(completions.Id), completions.Id },
{ nameof(completions.Created), completions.Created },
{ nameof(completions.PromptFilterResults), completions.PromptFilterResults },
{ nameof(completions.Usage), completions.Usage },
{ nameof(choice.ContentFilterResults), choice.ContentFilterResults },
};
}

private static Dictionary<string, object?> GetResponseMetadata(ChatCompletions completions)
private static Dictionary<string, object?> GetChatChoiceMetadata(ChatCompletions completions, ChatChoice chatChoice)
{
return new Dictionary<string, object?>(5)
return new Dictionary<string, object?>(6)
{
{ nameof(completions.Id), completions.Id },
{ nameof(completions.Created), completions.Created },
{ nameof(completions.PromptFilterResults), completions.PromptFilterResults },
{ nameof(completions.SystemFingerprint), completions.SystemFingerprint },
{ nameof(completions.Usage), completions.Usage },
{ nameof(chatChoice.ContentFilterResults), chatChoice.ContentFilterResults },
};
}

Expand Down Expand Up @@ -303,13 +303,11 @@ await foreach (Completions completions in response)
throw new KernelException("Chat completions not found");
}

IReadOnlyDictionary<string, object?> metadata = GetResponseMetadata(responseData);

// If we don't want to attempt to invoke any functions, just return the result.
// Or if we are auto-invoking but we somehow end up with other than 1 choice even though only 1 was requested, similarly bail.
if (!autoInvoke || responseData.Choices.Count != 1)
{
return responseData.Choices.Select(chatChoice => new OpenAIChatMessageContent(chatChoice.Message, this.DeploymentOrModelName, metadata)).ToList();
return responseData.Choices.Select(chatChoice => new OpenAIChatMessageContent(chatChoice.Message, this.DeploymentOrModelName, GetChatChoiceMetadata(responseData, chatChoice))).ToList();
}

Debug.Assert(kernel is not null);
Expand All @@ -320,7 +318,7 @@ await foreach (Completions completions in response)
// may return a FinishReason of "stop" even if there are tool calls to be made, in particular if a required tool
// is specified.
ChatChoice resultChoice = responseData.Choices[0];
OpenAIChatMessageContent result = new(resultChoice.Message, this.DeploymentOrModelName, metadata);
OpenAIChatMessageContent result = new(resultChoice.Message, this.DeploymentOrModelName, GetChatChoiceMetadata(responseData, resultChoice));
if (result.ToolCalls.Count == 0)
{
return new[] { result };
Expand Down
27 changes: 27 additions & 0 deletions dotnet/src/IntegrationTests/BaseIntegrationTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Net;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Http.Resilience;
using Microsoft.SemanticKernel;

namespace SemanticKernel.IntegrationTests;
public class BaseIntegrationTest
{
protected IKernelBuilder CreateKernelBuilder()
{
var builder = Kernel.CreateBuilder();

builder.Services.ConfigureHttpClientDefaults(c =>
{
c.AddStandardResilienceHandler().Configure(o =>
{
o.Retry.ShouldRetryAfterHeader = true;
o.Retry.ShouldHandle = args => ValueTask.FromResult(args.Outcome.Result?.StatusCode is HttpStatusCode.TooManyRequests);
});
});

return builder;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -247,39 +247,45 @@ public async Task AzureOpenAIHttpRetryPolicyTestAsync(string prompt, string expe
[Theory]
[InlineData(false)]
[InlineData(true)]
public async Task AzureOpenAIShouldReturnTokenUsageInMetadataAsync(bool useChatModel)
public async Task AzureOpenAIShouldReturnMetadataAsync(bool useChatModel)
{
// Arrange
this._kernelBuilder.Services.AddSingleton<ILoggerFactory>(this._logger);
var builder = this._kernelBuilder;

if (useChatModel)
{
this.ConfigureAzureOpenAIChatAsText(builder);
this.ConfigureAzureOpenAIChatAsText(this._kernelBuilder);
}
else
{
this.ConfigureAzureOpenAI(builder);
this.ConfigureAzureOpenAI(this._kernelBuilder);
}

Kernel target = builder.Build();
var kernel = this._kernelBuilder.Build();

IReadOnlyKernelPluginCollection plugin = TestHelpers.ImportSamplePlugins(target, "FunPlugin");
var plugin = TestHelpers.ImportSamplePlugins(kernel, "FunPlugin");

// Act and Assert
FunctionResult result = await target.InvokeAsync(plugin["FunPlugin"]["Limerick"]);
// Act
var result = await kernel.InvokeAsync(plugin["FunPlugin"]["Limerick"]);

// Assert
Assert.NotNull(result.Metadata);

// Usage
Assert.True(result.Metadata.TryGetValue("Usage", out object? usageObject));
Assert.NotNull(usageObject);

var jsonObject = JsonSerializer.SerializeToElement(usageObject);
Assert.True(jsonObject.TryGetProperty("PromptTokens", out JsonElement promptTokensJson));
Assert.True(promptTokensJson.TryGetInt32(out int promptTokens));
Assert.NotEqual(0, promptTokens);

Assert.True(jsonObject.TryGetProperty("CompletionTokens", out JsonElement completionTokensJson));
Assert.True(completionTokensJson.TryGetInt32(out int completionTokens));
Assert.NotEqual(0, completionTokens);

// ContentFilterResults
Assert.True(result.Metadata.ContainsKey("ContentFilterResults"));
}

[Fact]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

namespace SemanticKernel.IntegrationTests.Connectors.OpenAI;

public sealed class OpenAIToolsTests : IDisposable
public sealed class OpenAIToolsTests : BaseIntegrationTest, IDisposable
{
public OpenAIToolsTests(ITestOutputHelper output)
{
Expand Down Expand Up @@ -128,7 +128,7 @@ private Kernel InitializeKernel()
OpenAIConfiguration? openAIConfiguration = this._configuration.GetSection("Planners:OpenAI").Get<OpenAIConfiguration>();
Assert.NotNull(openAIConfiguration);

IKernelBuilder builder = Kernel.CreateBuilder()
IKernelBuilder builder = this.CreateKernelBuilder()
.AddOpenAIChatCompletion(
modelId: openAIConfiguration.ModelId,
apiKey: openAIConfiguration.ApiKey);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
using Xunit.Abstractions;

namespace SemanticKernel.IntegrationTests.Planners.Stepwise;
public sealed class FunctionCallingStepwisePlannerTests : IDisposable
public sealed class FunctionCallingStepwisePlannerTests : BaseIntegrationTest, IDisposable
{
private readonly string _bingApiKey;

Expand Down Expand Up @@ -136,11 +136,11 @@ private Kernel InitializeKernel()
OpenAIConfiguration? openAIConfiguration = this._configuration.GetSection("Planners:OpenAI").Get<OpenAIConfiguration>();
Assert.NotNull(openAIConfiguration);

IKernelBuilder builder = Kernel.CreateBuilder();
IKernelBuilder builder = this.CreateKernelBuilder();
builder.Services.AddSingleton<ILoggerFactory>(this._logger);
builder.AddOpenAIChatCompletion(
modelId: openAIConfiguration.ModelId,
apiKey: openAIConfiguration.ApiKey);
modelId: openAIConfiguration.ModelId,
apiKey: openAIConfiguration.ApiKey);

var kernel = builder.Build();

Expand Down

0 comments on commit a9c4535

Please sign in to comment.