Skip to content

Commit

Permalink
.Net: Add activities to MistralClient (#6297)
Browse files Browse the repository at this point in the history
Replicates the ModelDiagnostics stuff to the MistralAI chat completion
service implementation.

I still need to test it. Best I can say now is it compiles :)

cc: @markwallace-microsoft, @TaoChenOSU
  • Loading branch information
stephentoub committed May 16, 2024
1 parent a136cd4 commit aa98754
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,9 @@ internal sealed class GeminiChatCompletionClient : ClientBase
.ConfigureAwait(false);
chatResponses = this.ProcessChatResponse(geminiResponse);
}
catch (Exception ex)
catch (Exception ex) when (activity is not null)
{
activity?.SetError(ex);
activity.SetError(ex);
throw;
}

Expand Down Expand Up @@ -259,9 +259,9 @@ internal sealed class GeminiChatCompletionClient : ClientBase
break;
}
}
catch (Exception ex)
catch (Exception ex) when (activity is not null)
{
activity?.SetError(ex);
activity.SetError(ex);
throw;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,9 @@ internal HttpRequestMessage CreatePost(object requestData, Uri endpoint, string?

response = DeserializeResponse<TextGenerationResponse>(body);
}
catch (Exception ex)
catch (Exception ex) when (activity is not null)
{
activity?.SetError(ex);
activity.SetError(ex);
throw;
}

Expand Down Expand Up @@ -204,9 +204,9 @@ internal HttpRequestMessage CreatePost(object requestData, Uri endpoint, string?
break;
}
}
catch (Exception ex)
catch (Exception ex) when (activity is not null)
{
activity?.SetError(ex);
activity.SetError(ex);
throw;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,9 @@ internal sealed class HuggingFaceMessageApiClient
break;
}
}
catch (Exception ex)
catch (Exception ex) when (activity is not null)
{
activity?.SetError(ex);
activity.SetError(ex);
throw;
}

Expand Down Expand Up @@ -162,9 +162,9 @@ internal sealed class HuggingFaceMessageApiClient

response = HuggingFaceClient.DeserializeResponse<ChatCompletionResponse>(body);
}
catch (Exception ex)
catch (Exception ex) when (activity is not null)
{
activity?.SetError(ex);
activity.SetError(ex);
throw;
}

Expand Down
139 changes: 110 additions & 29 deletions dotnet/src/Connectors/Connectors.MistralAI/Client/MistralClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Diagnostics;
using Microsoft.SemanticKernel.Http;
using Microsoft.SemanticKernel.Text;

Expand All @@ -25,6 +26,8 @@ namespace Microsoft.SemanticKernel.Connectors.MistralAI.Client;
/// </summary>
internal sealed class MistralClient
{
private const string ModelProvider = "mistralai";

internal MistralClient(
string modelId,
HttpClient httpClient,
Expand Down Expand Up @@ -56,18 +59,56 @@ internal async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsy

for (int requestIndex = 1; ; requestIndex++)
{
using var httpRequestMessage = this.CreatePost(chatRequest, endpoint, this._apiKey, stream: false);
var responseData = await this.SendRequestAsync<ChatCompletionResponse>(httpRequestMessage, cancellationToken).ConfigureAwait(false);
if (responseData is null || responseData.Choices is null || responseData.Choices.Count == 0)
ChatCompletionResponse? responseData = null;
List<ChatMessageContent> responseContent;
using (var activity = ModelDiagnostics.StartCompletionActivity(this._endpoint, this._modelId, ModelProvider, chatHistory, mistralExecutionSettings))
{
throw new KernelException("Chat completions not found");
try
{
using var httpRequestMessage = this.CreatePost(chatRequest, endpoint, this._apiKey, stream: false);
responseData = await this.SendRequestAsync<ChatCompletionResponse>(httpRequestMessage, cancellationToken).ConfigureAwait(false);
if (responseData is null || responseData.Choices is null || responseData.Choices.Count == 0)
{
throw new KernelException("Chat completions not found");
}
}
catch (Exception ex) when (activity is not null)
{
activity.SetError(ex);

// Capture available metadata even if the operation failed.
if (responseData is not null)
{
if (responseData.Id is string id)
{
activity.SetResponseId(id);
}

if (responseData.Usage is MistralUsage usage)
{
if (usage.PromptTokens is int promptTokens)
{
activity.SetPromptTokenUsage(promptTokens);
}
if (usage.CompletionTokens is int completionTokens)
{
activity.SetCompletionTokenUsage(completionTokens);
}
}
}

throw;
}

responseContent = this.ToChatMessageContent(modelId, responseData);
activity?.SetCompletionResponse(responseContent, responseData.Usage?.PromptTokens, responseData.Usage?.CompletionTokens);
}

// If we don't want to attempt to invoke any functions, just return the result.
// Or if we are auto-invoking but we somehow end up with other than 1 choice even though only 1 was requested, similarly bail.
if (!autoInvoke || responseData.Choices.Count != 1)
{
return this.ToChatMessageContent(modelId, responseData);
return responseContent;
}

// Get our single result and extract the function call information. If this isn't a function call, or if it is
Expand All @@ -78,7 +119,7 @@ internal async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsy
MistralChatChoice chatChoice = responseData.Choices[0]; // TODO Handle multiple choices
if (!chatChoice.IsToolCall)
{
return this.ToChatMessageContent(modelId, responseData);
return responseContent;
}

if (this._logger.IsEnabled(LogLevel.Debug))
Expand Down Expand Up @@ -237,35 +278,75 @@ internal async IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMes
toolCalls?.Clear();

// Stream the responses
var response = this.StreamChatMessageContentsAsync(chatHistory, mistralExecutionSettings, chatRequest, modelId, cancellationToken);
string? streamedRole = null;
await foreach (var update in response.ConfigureAwait(false))
using (var activity = ModelDiagnostics.StartCompletionActivity(this._endpoint, this._modelId, ModelProvider, chatHistory, mistralExecutionSettings))
{
// If we're intending to invoke function calls, we need to consume that function call information.
if (autoInvoke)
// Make the request.
IAsyncEnumerable<StreamingChatMessageContent> response;
try
{
if (update.InnerContent is not MistralChatCompletionChunk completionChunk || completionChunk.Choices is null || completionChunk.Choices?.Count == 0)
{
continue;
}
response = this.StreamChatMessageContentsAsync(chatHistory, mistralExecutionSettings, chatRequest, modelId, cancellationToken);
}
catch (Exception e) when (activity is not null)
{
activity.SetError(e);
throw;
}

MistralChatCompletionChoice chatChoice = completionChunk!.Choices![0]; // TODO Handle multiple choices
streamedRole ??= chatChoice.Delta!.Role;
if (chatChoice.IsToolCall)
var responseEnumerator = response.ConfigureAwait(false).GetAsyncEnumerator();
List<StreamingKernelContent>? streamedContents = activity is not null ? [] : null;
string? streamedRole = null;
try
{
while (true)
{
// Create a copy of the tool calls to avoid modifying the original list
toolCalls = new List<MistralToolCall>(chatChoice.ToolCalls!);

// Add the original assistant message to the chatRequest; this is required for the service
// to understand the tool call responses. Also add the result message to the caller's chat
// history: if they don't want it, they can remove it, but this makes the data available,
// including metadata like usage.
chatRequest.AddMessage(new MistralChatMessage(streamedRole, completionChunk.GetContent(0)) { ToolCalls = chatChoice.ToolCalls });
chatHistory.Add(this.ToChatMessageContent(modelId, streamedRole!, completionChunk, chatChoice));
try
{
if (!await responseEnumerator.MoveNextAsync())
{
break;
}
}
catch (Exception ex) when (activity is not null)
{
activity.SetError(ex);
throw;
}

StreamingChatMessageContent update = responseEnumerator.Current;

// If we're intending to invoke function calls, we need to consume that function call information.
if (autoInvoke)
{
if (update.InnerContent is not MistralChatCompletionChunk completionChunk || completionChunk.Choices is null || completionChunk.Choices?.Count == 0)
{
continue;
}

MistralChatCompletionChoice chatChoice = completionChunk!.Choices![0]; // TODO Handle multiple choices
streamedRole ??= chatChoice.Delta!.Role;
if (chatChoice.IsToolCall)
{
// Create a copy of the tool calls to avoid modifying the original list
toolCalls = new List<MistralToolCall>(chatChoice.ToolCalls!);

// Add the original assistant message to the chatRequest; this is required for the service
// to understand the tool call responses. Also add the result message to the caller's chat
// history: if they don't want it, they can remove it, but this makes the data available,
// including metadata like usage.
chatRequest.AddMessage(new MistralChatMessage(streamedRole, completionChunk.GetContent(0)) { ToolCalls = chatChoice.ToolCalls });
chatHistory.Add(this.ToChatMessageContent(modelId, streamedRole!, completionChunk, chatChoice));
}
}

streamedContents?.Add(update);
yield return update;
}
}

yield return update;
finally
{
activity?.EndStreaming(streamedContents);
await responseEnumerator.DisposeAsync();
}
}

// If we don't have a function to invoke, we're done.
Expand Down
28 changes: 14 additions & 14 deletions dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,13 @@ internal ClientCore(ILogger? logger = null)
throw new KernelException("Text completions not found");
}
}
catch (Exception ex)
catch (Exception ex) when (activity is not null)
{
activity?.SetError(ex);
activity.SetError(ex);
if (responseData != null)
{
// Capture available metadata even if the operation failed.
activity?
activity
.SetResponseId(responseData.Id)
.SetPromptTokenUsage(responseData.Usage.PromptTokens)
.SetCompletionTokenUsage(responseData.Usage.CompletionTokens);
Expand Down Expand Up @@ -190,9 +190,9 @@ internal ClientCore(ILogger? logger = null)
{
response = await RunRequestAsync(() => this.Client.GetCompletionsStreamingAsync(options, cancellationToken)).ConfigureAwait(false);
}
catch (Exception ex)
catch (Exception ex) when (activity is not null)
{
activity?.SetError(ex);
activity.SetError(ex);
throw;
}

Expand All @@ -209,9 +209,9 @@ internal ClientCore(ILogger? logger = null)
break;
}
}
catch (Exception ex)
catch (Exception ex) when (activity is not null)
{
activity?.SetError(ex);
activity.SetError(ex);
throw;
}

Expand Down Expand Up @@ -402,13 +402,13 @@ internal ClientCore(ILogger? logger = null)
throw new KernelException("Chat completions not found");
}
}
catch (Exception ex)
catch (Exception ex) when (activity is not null)
{
activity?.SetError(ex);
activity.SetError(ex);
if (responseData != null)
{
// Capture available metadata even if the operation failed.
activity?
activity
.SetResponseId(responseData.Id)
.SetPromptTokenUsage(responseData.Usage.PromptTokens)
.SetCompletionTokenUsage(responseData.Usage.CompletionTokens);
Expand Down Expand Up @@ -671,9 +671,9 @@ static void AddResponseMessage(ChatCompletionsOptions chatOptions, ChatHistory c
{
response = await RunRequestAsync(() => this.Client.GetChatCompletionsStreamingAsync(chatOptions, cancellationToken)).ConfigureAwait(false);
}
catch (Exception ex)
catch (Exception ex) when (activity is not null)
{
activity?.SetError(ex);
activity.SetError(ex);
throw;
}

Expand All @@ -690,9 +690,9 @@ static void AddResponseMessage(ChatCompletionsOptions chatOptions, ChatHistory c
break;
}
}
catch (Exception ex)
catch (Exception ex) when (activity is not null)
{
activity?.SetError(ex);
activity.SetError(ex);
throw;
}

Expand Down

0 comments on commit aa98754

Please sign in to comment.