Skip to content

Commit

Permalink
.Net - Add support for Name property to ChatMessageContent (#5666)
Browse files Browse the repository at this point in the history
### Motivation and Context

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

Tracking identity is critical for multi-agent conversations. It is also
supported as part of the core chat-completion api:
https://platform.openai.com/docs/api-reference/chat/create

```
 {
    "messages": [
        {
            "content": "Write one paragraph in response to the user that rhymes",
            "name": "Echo",
            "role": "system"
        },
        {
            "content": "Why is AI awesome",
            "name": "Ralph",
            "role": "user"
        }
    ],
    "temperature": 1,
    "top_p": 0.5,
    "n": 3,
    "presence_penalty": 0,
    "frequency_penalty": 0,
    "model": "gpt-4"
}
```

### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

Add support for `ChatMessageContent.Name` property with optional,
non-breaking patterns.

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄

---------

Co-authored-by: markwallace-microsoft <127216156+markwallace-microsoft@users.noreply.github.com>
  • Loading branch information
crickman and markwallace-microsoft committed Apr 3, 2024
1 parent 67eda98 commit ca9e3ae
Show file tree
Hide file tree
Showing 10 changed files with 198 additions and 42 deletions.
124 changes: 124 additions & 0 deletions dotnet/samples/KernelSyntaxExamples/Example37_CompletionIdentity.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.OpenAI;
using Xunit;
using Xunit.Abstractions;

namespace Examples;

// The following example shows how to use Semantic Kernel with identity associated with each chat message.
public class Example37_CompletionIdentity : BaseTest
{
/// <summary>
/// Flag to force usage of OpenAI configuration if both <see cref="TestConfiguration.OpenAI"/>
/// and <see cref="TestConfiguration.AzureOpenAI"/> are defined.
/// If 'false', Azure takes precedence.
/// </summary>
/// <remarks>
/// NOTE: Retrieval tools is not currently available on Azure.
/// </remarks>
private const bool ForceOpenAI = true;

private static readonly OpenAIPromptExecutionSettings s_executionSettings =
new()
{
FrequencyPenalty = 0,
PresencePenalty = 0,
Temperature = 1,
TopP = 0.5,
};

[Theory]
[InlineData(false)]
[InlineData(true)]
public async Task CompletionIdentityAsync(bool withName)
{
WriteLine("======== Completion Identity ========");

IChatCompletionService chatService = CreateCompletionService();

ChatHistory chatHistory = CreateHistory(withName);

WriteMessages(chatHistory);

WriteMessages(await chatService.GetChatMessageContentsAsync(chatHistory, s_executionSettings), chatHistory);

ValidateMessages(chatHistory, withName);
}

[Theory]
[InlineData(false)]
[InlineData(true)]
public async Task StreamingIdentityAsync(bool withName)
{
WriteLine("======== Completion Identity ========");

IChatCompletionService chatService = CreateCompletionService();

ChatHistory chatHistory = CreateHistory(withName);

var content = await chatHistory.AddStreamingMessageAsync(chatService.GetStreamingChatMessageContentsAsync(chatHistory, s_executionSettings).Cast<OpenAIStreamingChatMessageContent>()).ToArrayAsync();

WriteMessages(chatHistory);

ValidateMessages(chatHistory, withName);
}

private static ChatHistory CreateHistory(bool withName)
{
return
new ChatHistory()
{
new ChatMessageContent(AuthorRole.System, "Write one paragraph in response to the user that rhymes") { AuthorName = withName ? "Echo" : null },
new ChatMessageContent(AuthorRole.User, "Why is AI awesome") { AuthorName = withName ? "Ralph" : null },
};
}

private void ValidateMessages(ChatHistory chatHistory, bool expectName)
{
foreach (var message in chatHistory)
{
if (expectName && message.Role != AuthorRole.Assistant)
{
Assert.NotNull(message.AuthorName);
}
else
{
Assert.Null(message.AuthorName);
}
}
}

private void WriteMessages(IReadOnlyList<ChatMessageContent> messages, ChatHistory? history = null)
{
foreach (var message in messages)
{
WriteLine($"# {message.Role}:{message.AuthorName ?? "?"} - {message.Content ?? "-"}");
}

history?.AddRange(messages);
}

private static IChatCompletionService CreateCompletionService()
{
return
ForceOpenAI || string.IsNullOrEmpty(TestConfiguration.AzureOpenAI.Endpoint) ?
new OpenAIChatCompletionService(
TestConfiguration.OpenAI.ChatModelId,
TestConfiguration.OpenAI.ApiKey) :
new AzureOpenAIChatCompletionService(
deploymentName: TestConfiguration.AzureOpenAI.ChatDeploymentName,
endpoint: TestConfiguration.AzureOpenAI.Endpoint,
apiKey: TestConfiguration.AzureOpenAI.ApiKey,
modelId: TestConfiguration.AzureOpenAI.ChatModelId);
}

public Example37_CompletionIdentity(ITestOutputHelper output) : base(output)
{
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ public static async IAsyncEnumerable<StreamingChatMessageContent> AddStreamingMe
Dictionary<int, string>? functionNamesByIndex = null;
Dictionary<int, StringBuilder>? functionArgumentBuildersByIndex = null;
Dictionary<string, object?>? metadata = null;
AuthorRole? streamedRole = default;
AuthorRole? streamedRole = null;
string? streamedName = null;

await foreach (var chatMessage in streamingMessageContents.ConfigureAwait(false))
{
metadata ??= (Dictionary<string, object?>?)chatMessage.Metadata;
Expand All @@ -45,19 +47,24 @@ await foreach (var chatMessage in streamingMessageContents.ConfigureAwait(false)

// Is always expected to have at least one chunk with the role provided from a streaming message
streamedRole ??= chatMessage.Role;
streamedName ??= chatMessage.AuthorName;

messageContents.Add(chatMessage);
yield return chatMessage;
}

if (messageContents.Count != 0)
{
chatHistory.Add(new OpenAIChatMessageContent(
streamedRole ?? AuthorRole.Assistant,
contentBuilder?.ToString() ?? string.Empty,
messageContents[0].ModelId!,
OpenAIFunctionToolCall.ConvertToolCallUpdatesToChatCompletionsFunctionToolCalls(ref toolCallIdsByIndex, ref functionNamesByIndex, ref functionArgumentBuildersByIndex),
metadata));
var role = streamedRole ?? AuthorRole.Assistant;

chatHistory.Add(
new OpenAIChatMessageContent(
role,
contentBuilder?.ToString() ?? string.Empty,
messageContents[0].ModelId!,
OpenAIFunctionToolCall.ConvertToolCallUpdatesToChatCompletionsFunctionToolCalls(ref toolCallIdsByIndex, ref functionNamesByIndex, ref functionArgumentBuildersByIndex),
metadata)
{ AuthorName = streamedName });
}
}
}
31 changes: 17 additions & 14 deletions dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ await foreach (Completions completions in response)
}
#pragma warning disable CA1031 // Do not catch general exception types
catch (Exception e)
#pragma warning restore CA1031
#pragma warning restore CA1031 // Do not catch general exception types
{
AddResponseMessage(chatOptions, chat, null, $"Error: Exception while invoking function. {e.Message}", toolCall.Id, this.Logger);
continue;
Expand Down Expand Up @@ -520,12 +520,14 @@ static void AddResponseMessage(ChatCompletionsOptions chatOptions, ChatHistory c

// Stream the response.
IReadOnlyDictionary<string, object?>? metadata = null;
string? streamedName = null;
ChatRole? streamedRole = default;
CompletionsFinishReason finishReason = default;
await foreach (StreamingChatCompletionsUpdate update in response.ConfigureAwait(false))
{
metadata = GetResponseMetadata(update);
streamedRole ??= update.Role;
streamedName ??= update.AuthorName;
finishReason = update.FinishReason ?? default;

// If we're intending to invoke function calls, we need to consume that function call information.
Expand All @@ -539,7 +541,7 @@ await foreach (StreamingChatCompletionsUpdate update in response.ConfigureAwait(
OpenAIFunctionToolCall.TrackStreamingToolingUpdate(update.ToolCallUpdate, ref toolCallIdsByIndex, ref functionNamesByIndex, ref functionArgumentBuildersByIndex);
}

yield return new OpenAIStreamingChatMessageContent(update, update.ChoiceIndex ?? 0, this.DeploymentOrModelName, metadata);
yield return new OpenAIStreamingChatMessageContent(update, update.ChoiceIndex ?? 0, this.DeploymentOrModelName, metadata) { AuthorName = streamedName };
}

// If we don't have a function to invoke, we're done.
Expand Down Expand Up @@ -571,8 +573,8 @@ await foreach (StreamingChatCompletionsUpdate update in response.ConfigureAwait(

// Add the original assistant message to the chatOptions; this is required for the service
// to understand the tool call responses.
chatOptions.Messages.Add(GetRequestMessage(streamedRole ?? default, content, toolCalls));
chat.Add(new OpenAIChatMessageContent(streamedRole ?? default, content, this.DeploymentOrModelName, toolCalls, metadata));
chatOptions.Messages.Add(GetRequestMessage(streamedRole ?? default, content, streamedName, toolCalls));
chat.Add(new OpenAIChatMessageContent(streamedRole ?? default, content, this.DeploymentOrModelName, toolCalls, metadata) { AuthorName = streamedName });

// Respond to each tooling request.
foreach (ChatCompletionsFunctionToolCall toolCall in toolCalls)
Expand Down Expand Up @@ -625,7 +627,7 @@ await foreach (StreamingChatCompletionsUpdate update in response.ConfigureAwait(
}
#pragma warning disable CA1031 // Do not catch general exception types
catch (Exception e)
#pragma warning restore CA1031
#pragma warning restore CA1031 // Do not catch general exception types
{
AddResponseMessage(chatOptions, chat, streamedRole, toolCall, metadata, result: null, $"Error: Exception while invoking function. {e.Message}", this.Logger);
continue;
Expand Down Expand Up @@ -780,7 +782,7 @@ internal static OpenAIClientOptions GetOpenAIClientOptions(HttpClient? httpClien
/// <param name="text">Optional chat instructions for the AI service</param>
/// <param name="executionSettings">Execution settings</param>
/// <returns>Chat object</returns>
internal static ChatHistory CreateNewChat(string? text = null, OpenAIPromptExecutionSettings? executionSettings = null)
private static ChatHistory CreateNewChat(string? text = null, OpenAIPromptExecutionSettings? executionSettings = null)
{
var chat = new ChatHistory();

Expand Down Expand Up @@ -938,21 +940,21 @@ private static CompletionsOptions CreateCompletionsOptions(string text, OpenAIPr
return options;
}

private static ChatRequestMessage GetRequestMessage(ChatRole chatRole, string contents, ChatCompletionsFunctionToolCall[]? tools)
private static ChatRequestMessage GetRequestMessage(ChatRole chatRole, string contents, string? name, ChatCompletionsFunctionToolCall[]? tools)
{
if (chatRole == ChatRole.User)
{
return new ChatRequestUserMessage(contents);
return new ChatRequestUserMessage(contents) { Name = name };
}

if (chatRole == ChatRole.System)
{
return new ChatRequestSystemMessage(contents);
return new ChatRequestSystemMessage(contents) { Name = name };
}

if (chatRole == ChatRole.Assistant)
{
var msg = new ChatRequestAssistantMessage(contents);
var msg = new ChatRequestAssistantMessage(contents) { Name = name };
if (tools is not null)
{
foreach (ChatCompletionsFunctionToolCall tool in tools)
Expand All @@ -970,7 +972,7 @@ private static ChatRequestMessage GetRequestMessage(ChatMessageContent message)
{
if (message.Role == AuthorRole.System)
{
return new ChatRequestSystemMessage(message.Content);
return new ChatRequestSystemMessage(message.Content) { Name = message.AuthorName };
}

if (message.Role == AuthorRole.User || message.Role == AuthorRole.Tool)
Expand All @@ -983,20 +985,21 @@ private static ChatRequestMessage GetRequestMessage(ChatMessageContent message)

if (message.Items is { Count: 1 } && message.Items.FirstOrDefault() is TextContent textContent)
{
return new ChatRequestUserMessage(textContent.Text);
return new ChatRequestUserMessage(textContent.Text) { Name = message.AuthorName };
}

return new ChatRequestUserMessage(message.Items.Select(static (KernelContent item) => (ChatMessageContentItem)(item switch
{
TextContent textContent => new ChatMessageTextContentItem(textContent.Text),
ImageContent imageContent => new ChatMessageImageContentItem(imageContent.Uri),
_ => throw new NotSupportedException($"Unsupported chat message content type '{item.GetType()}'.")
})));
})))
{ Name = message.AuthorName };
}

if (message.Role == AuthorRole.Assistant)
{
var asstMessage = new ChatRequestAssistantMessage(message.Content);
var asstMessage = new ChatRequestAssistantMessage(message.Content) { Name = message.AuthorName };

IEnumerable<ChatCompletionsToolCall>? tools = (message as OpenAIChatMessageContent)?.ToolCalls;
if (tools is null && message.Metadata?.TryGetValue(OpenAIChatMessageContent.FunctionToolCallsProperty, out object? toolCallsObject) is true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ public sealed class OpenAIChatMessageContent : ChatMessageContent
/// <summary>
/// Initializes a new instance of the <see cref="OpenAIChatMessageContent"/> class.
/// </summary>
/// <param name="chatMessage">Azure SDK chat message</param>
/// <param name="modelId">The model ID used to generate the content</param>
/// <param name="metadata">Additional metadata</param>
internal OpenAIChatMessageContent(ChatResponseMessage chatMessage, string modelId, IReadOnlyDictionary<string, object?>? metadata = null)
: base(new AuthorRole(chatMessage.Role.ToString()), chatMessage.Content, modelId, chatMessage, System.Text.Encoding.UTF8, CreateMetadataDictionary(chatMessage.ToolCalls, metadata))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ public void ConstructorsWorkCorrectly()
List<ChatCompletionsToolCall> toolCalls = [new FakeChatCompletionsToolCall("id")];

// Act
var content1 = new OpenAIChatMessageContent(new ChatRole("user"), "content1", "model-id1", toolCalls);
var content1 = new OpenAIChatMessageContent(new ChatRole("user"), "content1", "model-id1", toolCalls) { AuthorName = "Fred" };
var content2 = new OpenAIChatMessageContent(AuthorRole.User, "content2", "model-id2", toolCalls);

// Assert
this.AssertChatMessageContent(AuthorRole.User, "content1", "model-id1", toolCalls, content1);
this.AssertChatMessageContent(AuthorRole.User, "content1", "model-id1", toolCalls, content1, "Fred");
this.AssertChatMessageContent(AuthorRole.User, "content2", "model-id2", toolCalls, content2);
}

Expand Down Expand Up @@ -91,10 +91,12 @@ public void MetadataIsInitializedCorrectly()
string expectedContent,
string expectedModelId,
IReadOnlyList<ChatCompletionsToolCall> expectedToolCalls,
OpenAIChatMessageContent actualContent)
OpenAIChatMessageContent actualContent,
string? expectedName = null)
{
Assert.Equal(expectedRole, actualContent.Role);
Assert.Equal(expectedContent, actualContent.Content);
Assert.Equal(expectedName, actualContent.AuthorName);
Assert.Equal(expectedModelId, actualContent.ModelId);
Assert.Same(expectedToolCalls, actualContent.ToolCalls);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ namespace Microsoft.SemanticKernel;
/// </summary>
public class ChatMessageContent : KernelContent
{
/// <summary>
/// Name of the author of the message
/// </summary>
[Experimental("SKEXP0001")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? AuthorName { get; set; }

/// <summary>
/// Role of the author of the message
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Text;
using System.Text.Json.Serialization;
using Microsoft.SemanticKernel.ChatCompletion;
Expand All @@ -20,6 +21,13 @@ public class StreamingChatMessageContent : StreamingKernelContent
/// </summary>
public string? Content { get; set; }

/// <summary>
/// Name of the author of the message
/// </summary>
[Experimental("SKEXP0001")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? AuthorName { get; set; }

/// <summary>
/// Role of the author of the message
/// </summary>
Expand All @@ -42,7 +50,8 @@ public class StreamingChatMessageContent : StreamingKernelContent
/// <param name="encoding">Encoding of the chat</param>
/// <param name="metadata">Additional metadata</param>
[JsonConstructor]
public StreamingChatMessageContent(AuthorRole? role, string? content, object? innerContent = null, int choiceIndex = 0, string? modelId = null, Encoding? encoding = null, IReadOnlyDictionary<string, object?>? metadata = null) : base(innerContent, choiceIndex, modelId, metadata)
public StreamingChatMessageContent(AuthorRole? role, string? content, object? innerContent = null, int choiceIndex = 0, string? modelId = null, Encoding? encoding = null, IReadOnlyDictionary<string, object?>? metadata = null)
: base(innerContent, choiceIndex, modelId, metadata)
{
this.Role = role;
this.Content = content;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@ public void ItCanBeSerializedAndDeserialized()
{
// Arrange
var options = new JsonSerializerOptions();
var chatHistory = new ChatHistory();
chatHistory.AddMessage(AuthorRole.User, "Hello");
chatHistory.AddMessage(AuthorRole.Assistant, "Hi");
var chatHistory = new ChatHistory()
{
new ChatMessageContent(AuthorRole.System, "You are a polite bot.") { AuthorName = "ChatBot" },
new ChatMessageContent(AuthorRole.User, "Hello") { AuthorName = "ChatBot" },
new ChatMessageContent(AuthorRole.Assistant, "Hi") { AuthorName = "ChatBot" },
};
var chatHistoryJson = JsonSerializer.Serialize(chatHistory, options);

// Act
Expand All @@ -33,6 +36,7 @@ public void ItCanBeSerializedAndDeserialized()
{
Assert.Equal(chatHistory[i].Role.Label, chatHistoryDeserialized[i].Role.Label);
Assert.Equal(chatHistory[i].Content, chatHistoryDeserialized[i].Content);
Assert.Equal(chatHistory[i].AuthorName, chatHistoryDeserialized[i].AuthorName);
Assert.Equal(chatHistory[i].Items.Count, chatHistoryDeserialized[i].Items.Count);
Assert.Equal(
chatHistory[i].Items.OfType<TextContent>().Single().Text,
Expand Down
Loading

0 comments on commit ca9e3ae

Please sign in to comment.