Skip to content

Commit

Permalink
.Net: Use IReadOnlyDictionary for metadata dictionaries (#4260)
Browse files Browse the repository at this point in the history
And then stop making copies each time we hand one out.

Fixes #4233

---------

Co-authored-by: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com>
  • Loading branch information
stephentoub and RogerBarreto committed Dec 14, 2023
1 parent f356cd8 commit 2321dd5
Show file tree
Hide file tree
Showing 20 changed files with 116 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ public sealed class AzureOpenAIWithDataChatMessageContent : ChatMessageContent
/// <param name="chatChoice">Azure Chat With Data Choice</param>
/// <param name="modelId">The model ID used to generate the content</param>
/// <param name="metadata">Additional metadata</param>
internal AzureOpenAIWithDataChatMessageContent(ChatWithDataChoice chatChoice, string? modelId, IDictionary<string, object?>? metadata = null)
: base(default, string.Empty, modelId, chatChoice, System.Text.Encoding.UTF8, metadata ?? new Dictionary<string, object?>(1))
internal AzureOpenAIWithDataChatMessageContent(ChatWithDataChoice chatChoice, string? modelId, IReadOnlyDictionary<string, object?>? metadata = null)
: base(default, string.Empty, modelId, chatChoice, System.Text.Encoding.UTF8, CreateMetadataDictionary(metadata))
{
// An assistant message content must be present, otherwise the chat is not valid.
var chatMessage = chatChoice.Messages.First(m => string.Equals(m.Role, AuthorRole.Assistant.Label, StringComparison.OrdinalIgnoreCase));
Expand All @@ -36,6 +36,32 @@ internal AzureOpenAIWithDataChatMessageContent(ChatWithDataChoice chatChoice, st
this.Role = new AuthorRole(chatMessage.Role);

this.ToolContent = chatChoice.Messages.FirstOrDefault(message => message.Role.Equals(AuthorRole.Tool.Label, StringComparison.OrdinalIgnoreCase))?.Content;
this.Metadata!.Add(nameof(this.ToolContent), this.ToolContent);
((Dictionary<string, object?>)this.Metadata!).Add(nameof(this.ToolContent), this.ToolContent);
}

private static Dictionary<string, object?> CreateMetadataDictionary(IReadOnlyDictionary<string, object?>? metadata)
{
Dictionary<string, object?> newDictionary;
if (metadata is null)
{
// There's no existing metadata to clone; just allocate a new dictionary.
newDictionary = new Dictionary<string, object?>(1);
}
else if (metadata is IDictionary<string, object?> origMutable)
{
// Efficiently clone the old dictionary to a new one.
newDictionary = new Dictionary<string, object?>(origMutable);
}
else
{
// There's metadata to clone but we have to do so one item at a time.
newDictionary = new Dictionary<string, object?>(metadata.Count + 1);
foreach (var kvp in metadata)
{
newDictionary[kvp.Key] = kvp.Value;
}
}

return newDictionary;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ public sealed class AzureOpenAIWithDataStreamingChatMessageContent : StreamingCh
/// <param name="choiceIndex">Index of the choice</param>
/// <param name="modelId">The model ID used to generate the content</param>
/// <param name="metadata">Additional metadata</param>
internal AzureOpenAIWithDataStreamingChatMessageContent(ChatWithDataStreamingChoice choice, int choiceIndex, string modelId, IDictionary<string, object?>? metadata = null) : base(AuthorRole.Assistant, null, choice, choiceIndex, modelId, Encoding.UTF8, metadata)
internal AzureOpenAIWithDataStreamingChatMessageContent(ChatWithDataStreamingChoice choice, int choiceIndex, string modelId, IReadOnlyDictionary<string, object?>? metadata = null) :
base(AuthorRole.Assistant, null, choice, choiceIndex, modelId, Encoding.UTF8, metadata)
{
var message = choice.Messages.FirstOrDefault(this.IsValidMessage);
var messageContent = message?.Delta?.Content;
Expand Down
16 changes: 8 additions & 8 deletions dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ internal async Task<IReadOnlyList<TextContent>> GetTextResultsAsync(
}

this.CaptureUsageDetails(responseData.Usage);
var metadata = GetResponseMetadata(responseData);
return responseData.Choices.Select(choice => new TextContent(choice.Text, this.DeploymentOrModelName, choice, Encoding.UTF8, new Dictionary<string, object?>(metadata))).ToList();
IReadOnlyDictionary<string, object?> metadata = GetResponseMetadata(responseData);
return responseData.Choices.Select(choice => new TextContent(choice.Text, this.DeploymentOrModelName, choice, Encoding.UTF8, metadata)).ToList();
}

internal async IAsyncEnumerable<StreamingTextContent> GetStreamingTextContentsAsync(
Expand All @@ -132,13 +132,13 @@ internal async IAsyncEnumerable<StreamingTextContent> GetStreamingTextContentsAs

StreamingResponse<Completions>? response = await RunRequestAsync(() => this.Client.GetCompletionsStreamingAsync(options, cancellationToken)).ConfigureAwait(false);

Dictionary<string, object?>? metadata = null;
IReadOnlyDictionary<string, object?>? metadata = null;
await foreach (Completions completions in response)
{
metadata ??= GetResponseMetadata(completions);
foreach (Choice choice in completions.Choices)
{
yield return new OpenAIStreamingTextContent(choice.Text, choice.Index, this.DeploymentOrModelName, choice, new(metadata));
yield return new OpenAIStreamingTextContent(choice.Text, choice.Index, this.DeploymentOrModelName, choice, metadata);
}
}
}
Expand Down Expand Up @@ -240,13 +240,13 @@ internal async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsy
throw new KernelException("Chat completions not found");
}

var metadata = GetResponseMetadata(responseData);
IReadOnlyDictionary<string, object?> metadata = GetResponseMetadata(responseData);

// If we don't want to attempt to invoke any functions, just return the result.
// Or if we are auto-invoking but we somehow end up with other than 1 choice even though only 1 was requested, similarly bail.
if (!autoInvoke || responseData.Choices.Count != 1)
{
return responseData.Choices.Select(chatChoice => new OpenAIChatMessageContent(chatChoice.Message, this.DeploymentOrModelName, new(metadata))).ToList();
return responseData.Choices.Select(chatChoice => new OpenAIChatMessageContent(chatChoice.Message, this.DeploymentOrModelName, metadata)).ToList();
}

Debug.Assert(kernel is not null);
Expand Down Expand Up @@ -374,7 +374,7 @@ internal async IAsyncEnumerable<OpenAIStreamingChatMessageContent> GetStreamingC
// Stream the response.
contentBuilder?.Clear();
List<ChatCompletionsFunctionToolCall>? functionCallResponses = null;
Dictionary<string, object?>? metadata = null;
IReadOnlyDictionary<string, object?>? metadata = null;
ChatRole? streamedRole = default;
CompletionsFinishReason finishReason = default;
await foreach (StreamingChatCompletionsUpdate update in response.ConfigureAwait(false))
Expand Down Expand Up @@ -410,7 +410,7 @@ internal async IAsyncEnumerable<OpenAIStreamingChatMessageContent> GetStreamingC
}
}

yield return new OpenAIStreamingChatMessageContent(update, update.ChoiceIndex ?? 0, this.DeploymentOrModelName, new(metadata));
yield return new OpenAIStreamingChatMessageContent(update, update.ChoiceIndex ?? 0, this.DeploymentOrModelName, metadata);
}

// If we don't have a function call to invoke, we're done.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,31 +28,28 @@ public sealed class OpenAIChatMessageContent : ChatMessageContent
/// <param name="chatMessage">Azure SDK chat message</param>
/// <param name="modelId">The model ID used to generate the content</param>
/// <param name="metadata">Additional metadata</param>
internal OpenAIChatMessageContent(ChatResponseMessage chatMessage, string modelId, Dictionary<string, object?>? metadata = null)
: base(new AuthorRole(chatMessage.Role.ToString()), chatMessage.Content, modelId, chatMessage, System.Text.Encoding.UTF8, metadata ?? new Dictionary<string, object?>(1))
internal OpenAIChatMessageContent(ChatResponseMessage chatMessage, string modelId, IReadOnlyDictionary<string, object?>? metadata = null)
: base(new AuthorRole(chatMessage.Role.ToString()), chatMessage.Content, modelId, chatMessage, System.Text.Encoding.UTF8, CreateMetadataDictionary(chatMessage.ToolCalls, metadata))
{
this.ToolCalls = chatMessage.ToolCalls;
this.Metadata!.Add(ToolCallsProperty, chatMessage.ToolCalls);
}

/// <summary>
/// Initializes a new instance of the <see cref="OpenAIChatMessageContent"/> class.
/// </summary>
internal OpenAIChatMessageContent(ChatRole role, string? content, string modelId, IReadOnlyList<ChatCompletionsToolCall> toolCalls, Dictionary<string, object?>? metadata = null)
: base(new AuthorRole(role.ToString()), content, modelId, content, System.Text.Encoding.UTF8, metadata ?? new Dictionary<string, object?>(1))
internal OpenAIChatMessageContent(ChatRole role, string? content, string modelId, IReadOnlyList<ChatCompletionsToolCall> toolCalls, IReadOnlyDictionary<string, object?>? metadata = null)
: base(new AuthorRole(role.ToString()), content, modelId, content, System.Text.Encoding.UTF8, CreateMetadataDictionary(toolCalls, metadata))
{
this.ToolCalls = toolCalls;
this.Metadata![ToolCallsProperty] = toolCalls;
}

/// <summary>
/// Initializes a new instance of the <see cref="OpenAIChatMessageContent"/> class.
/// </summary>
internal OpenAIChatMessageContent(AuthorRole role, string? content, string modelId, IReadOnlyList<ChatCompletionsToolCall> toolCalls, Dictionary<string, object?>? metadata = null)
: base(role, content, modelId, content, System.Text.Encoding.UTF8, metadata ?? new Dictionary<string, object?>(1))
internal OpenAIChatMessageContent(AuthorRole role, string? content, string modelId, IReadOnlyList<ChatCompletionsToolCall> toolCalls, IReadOnlyDictionary<string, object?>? metadata = null)
: base(role, content, modelId, content, System.Text.Encoding.UTF8, CreateMetadataDictionary(toolCalls, metadata))
{
this.ToolCalls = toolCalls;
this.Metadata![ToolCallsProperty] = toolCalls;
}

/// <summary>
Expand All @@ -66,24 +63,58 @@ internal OpenAIChatMessageContent(AuthorRole role, string? content, string model
/// <returns>The <see cref="OpenAIFunctionToolCall"/>, or null if no function was returned by the model.</returns>
public IReadOnlyList<OpenAIFunctionToolCall> GetOpenAIFunctionToolCalls()
{
if (this.ToolCalls is not null)
List<OpenAIFunctionToolCall>? functionToolCallList = null;

foreach (var toolCall in this.ToolCalls)
{
if (toolCall is ChatCompletionsFunctionToolCall functionToolCall)
{
(functionToolCallList ??= new List<OpenAIFunctionToolCall>()).Add(new OpenAIFunctionToolCall(functionToolCall));
}
}

if (functionToolCallList is not null)
{
List<OpenAIFunctionToolCall>? list = null;
return functionToolCallList;
}

return Array.Empty<OpenAIFunctionToolCall>();
}

for (int i = 0; i < this.ToolCalls.Count; i++)
private static IReadOnlyDictionary<string, object?>? CreateMetadataDictionary(
IReadOnlyList<ChatCompletionsToolCall> toolCalls,
IReadOnlyDictionary<string, object?>? original)
{
// We only need to augment the metadata if there are any tool calls.
if (toolCalls.Count > 0)
{
Dictionary<string, object?> newDictionary;
if (original is null)
{
// There's no existing metadata to clone; just allocate a new dictionary.
newDictionary = new Dictionary<string, object?>(1);
}
else if (original is IDictionary<string, object?> origIDictionary)
{
// Efficiently clone the old dictionary to a new one.
newDictionary = new Dictionary<string, object?>(origIDictionary);
}
else
{
if (this.ToolCalls[i] is ChatCompletionsFunctionToolCall ftc)
// There's metadata to clone but we have to do so one item at a time.
newDictionary = new Dictionary<string, object?>(original.Count + 1);
foreach (var kvp in original)
{
(list ??= new List<OpenAIFunctionToolCall>()).Add(new OpenAIFunctionToolCall(ftc));
newDictionary[kvp.Key] = kvp.Value;
}
}

if (list is not null)
{
return list;
}
// Add the additional entry.
newDictionary.Add(ToolCallsProperty, toolCalls);

return newDictionary;
}

return Array.Empty<OpenAIFunctionToolCall>();
return original;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ internal OpenAIStreamingChatMessageContent(
StreamingChatCompletionsUpdate chatUpdate,
int choiceIndex,
string modelId,
Dictionary<string, object?>? metadata = null)
IReadOnlyDictionary<string, object?>? metadata = null)
: base(
chatUpdate.Role.HasValue ? new AuthorRole(chatUpdate.Role.Value.ToString()) : null,
chatUpdate.ContentUpdate,
Expand Down Expand Up @@ -62,7 +62,7 @@ internal OpenAIStreamingChatMessageContent(
CompletionsFinishReason? completionsFinishReason = null,
int choiceIndex = 0,
string? modelId = null,
Dictionary<string, object?>? metadata = null)
IReadOnlyDictionary<string, object?>? metadata = null)
: base(
authorRole,
content,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ internal OpenAIStreamingTextContent(
int choiceIndex,
string modelId,
object? innerContentObject = null,
Dictionary<string, object?>? metadata = null)
IReadOnlyDictionary<string, object?>? metadata = null)
: base(
text,
choiceIndex,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ private async Task<IReadOnlyList<ChatMessageContent>> InternalGetChatMessageCont
var body = await response.Content.ReadAsStringWithExceptionMappingAsync().ConfigureAwait(false);

var chatWithDataResponse = this.DeserializeResponse<ChatWithDataResponse>(body);
var metadata = GetResponseMetadata(chatWithDataResponse);
IReadOnlyDictionary<string, object?> metadata = GetResponseMetadata(chatWithDataResponse);

return chatWithDataResponse.Choices.Select(choice => new AzureOpenAIWithDataChatMessageContent(choice, this.GetModelId(), metadata)).ToList();
}
Expand Down Expand Up @@ -195,11 +195,11 @@ private async IAsyncEnumerable<AzureOpenAIWithDataStreamingChatMessageContent> I
}

var chatWithDataResponse = this.DeserializeResponse<ChatWithDataStreamingResponse>(body);
var metadata = GetResponseMetadata(chatWithDataResponse);
IReadOnlyDictionary<string, object?> metadata = GetResponseMetadata(chatWithDataResponse);

foreach (var choice in chatWithDataResponse.Choices)
{
yield return new AzureOpenAIWithDataStreamingChatMessageContent(choice, choice.Index, this.GetModelId()!, new Dictionary<string, object?>(metadata));
yield return new AzureOpenAIWithDataStreamingChatMessageContent(choice, choice.Index, this.GetModelId()!, metadata);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public void AddMessage(ChatMessageContent chatMessageContent)
/// <param name="encoding">Encoding of the message content</param>
/// <param name="metadata">Dictionary for any additional metadata</param>
/// </summary>
public void AddMessage(AuthorRole authorRole, string content, Encoding? encoding = null, IDictionary<string, object?>? metadata = null) =>
public void AddMessage(AuthorRole authorRole, string content, Encoding? encoding = null, IReadOnlyDictionary<string, object?>? metadata = null) =>
this.Add(new ChatMessageContent(authorRole, content, null, null, encoding, metadata));

/// <summary>
Expand All @@ -75,7 +75,7 @@ public void AddMessage(AuthorRole authorRole, string content, Encoding? encoding
/// <param name="encoding">Encoding of the message content</param>
/// <param name="metadata">Dictionary for any additional metadata</param>
/// </summary>
public void AddMessage(AuthorRole authorRole, ChatMessageContentItemCollection items, Encoding? encoding = null, IDictionary<string, object?>? metadata = null) =>
public void AddMessage(AuthorRole authorRole, ChatMessageContentItemCollection items, Encoding? encoding = null, IReadOnlyDictionary<string, object?>? metadata = null) =>
this.Add(new ChatMessageContent(authorRole, items, null, null, encoding, metadata));

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public ChatMessageContent(
string? modelId = null,
object? innerContent = null,
Encoding? encoding = null,
IDictionary<string, object?>? metadata = null)
IReadOnlyDictionary<string, object?>? metadata = null)
: base(innerContent, modelId, metadata)
{
this.Role = role;
Expand All @@ -72,7 +72,7 @@ public ChatMessageContent(
string? modelId = null,
object? innerContent = null,
Encoding? encoding = null,
IDictionary<string, object?>? metadata = null)
IReadOnlyDictionary<string, object?>? metadata = null)
: base(innerContent, modelId, metadata)
{
this.Role = role;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public ImageContent(
string? modelId = null,
object? innerContent = null,
Encoding? encoding = null,
IDictionary<string, object?>? metadata = null)
IReadOnlyDictionary<string, object?>? metadata = null)
: base(innerContent, modelId, metadata)
{
this.Uri = uri;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,18 @@ public abstract class KernelContent
/// <summary>
/// The metadata associated with the content.
/// </summary>
public IDictionary<string, object?>? Metadata { get; }
public IReadOnlyDictionary<string, object?>? Metadata { get; }

/// <summary>
/// Initializes a new instance of the <see cref="KernelContent"/> class.
/// </summary>
/// <param name="innerContent">The inner content representation</param>
/// <param name="modelId">The model ID used to generate the content</param>
/// <param name="metadata">Metadata associated with the content</param>
protected KernelContent(object? innerContent, string? modelId = null, IDictionary<string, object?>? metadata = null)
protected KernelContent(object? innerContent, string? modelId = null, IReadOnlyDictionary<string, object?>? metadata = null)
{
this.ModelId = modelId;
this.InnerContent = innerContent;
if (metadata is not null)
{
this.Metadata = new Dictionary<string, object?>(metadata);
}
this.Metadata = metadata;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public class StreamingChatMessageContent : StreamingKernelContent
/// <param name="encoding">Encoding of the chat</param>
/// <param name="metadata">Additional metadata</param>
[JsonConstructor]
public StreamingChatMessageContent(AuthorRole? role, string? content, object? innerContent = null, int choiceIndex = 0, string? modelId = null, Encoding? encoding = null, IDictionary<string, object?>? metadata = null) : base(innerContent, choiceIndex, modelId, metadata)
public StreamingChatMessageContent(AuthorRole? role, string? content, object? innerContent = null, int choiceIndex = 0, string? modelId = null, Encoding? encoding = null, IReadOnlyDictionary<string, object?>? metadata = null) : base(innerContent, choiceIndex, modelId, metadata)
{
this.Role = role;
this.Content = content;
Expand Down
Loading

0 comments on commit 2321dd5

Please sign in to comment.