diff --git a/dotnet/src/Planners/Planners.Core/Stepwise/ChatHistoryExtensions.cs b/dotnet/src/Planners/Planners.Core/Stepwise/ChatHistoryExtensions.cs index 60e810f105b5..9aabf9be2751 100644 --- a/dotnet/src/Planners/Planners.Core/Stepwise/ChatHistoryExtensions.cs +++ b/dotnet/src/Planners/Planners.Core/Stepwise/ChatHistoryExtensions.cs @@ -2,7 +2,7 @@ using System.Linq; using Microsoft.SemanticKernel.AI.ChatCompletion; -using static Microsoft.SemanticKernel.Text.TextChunker; +using Microsoft.SemanticKernel.Text; #pragma warning disable IDE0130 // ReSharper disable once CheckNamespace - Using NS of Plan @@ -12,7 +12,7 @@ namespace Microsoft.SemanticKernel.Planners; /// /// Extension methods for class. /// -public static class ChatHistoryExtensions +internal static class ChatHistoryExtensions { /// /// Returns the number of tokens in the chat history. @@ -22,23 +22,52 @@ public static class ChatHistoryExtensions // The index to start skipping messages. // The number of messages to skip. // The token counter to use. - internal static int GetTokenCount(this ChatHistory chatHistory, string? additionalMessage = null, int skipStart = 0, int skipCount = 0, TokenCounter? tokenCounter = null) + internal static int GetTokenCount(this ChatHistory chatHistory, string? additionalMessage = null, int skipStart = 0, int skipCount = 0, TextChunker.TokenCounter? tokenCounter = null) { - tokenCounter ??= DefaultTokenCounter; + return tokenCounter is null ? + Default(chatHistory, additionalMessage, skipStart, skipCount) : + Custom(chatHistory, additionalMessage, skipStart, skipCount, tokenCounter); - var messages = string.Join("\n", chatHistory.Where((m, i) => i < skipStart || i >= skipStart + skipCount).Select(m => m.Content)); - - if (!string.IsNullOrEmpty(additionalMessage)) + static int Default(ChatHistory chatHistory, string? additionalMessage, int skipStart, int skipCount) { - messages = $"{messages}\n{additionalMessage}"; + int chars = 0; + bool prevMsg = false; + for (int i = 0; i < chatHistory.Count; i++) + { + if (i >= skipStart && i < skipStart + skipCount) + { + continue; + } + + chars += chatHistory[i].Content?.Length ?? 0; + + // +1 for "\n" if there was a previous message + if (prevMsg) + { + chars++; + } + prevMsg = true; + } + + if (additionalMessage is not null) + { + chars += 1 + additionalMessage.Length; // +1 for "\n" + } + + return chars / 4; // same as TextChunker's default token counter } - var tokenCount = tokenCounter(messages); - return tokenCount; - } + static int Custom(ChatHistory chatHistory, string? additionalMessage, int skipStart, int skipCount, TextChunker.TokenCounter tokenCounter) + { + var messages = string.Join("\n", chatHistory.Where((m, i) => i < skipStart || i >= skipStart + skipCount).Select(m => m.Content)); - private static int DefaultTokenCounter(string input) - { - return input.Length / 4; + if (!string.IsNullOrEmpty(additionalMessage)) + { + messages = $"{messages}\n{additionalMessage}"; + } + + var tokenCount = tokenCounter(messages); + return tokenCount; + } } }