From 0c89e0bd4314b4f4c913563258ffefadedab1afe Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Fri, 17 May 2024 12:48:40 -0400 Subject: [PATCH] .Net: Fix MistralAI logging (#6315) - The logger factory wasn't being forwarded to the chat completion service instance - The class wasn't logging tokens like the other connectors Also made the others consistent in verbiage, metrics namespace, etc. --- .../Clients/GeminiChatCompletionClient.cs | 47 +++++++------- .../Core/HuggingFaceMessageApiClient.cs | 29 +++++---- .../Client/MistralClient.cs | 64 ++++++++++++++++++- .../MistralAIKernelBuilderExtensions.cs | 5 +- .../Connectors.OpenAI/AzureSdk/ClientCore.cs | 8 +-- 5 files changed, 109 insertions(+), 44 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs index a44ebc87b1df..087a1c2bf2f8 100644 --- a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs +++ b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs @@ -29,7 +29,7 @@ internal sealed class GeminiChatCompletionClient : ClientBase private readonly Uri _chatGenerationEndpoint; private readonly Uri _chatStreamingEndpoint; - private static readonly string s_namespace = typeof(GeminiChatCompletionClient).Namespace!; + private static readonly string s_namespace = typeof(GoogleAIGeminiChatCompletionService).Namespace!; /// /// The maximum number of auto-invokes that can be in-flight at any given time as part of the current @@ -622,7 +622,28 @@ private static void ValidateGeminiResponse(GeminiResponse geminiResponse) } private void LogUsage(List chatMessageContents) - => this.LogUsageMetadata(chatMessageContents[0].Metadata!); + { + GeminiMetadata? metadata = chatMessageContents[0].Metadata; + + if (metadata is null || metadata.TotalTokenCount <= 0) + { + this.Logger.LogDebug("Token usage information unavailable."); + return; + } + + if (this.Logger.IsEnabled(LogLevel.Information)) + { + this.Logger.LogInformation( + "Prompt tokens: {PromptTokens}. Completion tokens: {CompletionTokens}. Total tokens: {TotalTokens}.", + metadata.PromptTokenCount, + metadata.CandidatesTokenCount, + metadata.TotalTokenCount); + } + + s_promptTokensCounter.Add(metadata.PromptTokenCount); + s_completionTokensCounter.Add(metadata.CandidatesTokenCount); + s_totalTokensCounter.Add(metadata.TotalTokenCount); + } private List GetChatMessageContentsFromResponse(GeminiResponse geminiResponse) => geminiResponse.Candidates!.Select(candidate => this.GetChatMessageContentFromCandidate(geminiResponse, candidate)).ToList(); @@ -707,28 +728,6 @@ private static void ValidateAutoInvoke(bool autoInvoke, int resultsPerPrompt) ResponseSafetyRatings = candidate.SafetyRatings?.ToList(), }; - private void LogUsageMetadata(GeminiMetadata metadata) - { - if (metadata.TotalTokenCount <= 0) - { - this.Logger.LogDebug("Gemini usage information is not available."); - return; - } - - if (this.Logger.IsEnabled(LogLevel.Debug)) - { - this.Logger.LogDebug( - "Gemini usage metadata: Candidates tokens: {CandidatesTokens}, Prompt tokens: {PromptTokens}, Total tokens: {TotalTokens}", - metadata.CandidatesTokenCount, - metadata.PromptTokenCount, - metadata.TotalTokenCount); - } - - s_promptTokensCounter.Add(metadata.PromptTokenCount); - s_completionTokensCounter.Add(metadata.CandidatesTokenCount); - s_totalTokensCounter.Add(metadata.TotalTokenCount); - } - private sealed class ChatCompletionState { internal ChatHistory ChatHistory { get; set; } = null!; diff --git a/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceMessageApiClient.cs b/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceMessageApiClient.cs index 80c7563eb555..66bd8cdbf365 100644 --- a/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceMessageApiClient.cs +++ b/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceMessageApiClient.cs @@ -27,7 +27,7 @@ internal sealed class HuggingFaceMessageApiClient { private readonly HuggingFaceClient _clientCore; - private static readonly string s_namespace = typeof(HuggingFaceMessageApiClient).Namespace!; + private static readonly string s_namespace = typeof(HuggingFaceChatCompletionService).Namespace!; /// /// Instance of for metrics. @@ -179,20 +179,25 @@ internal sealed class HuggingFaceMessageApiClient private void LogChatCompletionUsage(HuggingFacePromptExecutionSettings executionSettings, ChatCompletionResponse chatCompletionResponse) { - if (this._clientCore.Logger.IsEnabled(LogLevel.Debug)) + if (chatCompletionResponse.Usage is null) { - this._clientCore.Logger.Log( - LogLevel.Debug, - "HuggingFace chat completion usage - ModelId: {ModelId}, Prompt tokens: {PromptTokens}, Completion tokens: {CompletionTokens}, Total tokens: {TotalTokens}", - chatCompletionResponse.Model, - chatCompletionResponse.Usage!.PromptTokens, - chatCompletionResponse.Usage!.CompletionTokens, - chatCompletionResponse.Usage!.TotalTokens); + this._clientCore.Logger.LogDebug("Token usage information unavailable."); + return; } - s_promptTokensCounter.Add(chatCompletionResponse.Usage!.PromptTokens); - s_completionTokensCounter.Add(chatCompletionResponse.Usage!.CompletionTokens); - s_totalTokensCounter.Add(chatCompletionResponse.Usage!.TotalTokens); + if (this._clientCore.Logger.IsEnabled(LogLevel.Information)) + { + this._clientCore.Logger.LogInformation( + "Prompt tokens: {PromptTokens}. Completion tokens: {CompletionTokens}. Total tokens: {TotalTokens}. ModelId: {ModelId}.", + chatCompletionResponse.Usage.PromptTokens, + chatCompletionResponse.Usage.CompletionTokens, + chatCompletionResponse.Usage.TotalTokens, + chatCompletionResponse.Model); + } + + s_promptTokensCounter.Add(chatCompletionResponse.Usage.PromptTokens); + s_completionTokensCounter.Add(chatCompletionResponse.Usage.CompletionTokens); + s_totalTokensCounter.Add(chatCompletionResponse.Usage.TotalTokens); } private static List GetChatMessageContentsFromResponse(ChatCompletionResponse response, string modelId) diff --git a/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralClient.cs b/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralClient.cs index 2b179dca872a..78c9e6dce33f 100644 --- a/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralClient.cs +++ b/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralClient.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Diagnostics; +using System.Diagnostics.Metrics; using System.IO; using System.Linq; using System.Net.Http; @@ -26,8 +27,6 @@ namespace Microsoft.SemanticKernel.Connectors.MistralAI.Client; /// internal sealed class MistralClient { - private const string ModelProvider = "mistralai"; - internal MistralClient( string modelId, HttpClient httpClient, @@ -67,6 +66,7 @@ internal async Task> GetChatMessageContentsAsy { using var httpRequestMessage = this.CreatePost(chatRequest, endpoint, this._apiKey, stream: false); responseData = await this.SendRequestAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false); + this.LogUsage(responseData?.Usage); if (responseData is null || responseData.Choices is null || responseData.Choices.Count == 0) { throw new KernelException("Chat completions not found"); @@ -572,6 +572,9 @@ internal async Task>> GenerateEmbeddingsAsync(IList< private readonly ILogger _logger; private readonly StreamJsonParser _streamJsonParser; + /// Provider name used for diagnostics. + private const string ModelProvider = "mistralai"; + /// /// The maximum number of auto-invokes that can be in-flight at any given time as part of the current /// asynchronous chain of execution. @@ -593,6 +596,63 @@ internal async Task>> GenerateEmbeddingsAsync(IList< /// Tracking for . private static readonly AsyncLocal s_inflightAutoInvokes = new(); + private static readonly string s_namespace = typeof(MistralAIChatCompletionService).Namespace!; + + /// + /// Instance of for metrics. + /// + private static readonly Meter s_meter = new(s_namespace); + + /// + /// Instance of to keep track of the number of prompt tokens used. + /// + private static readonly Counter s_promptTokensCounter = + s_meter.CreateCounter( + name: $"{s_namespace}.tokens.prompt", + unit: "{token}", + description: "Number of prompt tokens used"); + + /// + /// Instance of to keep track of the number of completion tokens used. + /// + private static readonly Counter s_completionTokensCounter = + s_meter.CreateCounter( + name: $"{s_namespace}.tokens.completion", + unit: "{token}", + description: "Number of completion tokens used"); + + /// + /// Instance of to keep track of the total number of tokens used. + /// + private static readonly Counter s_totalTokensCounter = + s_meter.CreateCounter( + name: $"{s_namespace}.tokens.total", + unit: "{token}", + description: "Number of tokens used"); + + /// Log token usage to the logger and metrics. + private void LogUsage(MistralUsage? usage) + { + if (usage is null || usage.PromptTokens is null || usage.CompletionTokens is null || usage.TotalTokens is null) + { + this._logger.LogDebug("Usage information unavailable."); + return; + } + + if (this._logger.IsEnabled(LogLevel.Information)) + { + this._logger.LogInformation( + "Prompt tokens: {PromptTokens}. Completion tokens: {CompletionTokens}. Total tokens: {TotalTokens}.", + usage.PromptTokens, + usage.CompletionTokens, + usage.TotalTokens); + } + + s_promptTokensCounter.Add(usage.PromptTokens.Value); + s_completionTokensCounter.Add(usage.CompletionTokens.Value); + s_totalTokensCounter.Add(usage.TotalTokens.Value); + } + /// /// Messages are required and the first prompt role should be user or system. /// diff --git a/dotnet/src/Connectors/Connectors.MistralAI/MistralAIKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.MistralAI/MistralAIKernelBuilderExtensions.cs index 92e1fd3098a7..90e7e762d3c3 100644 --- a/dotnet/src/Connectors/Connectors.MistralAI/MistralAIKernelBuilderExtensions.cs +++ b/dotnet/src/Connectors/Connectors.MistralAI/MistralAIKernelBuilderExtensions.cs @@ -3,6 +3,7 @@ using System; using System.Net.Http; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.Connectors.MistralAI; using Microsoft.SemanticKernel.Embeddings; @@ -38,7 +39,7 @@ public static class MistralAIKernelBuilderExtensions Verify.NotNullOrWhiteSpace(apiKey); builder.Services.AddKeyedSingleton(serviceId, (serviceProvider, _) => - new MistralAIChatCompletionService(modelId, apiKey, endpoint, HttpClientProvider.GetHttpClient(httpClient, serviceProvider))); + new MistralAIChatCompletionService(modelId, apiKey, endpoint, HttpClientProvider.GetHttpClient(httpClient, serviceProvider), serviceProvider.GetService())); return builder; } @@ -64,7 +65,7 @@ public static class MistralAIKernelBuilderExtensions Verify.NotNull(builder); builder.Services.AddKeyedSingleton(serviceId, (serviceProvider, _) => - new MistralAITextEmbeddingGenerationService(modelId, apiKey, endpoint, HttpClientProvider.GetHttpClient(httpClient, serviceProvider))); + new MistralAITextEmbeddingGenerationService(modelId, apiKey, endpoint, HttpClientProvider.GetHttpClient(httpClient, serviceProvider), serviceProvider.GetService())); return builder; } diff --git a/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs b/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs index 47da5614adf2..c51c74667525 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs @@ -166,7 +166,7 @@ internal ClientCore(ILogger? logger = null) activity?.SetCompletionResponse(responseContent, responseData.Usage.PromptTokens, responseData.Usage.CompletionTokens); } - this.CaptureUsageDetails(responseData.Usage); + this.LogUsage(responseData.Usage); return responseContent; } @@ -396,7 +396,7 @@ internal ClientCore(ILogger? logger = null) try { responseData = (await RunRequestAsync(() => this.Client.GetChatCompletionsAsync(chatOptions, cancellationToken)).ConfigureAwait(false)).Value; - this.CaptureUsageDetails(responseData.Usage); + this.LogUsage(responseData.Usage); if (responseData.Choices.Count == 0) { throw new KernelException("Chat completions not found"); @@ -1435,11 +1435,11 @@ private static async Task RunRequestAsync(Func> request) /// Captures usage details, including token information. /// /// Instance of with usage details. - private void CaptureUsageDetails(CompletionsUsage usage) + private void LogUsage(CompletionsUsage usage) { if (usage is null) { - this.Logger.LogDebug("Usage information is not available."); + this.Logger.LogDebug("Token usage information unavailable."); return; }