Skip to content

Commit

Permalink
.Net Fix Add Missing OpenAI Connector Choice properties to Metadata (#…
Browse files Browse the repository at this point in the history
…5655)

## Description

Resolves #5289 

This pull request includes changes to both the
`dotnet/samples/KernelSyntaxExamples/Example43_GetModelResult.cs` and
`dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs` files.
The changes mainly focus on improving the handling of metadata and
simplifying the codebase.

Changes in `Example43_GetModelResult.cs`:

* The method `GetTokenUsageMetadataAsync()` was modified to use implicit
type (`var`) instead of explicit type (`Kernel`) when creating a kernel.
This change simplifies the code and makes it more readable.
* A new method `GetFullModelMetadataAsync()` was added. This method
creates a kernel, defines a function, invokes the function through the
kernel, and displays the results. This addition expands the
functionality of the class.

Changes in `ClientCore.cs`:

* The method `ClientCore()` was modified to use
`GetTextChoiceMetadata()` instead of `GetChoiceMetadata()`. This change
improves the handling of metadata.
* The method `GetStreamingTextContentsAsync()` was also modified to use
`GetTextChoiceMetadata()` instead of `GetChoiceMetadata()`. This change
improves the handling of metadata.
* The methods `GetTextChoiceMetadata()`, `GetChatChoiceMetadata()`, and
`GetResponseMetadata()` were modified to include additional metadata
fields. This change improves the amount of information available in the
metadata.
* The method `AddResponseMessage()` was modified to handle `null` values
for `finishReason`. This change improves the robustness of the code.

Changes in `AzureOpenAIChatCompletionServiceTests.cs`:

* The method `GetStreamingChatMessageContentsWorksCorrectlyAsync()` was
modified to use an enumerator instead of a foreach loop. This change
simplifies the code and makes it more readable.

Changes in `chat_completion_streaming_test_response.txt`:

* The test response was modified to include additional data. This change
improves the accuracy of the tests.
  • Loading branch information
RogerBarreto committed Mar 26, 2024
1 parent 4aeeb9e commit b997dcb
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 31 deletions.
46 changes: 45 additions & 1 deletion dotnet/samples/KernelSyntaxExamples/Example43_GetModelResult.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public async Task GetTokenUsageMetadataAsync()
WriteLine("======== Inline Function Definition + Invocation ========");

// Create kernel
Kernel kernel = Kernel.CreateBuilder()
var kernel = Kernel.CreateBuilder()
.AddOpenAIChatCompletion(
modelId: TestConfiguration.OpenAI.ChatModelId,
apiKey: TestConfiguration.OpenAI.ApiKey)
Expand All @@ -35,6 +35,50 @@ public async Task GetTokenUsageMetadataAsync()
WriteLine();
}

[Fact]
public async Task GetFullModelMetadataAsync()
{
WriteLine("======== Inline Function Definition + Invocation ========");

// Create kernel
var kernel = Kernel.CreateBuilder()
.AddOpenAIChatCompletion(
modelId: TestConfiguration.OpenAI.ChatModelId,
apiKey: TestConfiguration.OpenAI.ApiKey)
.Build();

// Create function
const string FunctionDefinition = "1 + 1 = ?";
KernelFunction myFunction = kernel.CreateFunctionFromPrompt(FunctionDefinition);

// Invoke function through kernel
FunctionResult result = await kernel.InvokeAsync(myFunction);

// Display results
WriteLine(result.GetValue<string>());
WriteLine(result.Metadata?.AsJson());
WriteLine();
}

[Fact]
public async Task GetMetadataFromStreamAsync()
{
var kernel = Kernel.CreateBuilder()
.AddOpenAIChatCompletion(
modelId: TestConfiguration.OpenAI.ChatModelId,
apiKey: TestConfiguration.OpenAI.ApiKey)
.Build();

// Create function
const string FunctionDefinition = "1 + 1 = ?";
KernelFunction myFunction = kernel.CreateFunctionFromPrompt(FunctionDefinition);

await foreach (var content in kernel.InvokeStreamingAsync(myFunction))
{
WriteLine(content.Metadata?.AsJson());
}
}

public Example43_GetModelResult(ITestOutputHelper output) : base(output)
{
}
Expand Down
31 changes: 24 additions & 7 deletions dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ internal ClientCore(ILogger? logger = null)

this.CaptureUsageDetails(responseData.Usage);

return responseData.Choices.Select(choice => new TextContent(choice.Text, this.DeploymentOrModelName, choice, Encoding.UTF8, GetChoiceMetadata(responseData, choice))).ToList();
return responseData.Choices.Select(choice => new TextContent(choice.Text, this.DeploymentOrModelName, choice, Encoding.UTF8, GetTextChoiceMetadata(responseData, choice))).ToList();
}

internal async IAsyncEnumerable<StreamingTextContent> GetStreamingTextContentsAsync(
Expand All @@ -161,43 +161,60 @@ await foreach (Completions completions in response)
{
foreach (Choice choice in completions.Choices)
{
yield return new OpenAIStreamingTextContent(choice.Text, choice.Index, this.DeploymentOrModelName, choice, GetChoiceMetadata(completions, choice));
yield return new OpenAIStreamingTextContent(choice.Text, choice.Index, this.DeploymentOrModelName, choice, GetTextChoiceMetadata(completions, choice));
}
}
}

private static Dictionary<string, object?> GetChoiceMetadata(Completions completions, Choice choice)
private static Dictionary<string, object?> GetTextChoiceMetadata(Completions completions, Choice choice)
{
return new Dictionary<string, object?>(5)
return new Dictionary<string, object?>(8)
{
{ nameof(completions.Id), completions.Id },
{ nameof(completions.Created), completions.Created },
{ nameof(completions.PromptFilterResults), completions.PromptFilterResults },
{ nameof(completions.Usage), completions.Usage },
{ nameof(choice.ContentFilterResults), choice.ContentFilterResults },

// Serialization of this struct behaves as an empty object {}, need to cast to string to avoid it.
{ nameof(choice.FinishReason), choice.FinishReason?.ToString() },

{ nameof(choice.LogProbabilityModel), choice.LogProbabilityModel },
{ nameof(choice.Index), choice.Index },
};
}

private static Dictionary<string, object?> GetChatChoiceMetadata(ChatCompletions completions, ChatChoice chatChoice)
{
return new Dictionary<string, object?>(6)
return new Dictionary<string, object?>(12)
{
{ nameof(completions.Id), completions.Id },
{ nameof(completions.Created), completions.Created },
{ nameof(completions.PromptFilterResults), completions.PromptFilterResults },
{ nameof(completions.SystemFingerprint), completions.SystemFingerprint },
{ nameof(completions.Usage), completions.Usage },
{ nameof(chatChoice.ContentFilterResults), chatChoice.ContentFilterResults },

// Serialization of this struct behaves as an empty object {}, need to cast to string to avoid it.
{ nameof(chatChoice.FinishReason), chatChoice.FinishReason?.ToString() },

{ nameof(chatChoice.FinishDetails), chatChoice.FinishDetails },
{ nameof(chatChoice.LogProbabilityInfo), chatChoice.LogProbabilityInfo },
{ nameof(chatChoice.Index), chatChoice.Index },
{ nameof(chatChoice.Enhancements), chatChoice.Enhancements },
};
}

private static Dictionary<string, object?> GetResponseMetadata(StreamingChatCompletionsUpdate completions)
{
return new Dictionary<string, object?>(3)
return new Dictionary<string, object?>(4)
{
{ nameof(completions.Id), completions.Id },
{ nameof(completions.Created), completions.Created },
{ nameof(completions.SystemFingerprint), completions.SystemFingerprint },

// Serialization of this struct behaves as an empty object {}, need to cast to string to avoid it.
{ nameof(completions.FinishReason), completions.FinishReason?.ToString() },
};
}

Expand Down Expand Up @@ -507,7 +524,7 @@ static void AddResponseMessage(ChatCompletionsOptions chatOptions, ChatHistory c
CompletionsFinishReason finishReason = default;
await foreach (StreamingChatCompletionsUpdate update in response.ConfigureAwait(false))
{
metadata ??= GetResponseMetadata(update);
metadata = GetResponseMetadata(update);
streamedRole ??= update.Role;
finishReason = update.FinishReason ?? default;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ public async Task GetChatMessageContentsWorksCorrectlyAsync(ToolCallBehavior beh
Assert.Equal(55, usage.PromptTokens);
Assert.Equal(100, usage.CompletionTokens);
Assert.Equal(155, usage.TotalTokens);

Assert.Equal("stop", result[0].Metadata?["FinishReason"]);
}

[Fact]
Expand Down Expand Up @@ -417,10 +419,13 @@ public async Task GetStreamingTextContentsWorksCorrectlyAsync()
});

// Act & Assert
await foreach (var chunk in service.GetStreamingTextContentsAsync("Prompt"))
{
Assert.Equal("Test chat streaming response", chunk.Text);
}
var enumerator = service.GetStreamingTextContentsAsync("Prompt").GetAsyncEnumerator();

await enumerator.MoveNextAsync();
Assert.Equal("Test chat streaming response", enumerator.Current.Text);

await enumerator.MoveNextAsync();
Assert.Equal("stop", enumerator.Current.Metadata?["FinishReason"]);
}

[Fact]
Expand All @@ -436,10 +441,13 @@ public async Task GetStreamingChatMessageContentsWorksCorrectlyAsync()
});

// Act & Assert
await foreach (var chunk in service.GetStreamingChatMessageContentsAsync([]))
{
Assert.Equal("Test chat streaming response", chunk.Content);
}
var enumerator = service.GetStreamingChatMessageContentsAsync([]).GetAsyncEnumerator();

await enumerator.MoveNextAsync();
Assert.Equal("Test chat streaming response", enumerator.Current.Content);

await enumerator.MoveNextAsync();
Assert.Equal("stop", enumerator.Current.Metadata?["FinishReason"]);
}

[Fact]
Expand Down Expand Up @@ -472,9 +480,18 @@ public async Task GetStreamingChatMessageContentsWithFunctionCallAsync()
this._messageHandlerStub.ResponsesToReturn = [response1, response2];

// Act & Assert
await foreach (var chunk in service.GetStreamingChatMessageContentsAsync([], settings, kernel))
var enumerator = service.GetStreamingChatMessageContentsAsync([], settings, kernel).GetAsyncEnumerator();

await enumerator.MoveNextAsync();
Assert.Equal("Test chat streaming response", enumerator.Current.Content);
Assert.Equal("tool_calls", enumerator.Current.Metadata?["FinishReason"]);

await enumerator.MoveNextAsync();
Assert.Equal("tool_calls", enumerator.Current.Metadata?["FinishReason"]);

// Keep looping until the end of stream
while (await enumerator.MoveNextAsync())
{
Assert.Equal("Test chat streaming response", chunk.Content);
}

Assert.Equal(2, functionCallCount);
Expand Down Expand Up @@ -546,10 +563,20 @@ public async Task GetStreamingChatMessageContentsWithRequiredFunctionCallAsync()
this._messageHandlerStub.ResponsesToReturn = [response1, response2];

// Act & Assert
await foreach (var chunk in service.GetStreamingChatMessageContentsAsync([], settings, kernel))
{
Assert.Equal("Test chat streaming response", chunk.Content);
}
var enumerator = service.GetStreamingChatMessageContentsAsync([], settings, kernel).GetAsyncEnumerator();

// Function Tool Call Streaming (One Chunk)
await enumerator.MoveNextAsync();
Assert.Equal("Test chat streaming response", enumerator.Current.Content);
Assert.Equal("tool_calls", enumerator.Current.Metadata?["FinishReason"]);

// Chat Completion Streaming (1st Chunk)
await enumerator.MoveNextAsync();
Assert.Null(enumerator.Current.Metadata?["FinishReason"]);

// Chat Completion Streaming (2nd Chunk)
await enumerator.MoveNextAsync();
Assert.Equal("stop", enumerator.Current.Metadata?["FinishReason"]);

Assert.Equal(1, functionCallCount);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,13 @@ public async Task GetStreamingTextContentsWorksCorrectlyAsync()
};

// Act & Assert
await foreach (var chunk in service.GetStreamingTextContentsAsync("Prompt"))
{
Assert.Equal("Test chat streaming response", chunk.Text);
}
var enumerator = service.GetStreamingTextContentsAsync("Prompt").GetAsyncEnumerator();

await enumerator.MoveNextAsync();
Assert.Equal("Test chat streaming response", enumerator.Current.Text);

await enumerator.MoveNextAsync();
Assert.Equal("stop", enumerator.Current.Metadata?["FinishReason"]);
}

[Fact]
Expand All @@ -233,10 +236,13 @@ public async Task GetStreamingChatMessageContentsWorksCorrectlyAsync()
};

// Act & Assert
await foreach (var chunk in service.GetStreamingChatMessageContentsAsync([]))
{
Assert.Equal("Test chat streaming response", chunk.Content);
}
var enumerator = service.GetStreamingChatMessageContentsAsync([]).GetAsyncEnumerator();

await enumerator.MoveNextAsync();
Assert.Equal("Test chat streaming response", enumerator.Current.Content);

await enumerator.MoveNextAsync();
Assert.Equal("stop", enumerator.Current.Metadata?["FinishReason"]);
}

[Fact]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
data: {"id":"response-id","object":"chat.completion.chunk","created":1704212243,"model":"gpt-4","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"Test chat streaming response"},"finish_reason":null}]}
data: {"id":"chatcmpl-96fqQVHGjG9Yzs4ZMB1K6nfy2oEoo","object":"chat.completion.chunk","created":1711377846,"model":"gpt-4-0125-preview","system_fingerprint":"fp_a7daf7c51e","choices":[{"index":0,"delta":{"content":"Test chat streaming response"},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-96fqQVHGjG9Yzs4ZMB1K6nfy2oEoo","object":"chat.completion.chunk","created":1711377846,"model":"gpt-4-0125-preview","system_fingerprint":"fp_a7daf7c51e","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}

data: [DONE]

0 comments on commit b997dcb

Please sign in to comment.