Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

.Net: Fix a few issues with ChatHistoryExtensions #3508

Merged
merged 1 commit into from
Nov 15, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

using System.Linq;
using Microsoft.SemanticKernel.AI.ChatCompletion;
using static Microsoft.SemanticKernel.Text.TextChunker;
using Microsoft.SemanticKernel.Text;

#pragma warning disable IDE0130
// ReSharper disable once CheckNamespace - Using NS of Plan
Expand All @@ -12,7 +12,7 @@ namespace Microsoft.SemanticKernel.Planners;
/// <summary>
/// Extension methods for <see cref="ChatHistory"/> class.
/// </summary>
public static class ChatHistoryExtensions
internal static class ChatHistoryExtensions
{
/// <summary>
/// Returns the number of tokens in the chat history.
Expand All @@ -22,23 +22,52 @@ public static class ChatHistoryExtensions
// <param name="skipStart">The index to start skipping messages.</param>
// <param name="skipCount">The number of messages to skip.</param>
// <param name="tokenCounter">The token counter to use.</param>
internal static int GetTokenCount(this ChatHistory chatHistory, string? additionalMessage = null, int skipStart = 0, int skipCount = 0, TokenCounter? tokenCounter = null)
internal static int GetTokenCount(this ChatHistory chatHistory, string? additionalMessage = null, int skipStart = 0, int skipCount = 0, TextChunker.TokenCounter? tokenCounter = null)
{
tokenCounter ??= DefaultTokenCounter;
return tokenCounter is null ?
Default(chatHistory, additionalMessage, skipStart, skipCount) :
Custom(chatHistory, additionalMessage, skipStart, skipCount, tokenCounter);

var messages = string.Join("\n", chatHistory.Where((m, i) => i < skipStart || i >= skipStart + skipCount).Select(m => m.Content));

if (!string.IsNullOrEmpty(additionalMessage))
static int Default(ChatHistory chatHistory, string? additionalMessage, int skipStart, int skipCount)
{
messages = $"{messages}\n{additionalMessage}";
int chars = 0;
bool prevMsg = false;
for (int i = 0; i < chatHistory.Count; i++)
{
if (i >= skipStart && i < skipStart + skipCount)
{
continue;
}

chars += chatHistory[i].Content?.Length ?? 0;

// +1 for "\n" if there was a previous message
if (prevMsg)
{
chars++;
}
prevMsg = true;
}

if (additionalMessage is not null)
{
chars += 1 + additionalMessage.Length; // +1 for "\n"
}

return chars / 4; // same as TextChunker's default token counter
}

var tokenCount = tokenCounter(messages);
return tokenCount;
}
static int Custom(ChatHistory chatHistory, string? additionalMessage, int skipStart, int skipCount, TextChunker.TokenCounter tokenCounter)
{
var messages = string.Join("\n", chatHistory.Where((m, i) => i < skipStart || i >= skipStart + skipCount).Select(m => m.Content));

private static int DefaultTokenCounter(string input)
{
return input.Length / 4;
if (!string.IsNullOrEmpty(additionalMessage))
{
messages = $"{messages}\n{additionalMessage}";
}

var tokenCount = tokenCounter(messages);
return tokenCount;
}
}
}
Loading