Skip to content

Commit

Permalink
.Net Single source of truth for chat message content (#5088)
Browse files Browse the repository at this point in the history
### Motivation, Context and Description
Today, the ChatMessageContent class has two sources of truth for its
content - the Content property and the Items property. This may be
acceptable for now, as all SK and industry chat completion services
follow the same protocol. They use the Content property for system,
user, assistant, and tool messages and alternatively allow passing image
and text content via the Items collection for user messages only.
   
However, this might not be suitable when there's a chat completion
service that doesn't follow the protocol mentioned above. For example,
it could be a new advanced chat completion service with multimodal
support for assistant messages. In this case, consumer code working
through the IChatCompletionService interface won't be able to handle
content for assistant messages polymorphically, and all consumers that
need to work with both current and new chat completion services will
have to use code like this:
   
```C#  
var message = await chatCompletionService.GetChatMessageContentAsync(...);  
    
if (message.Content != null)  
{  
    // Handle content specified in the Content property of the assistance message  
}  
else if (message.Items is { Count: > 0 } items)  
{  
    // Handle content specified in the Items/items property of the assistance message  
}  
    
// or check the Items property first and then Content one?  
```

The problem becomes more apparent and manifests itself immediately in
the agent's space. Each agent needs to have logic like the one above to
identify the source of the content and map it to an internal API data
model. For example, here's the code for a ChatCompletion agent that we
would need to write
   
```C#  
async Task<ChatMessageContent[]> InvokeAsync(ChatMessageContent[] messages)
{
    var chat = new ChatHistory();

    foreach (var message in messages)
    {
        if (message.Role == AuthorRole.User)
        {
            // User messages can have content in either 'Items' property or the 'Content' property.
            // Assuming one of the two properties has content, adding the message to the chat and continue.
            chat.Add(message);
            continue;
        }

        // The system, assistant and tool messages are expected to have content in the 'Content' property.
        // This expectation is specific to OpenAI chat completion services and may not be relevant to other chat   
        // completion service types, e.g., multimodality for assistant messages, where content would be expected   
        // to be provided via the 'Items' collection.  
        if (!string.IsNullOrEmpty(message.Content))
        {
            chat.Add(message);
            continue;
        }

        // Doing our best to identify content for the message and add it to the chat.
        if (message.Items is { Count: > 0 } items && items[0] is TextContent textContent)
        {
            // The problem with the clone code below is that we loose the original message type that could have been interpreted by
            // underlying API differently than the type of the clone - ChatMessageContent
            var clone = new ChatMessageContent(
                role: message.Role,
                content: textContent.Text,
                modelId: message.ModelId,
                innerContent: message.InnerContent,
                encoding: message.Encoding,
                metadata: message.Metadata
            );

            chat.Add(clone);
            continue;
        }

        // If we get here it means that all the above conditions failed, so we add the message 
        // to let underlying API handle it.
        chat.Add(message);
    }

    var chatMessageContent = await chatCompletionService.GetChatMessageContentsAsync(chat, ...);

    return chatMessageContent.Select(m => { m.Source = this; return m; }).ToArray();
}
```  

To avoid all the unnecessary 'if/else' mapping logic needed to identify
the source of content depending on the message role, it would be
beneficial to have only one source of content - Items. Ideally, the
'Content' property should be removed, but doing so would unnecessarily
break a lot of consumer code.

As a middle-ground solution, this PR changes the purpose of the
'Content' property from being a separate source of content to a shortcut
for the first item of text content type in the 'Items' collection. This
way, the 'Content' property will be nothing more than just a convenient
method to add, update, or return the text of the first item of text
content type. The 'Items' collection, on the other hand, becomes the
only source of content that can be used polymorphically by consumer
code.

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄
  • Loading branch information
SergeyMenshykh committed Feb 26, 2024
1 parent a6f66ec commit 36ce92f
Show file tree
Hide file tree
Showing 6 changed files with 255 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

namespace Examples;

public class Example85_ChatHistorySerialization : BaseTest
public class Example86_ChatHistorySerialization : BaseTest
{
/// <summary>
/// Demonstrates how to serialize and deserialize <see cref="ChatHistory"/> class
Expand Down Expand Up @@ -88,7 +88,7 @@ public void SerializeChatWithHistoryWithCustomContentType()
WriteLine($"Custom content: {(deserializedMessage.Items![1]! as CustomContent)!.Content}");
}

public Example85_ChatHistorySerialization(ITestOutputHelper output) : base(output)
public Example86_ChatHistorySerialization(ITestOutputHelper output) : base(output)
{
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,11 @@ public async Task GetChatMessageContentsHandlesSettingsCorrectlyAsync()
var assistantMessage = messages[2];

Assert.Equal("user", userMessage.GetProperty("role").GetString());
Assert.Equal("User Message", userMessage.GetProperty("content").GetString());

var contentItems = userMessage.GetProperty("content");
Assert.Equal(1, contentItems.GetArrayLength());
Assert.Equal("User Message", contentItems[0].GetProperty("text").GetString());
Assert.Equal("text", contentItems[0].GetProperty("type").GetString());

Assert.Equal("system", systemMessage.GetProperty("role").GetString());
Assert.Equal("System Message", systemMessage.GetProperty("content").GetString());
Expand Down Expand Up @@ -599,8 +603,12 @@ public async Task GetChatMessageContentsUsesPromptAndSettingsCorrectlyAsync()
Assert.Equal("This is test system message", messages[0].GetProperty("content").GetString());
Assert.Equal("system", messages[0].GetProperty("role").GetString());

Assert.Equal("This is test prompt", messages[1].GetProperty("content").GetString());
Assert.Equal("user", messages[1].GetProperty("role").GetString());

var contentItems = messages[1].GetProperty("content");
Assert.Equal(1, contentItems.GetArrayLength());
Assert.Equal("This is test prompt", contentItems[0].GetProperty("text").GetString());
Assert.Equal("text", contentItems[0].GetProperty("type").GetString());
}

public void Dispose()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,19 @@ public async Task ItAddsSystemMessageAsync()
var actualRequestContent = Encoding.UTF8.GetString(this._messageHandlerStub.RequestContent!);
Assert.NotNull(actualRequestContent);
var optionsJson = JsonSerializer.Deserialize<JsonElement>(actualRequestContent);
Assert.Equal(2, optionsJson.GetProperty("messages").GetArrayLength());
Assert.Equal("Assistant is a large language model.", optionsJson.GetProperty("messages")[0].GetProperty("content").GetString());
Assert.Equal("system", optionsJson.GetProperty("messages")[0].GetProperty("role").GetString());
Assert.Equal("Hello", optionsJson.GetProperty("messages")[1].GetProperty("content").GetString());
Assert.Equal("user", optionsJson.GetProperty("messages")[1].GetProperty("role").GetString());

var messages = optionsJson.GetProperty("messages");
Assert.Equal(2, messages.GetArrayLength());

Assert.Equal("Assistant is a large language model.", messages[0].GetProperty("content").GetString());
Assert.Equal("system", messages[0].GetProperty("role").GetString());

Assert.Equal("user", messages[1].GetProperty("role").GetString());
var contentItems = messages[1].GetProperty("content");

Assert.Equal(1, contentItems.GetArrayLength());
Assert.Equal("Hello", contentItems[0].GetProperty("text").GetString());
Assert.Equal("text", contentItems[0].GetProperty("type").GetString());
}

public void Dispose()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Text;
using System.Text.Json.Serialization;
using Microsoft.SemanticKernel.ChatCompletion;
Expand All @@ -19,20 +21,78 @@ public class ChatMessageContent : KernelContent
public AuthorRole Role { get; set; }

/// <summary>
/// Content of the message
/// A convenience property to get or set the text of the first item in the <see cref="Items" /> collection of <see cref="TextContent"/> type.
/// </summary>
public string? Content { get; set; }
[EditorBrowsable(EditorBrowsableState.Never)]
public string? Content
{
get
{
var textContent = this.Items.OfType<TextContent>().FirstOrDefault();
return textContent?.Text;
}
set
{
if (value == null)
{
return;
}

var textContent = this.Items.OfType<TextContent>().FirstOrDefault();
if (textContent is not null)
{
textContent.Text = value;
textContent.Encoding = this.Encoding;
}
else
{
this.Items.Add(new TextContent(
text: value,
modelId: this.ModelId,
innerContent: this.InnerContent,
encoding: this.Encoding,
metadata: this.Metadata
));
}
}
}

/// <summary>
/// Chat message content items
/// </summary>
public ChatMessageContentItemCollection? Items { get; set; }
public ChatMessageContentItemCollection Items
{
get => this._items ??= new ChatMessageContentItemCollection();
set => this._items = value;
}

/// <summary>
/// The encoding of the text content.
/// </summary>
[JsonIgnore]
public Encoding Encoding { get; set; }
public Encoding Encoding
{
get
{
var textContent = this.Items.OfType<TextContent>().FirstOrDefault();
if (textContent is not null)
{
return textContent.Encoding;
}

return this._encoding;
}
set
{
this._encoding = value;

var textContent = this.Items.OfType<TextContent>().FirstOrDefault();
if (textContent is not null)
{
textContent.Encoding = value;
}
}
}

/// <summary>
/// Represents the source of the message.
Expand Down Expand Up @@ -65,8 +125,8 @@ public class ChatMessageContent : KernelContent
: base(innerContent, modelId, metadata)
{
this.Role = role;
this._encoding = encoding ?? Encoding.UTF8;
this.Content = content;
this.Encoding = encoding ?? Encoding.UTF8;
}

/// <summary>
Expand All @@ -88,13 +148,16 @@ public class ChatMessageContent : KernelContent
: base(innerContent, modelId, metadata)
{
this.Role = role;
this.Encoding = encoding ?? Encoding.UTF8;
this.Items = items;
this._encoding = encoding ?? Encoding.UTF8;
this._items = items;
}

/// <inheritdoc/>
public override string ToString()
{
return this.Content ?? string.Empty;
}

private ChatMessageContentItemCollection? _items;
private Encoding _encoding;
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Linq;
using System.Text.Json;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Xunit;

Expand All @@ -12,24 +14,7 @@ namespace SemanticKernel.UnitTests.AI.ChatCompletion;
public class ChatHistoryTests
{
[Fact]
public void ItCanBeSerialized()
{
// Arrange
var options = new JsonSerializerOptions();
var chatHistory = new ChatHistory();
chatHistory.AddMessage(AuthorRole.User, "Hello");
chatHistory.AddMessage(AuthorRole.Assistant, "Hi");

// Act
var chatHistoryJson = JsonSerializer.Serialize(chatHistory);

// Assert
Assert.NotNull(chatHistoryJson);
Assert.Equal("[{\"Role\":{\"Label\":\"user\"},\"Content\":\"Hello\",\"Items\":null,\"ModelId\":null,\"Metadata\":null},{\"Role\":{\"Label\":\"assistant\"},\"Content\":\"Hi\",\"Items\":null,\"ModelId\":null,\"Metadata\":null}]", chatHistoryJson);
}

[Fact]
public void ItCanBeDeserialized()
public void ItCanBeSerializedAndDeserialized()
{
// Arrange
var options = new JsonSerializerOptions();
Expand All @@ -48,6 +33,10 @@ public void ItCanBeDeserialized()
{
Assert.Equal(chatHistory[i].Role.Label, chatHistoryDeserialized[i].Role.Label);
Assert.Equal(chatHistory[i].Content, chatHistoryDeserialized[i].Content);
Assert.Equal(chatHistory[i].Items.Count, chatHistoryDeserialized[i].Items.Count);
Assert.Equal(
chatHistory[i].Items.OfType<TextContent>().Single().Text,
chatHistoryDeserialized[i].Items.OfType<TextContent>().Single().Text);
}
}
}

0 comments on commit 36ce92f

Please sign in to comment.