Skip to content

Commit

Permalink
.Net: Fix MistralAI logging (#6315)
Browse files Browse the repository at this point in the history
- The logger factory wasn't being forwarded to the chat completion
service instance
- The class wasn't logging tokens like the other connectors

Also made the others consistent in verbiage, metrics namespace, etc.
  • Loading branch information
stephentoub committed May 17, 2024
1 parent 51af5ee commit 0c89e0b
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ internal sealed class GeminiChatCompletionClient : ClientBase
private readonly Uri _chatGenerationEndpoint;
private readonly Uri _chatStreamingEndpoint;

private static readonly string s_namespace = typeof(GeminiChatCompletionClient).Namespace!;
private static readonly string s_namespace = typeof(GoogleAIGeminiChatCompletionService).Namespace!;

/// <summary>
/// The maximum number of auto-invokes that can be in-flight at any given time as part of the current
Expand Down Expand Up @@ -622,7 +622,28 @@ private static void ValidateGeminiResponse(GeminiResponse geminiResponse)
}

private void LogUsage(List<GeminiChatMessageContent> chatMessageContents)
=> this.LogUsageMetadata(chatMessageContents[0].Metadata!);
{
GeminiMetadata? metadata = chatMessageContents[0].Metadata;

if (metadata is null || metadata.TotalTokenCount <= 0)
{
this.Logger.LogDebug("Token usage information unavailable.");
return;
}

if (this.Logger.IsEnabled(LogLevel.Information))
{
this.Logger.LogInformation(
"Prompt tokens: {PromptTokens}. Completion tokens: {CompletionTokens}. Total tokens: {TotalTokens}.",
metadata.PromptTokenCount,
metadata.CandidatesTokenCount,
metadata.TotalTokenCount);
}

s_promptTokensCounter.Add(metadata.PromptTokenCount);
s_completionTokensCounter.Add(metadata.CandidatesTokenCount);
s_totalTokensCounter.Add(metadata.TotalTokenCount);
}

private List<GeminiChatMessageContent> GetChatMessageContentsFromResponse(GeminiResponse geminiResponse)
=> geminiResponse.Candidates!.Select(candidate => this.GetChatMessageContentFromCandidate(geminiResponse, candidate)).ToList();
Expand Down Expand Up @@ -707,28 +728,6 @@ private static void ValidateAutoInvoke(bool autoInvoke, int resultsPerPrompt)
ResponseSafetyRatings = candidate.SafetyRatings?.ToList(),
};

private void LogUsageMetadata(GeminiMetadata metadata)
{
if (metadata.TotalTokenCount <= 0)
{
this.Logger.LogDebug("Gemini usage information is not available.");
return;
}

if (this.Logger.IsEnabled(LogLevel.Debug))
{
this.Logger.LogDebug(
"Gemini usage metadata: Candidates tokens: {CandidatesTokens}, Prompt tokens: {PromptTokens}, Total tokens: {TotalTokens}",
metadata.CandidatesTokenCount,
metadata.PromptTokenCount,
metadata.TotalTokenCount);
}

s_promptTokensCounter.Add(metadata.PromptTokenCount);
s_completionTokensCounter.Add(metadata.CandidatesTokenCount);
s_totalTokensCounter.Add(metadata.TotalTokenCount);
}

private sealed class ChatCompletionState
{
internal ChatHistory ChatHistory { get; set; } = null!;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ internal sealed class HuggingFaceMessageApiClient
{
private readonly HuggingFaceClient _clientCore;

private static readonly string s_namespace = typeof(HuggingFaceMessageApiClient).Namespace!;
private static readonly string s_namespace = typeof(HuggingFaceChatCompletionService).Namespace!;

/// <summary>
/// Instance of <see cref="Meter"/> for metrics.
Expand Down Expand Up @@ -179,20 +179,25 @@ internal sealed class HuggingFaceMessageApiClient

private void LogChatCompletionUsage(HuggingFacePromptExecutionSettings executionSettings, ChatCompletionResponse chatCompletionResponse)
{
if (this._clientCore.Logger.IsEnabled(LogLevel.Debug))
if (chatCompletionResponse.Usage is null)
{
this._clientCore.Logger.Log(
LogLevel.Debug,
"HuggingFace chat completion usage - ModelId: {ModelId}, Prompt tokens: {PromptTokens}, Completion tokens: {CompletionTokens}, Total tokens: {TotalTokens}",
chatCompletionResponse.Model,
chatCompletionResponse.Usage!.PromptTokens,
chatCompletionResponse.Usage!.CompletionTokens,
chatCompletionResponse.Usage!.TotalTokens);
this._clientCore.Logger.LogDebug("Token usage information unavailable.");
return;
}

s_promptTokensCounter.Add(chatCompletionResponse.Usage!.PromptTokens);
s_completionTokensCounter.Add(chatCompletionResponse.Usage!.CompletionTokens);
s_totalTokensCounter.Add(chatCompletionResponse.Usage!.TotalTokens);
if (this._clientCore.Logger.IsEnabled(LogLevel.Information))
{
this._clientCore.Logger.LogInformation(
"Prompt tokens: {PromptTokens}. Completion tokens: {CompletionTokens}. Total tokens: {TotalTokens}. ModelId: {ModelId}.",
chatCompletionResponse.Usage.PromptTokens,
chatCompletionResponse.Usage.CompletionTokens,
chatCompletionResponse.Usage.TotalTokens,
chatCompletionResponse.Model);
}

s_promptTokensCounter.Add(chatCompletionResponse.Usage.PromptTokens);
s_completionTokensCounter.Add(chatCompletionResponse.Usage.CompletionTokens);
s_totalTokensCounter.Add(chatCompletionResponse.Usage.TotalTokens);
}

private static List<ChatMessageContent> GetChatMessageContentsFromResponse(ChatCompletionResponse response, string modelId)
Expand Down
64 changes: 62 additions & 2 deletions dotnet/src/Connectors/Connectors.MistralAI/Client/MistralClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.Metrics;
using System.IO;
using System.Linq;
using System.Net.Http;
Expand All @@ -26,8 +27,6 @@ namespace Microsoft.SemanticKernel.Connectors.MistralAI.Client;
/// </summary>
internal sealed class MistralClient
{
private const string ModelProvider = "mistralai";

internal MistralClient(
string modelId,
HttpClient httpClient,
Expand Down Expand Up @@ -67,6 +66,7 @@ internal async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsy
{
using var httpRequestMessage = this.CreatePost(chatRequest, endpoint, this._apiKey, stream: false);
responseData = await this.SendRequestAsync<ChatCompletionResponse>(httpRequestMessage, cancellationToken).ConfigureAwait(false);
this.LogUsage(responseData?.Usage);
if (responseData is null || responseData.Choices is null || responseData.Choices.Count == 0)
{
throw new KernelException("Chat completions not found");
Expand Down Expand Up @@ -572,6 +572,9 @@ internal async Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(IList<
private readonly ILogger _logger;
private readonly StreamJsonParser _streamJsonParser;

/// <summary>Provider name used for diagnostics.</summary>
private const string ModelProvider = "mistralai";

/// <summary>
/// The maximum number of auto-invokes that can be in-flight at any given time as part of the current
/// asynchronous chain of execution.
Expand All @@ -593,6 +596,63 @@ internal async Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(IList<
/// <summary>Tracking <see cref="AsyncLocal{Int32}"/> for <see cref="MaxInflightAutoInvokes"/>.</summary>
private static readonly AsyncLocal<int> s_inflightAutoInvokes = new();

private static readonly string s_namespace = typeof(MistralAIChatCompletionService).Namespace!;

/// <summary>
/// Instance of <see cref="Meter"/> for metrics.
/// </summary>
private static readonly Meter s_meter = new(s_namespace);

/// <summary>
/// Instance of <see cref="Counter{T}"/> to keep track of the number of prompt tokens used.
/// </summary>
private static readonly Counter<int> s_promptTokensCounter =
s_meter.CreateCounter<int>(
name: $"{s_namespace}.tokens.prompt",
unit: "{token}",
description: "Number of prompt tokens used");

/// <summary>
/// Instance of <see cref="Counter{T}"/> to keep track of the number of completion tokens used.
/// </summary>
private static readonly Counter<int> s_completionTokensCounter =
s_meter.CreateCounter<int>(
name: $"{s_namespace}.tokens.completion",
unit: "{token}",
description: "Number of completion tokens used");

/// <summary>
/// Instance of <see cref="Counter{T}"/> to keep track of the total number of tokens used.
/// </summary>
private static readonly Counter<int> s_totalTokensCounter =
s_meter.CreateCounter<int>(
name: $"{s_namespace}.tokens.total",
unit: "{token}",
description: "Number of tokens used");

/// <summary>Log token usage to the logger and metrics.</summary>
private void LogUsage(MistralUsage? usage)
{
if (usage is null || usage.PromptTokens is null || usage.CompletionTokens is null || usage.TotalTokens is null)
{
this._logger.LogDebug("Usage information unavailable.");
return;
}

if (this._logger.IsEnabled(LogLevel.Information))
{
this._logger.LogInformation(
"Prompt tokens: {PromptTokens}. Completion tokens: {CompletionTokens}. Total tokens: {TotalTokens}.",
usage.PromptTokens,
usage.CompletionTokens,
usage.TotalTokens);
}

s_promptTokensCounter.Add(usage.PromptTokens.Value);
s_completionTokensCounter.Add(usage.CompletionTokens.Value);
s_totalTokensCounter.Add(usage.TotalTokens.Value);
}

/// <summary>
/// Messages are required and the first prompt role should be user or system.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System;
using System.Net.Http;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.MistralAI;
using Microsoft.SemanticKernel.Embeddings;
Expand Down Expand Up @@ -38,7 +39,7 @@ public static class MistralAIKernelBuilderExtensions
Verify.NotNullOrWhiteSpace(apiKey);

builder.Services.AddKeyedSingleton<IChatCompletionService>(serviceId, (serviceProvider, _) =>
new MistralAIChatCompletionService(modelId, apiKey, endpoint, HttpClientProvider.GetHttpClient(httpClient, serviceProvider)));
new MistralAIChatCompletionService(modelId, apiKey, endpoint, HttpClientProvider.GetHttpClient(httpClient, serviceProvider), serviceProvider.GetService<ILoggerFactory>()));

return builder;
}
Expand All @@ -64,7 +65,7 @@ public static class MistralAIKernelBuilderExtensions
Verify.NotNull(builder);

builder.Services.AddKeyedSingleton<ITextEmbeddingGenerationService>(serviceId, (serviceProvider, _) =>
new MistralAITextEmbeddingGenerationService(modelId, apiKey, endpoint, HttpClientProvider.GetHttpClient(httpClient, serviceProvider)));
new MistralAITextEmbeddingGenerationService(modelId, apiKey, endpoint, HttpClientProvider.GetHttpClient(httpClient, serviceProvider), serviceProvider.GetService<ILoggerFactory>()));

return builder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ internal ClientCore(ILogger? logger = null)
activity?.SetCompletionResponse(responseContent, responseData.Usage.PromptTokens, responseData.Usage.CompletionTokens);
}

this.CaptureUsageDetails(responseData.Usage);
this.LogUsage(responseData.Usage);

return responseContent;
}
Expand Down Expand Up @@ -396,7 +396,7 @@ internal ClientCore(ILogger? logger = null)
try
{
responseData = (await RunRequestAsync(() => this.Client.GetChatCompletionsAsync(chatOptions, cancellationToken)).ConfigureAwait(false)).Value;
this.CaptureUsageDetails(responseData.Usage);
this.LogUsage(responseData.Usage);
if (responseData.Choices.Count == 0)
{
throw new KernelException("Chat completions not found");
Expand Down Expand Up @@ -1435,11 +1435,11 @@ private static async Task<T> RunRequestAsync<T>(Func<Task<T>> request)
/// Captures usage details, including token information.
/// </summary>
/// <param name="usage">Instance of <see cref="CompletionsUsage"/> with usage details.</param>
private void CaptureUsageDetails(CompletionsUsage usage)
private void LogUsage(CompletionsUsage usage)
{
if (usage is null)
{
this.Logger.LogDebug("Usage information is not available.");
this.Logger.LogDebug("Token usage information unavailable.");
return;
}

Expand Down

0 comments on commit 0c89e0b

Please sign in to comment.