From f79fb20428576dd6344169e9bb10490850a215db Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Mon, 13 May 2024 23:00:56 -0700 Subject: [PATCH 01/11] OTel model diagnostics: streaming APIs --- .../Demos/TelemetryWithAppInsights/Program.cs | 14 ++- .../Clients/GeminiChatCompletionClient.cs | 34 +++-- .../Core/HuggingFaceClient.cs | 27 ++-- .../Core/HuggingFaceMessageApiClient.cs | 27 ++-- .../Connectors.OpenAI/AzureSdk/ClientCore.cs | 72 ++++++++--- .../src/Diagnostics/ModelDiagnostics.cs | 118 +++++++++++++++++- 6 files changed, 245 insertions(+), 47 deletions(-) diff --git a/dotnet/samples/Demos/TelemetryWithAppInsights/Program.cs b/dotnet/samples/Demos/TelemetryWithAppInsights/Program.cs index 7fc1093c4d9d..b85f35f84cb3 100644 --- a/dotnet/samples/Demos/TelemetryWithAppInsights/Program.cs +++ b/dotnet/samples/Demos/TelemetryWithAppInsights/Program.cs @@ -158,20 +158,26 @@ private static async Task RunHuggingFaceChatAsync(Kernel kernel) private static async Task RunChatAsync(Kernel kernel) { + // Using non-streaming to get the poem. var poem = await kernel.InvokeAsync( "WriterPlugin", "ShortPoem", new KernelArguments { ["input"] = "Write a poem about John Doe." }); - var translatedPoem = await kernel.InvokeAsync( + Console.WriteLine($"Poem:\n{poem}\n\n"); + + // Use streaming to translate the poem. + Console.WriteLine("Translated Poem:"); + await foreach (var update in kernel.InvokeStreamingAsync( "WriterPlugin", "Translate", new KernelArguments { ["input"] = poem, ["language"] = "Italian" - }); - - Console.WriteLine($"Poem:\n{poem}\n\nTranslated Poem:\n{translatedPoem}"); + })) + { + Console.Write(update); + } } private static Kernel GetKernel(ILoggerFactory loggerFactory) diff --git a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs index 8e19ddb09144..79936276f0d6 100644 --- a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs +++ b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs @@ -226,15 +226,33 @@ internal sealed class GeminiChatCompletionClient : ClientBase for (state.Iteration = 1; ; state.Iteration++) { - using var httpRequestMessage = await this.CreateHttpRequestAsync(state.GeminiRequest, this._chatStreamingEndpoint).ConfigureAwait(false); - using var response = await this.SendRequestAndGetResponseImmediatelyAfterHeadersReadAsync(httpRequestMessage, cancellationToken) - .ConfigureAwait(false); - using var responseStream = await response.Content.ReadAsStreamAndTranslateExceptionAsync() - .ConfigureAwait(false); - - await foreach (var messageContent in this.GetStreamingChatMessageContentsOrPopulateStateForToolCallingAsync(state, responseStream, cancellationToken).ConfigureAwait(false)) + using (var activity = ModelDiagnostics.StartCompletionActivity( + this._chatGenerationEndpoint, this._modelId, ModelProvider, chatHistory, executionSettings)) { - yield return messageContent; + HttpResponseMessage httpResponseMessage; + Stream responseStream; + try + { + using var httpRequestMessage = await this.CreateHttpRequestAsync(state.GeminiRequest, this._chatStreamingEndpoint).ConfigureAwait(false); + // We cannot dispose these two objects leaving the try-catch block because we need them to read the response stream + httpResponseMessage = await this.SendRequestAndGetResponseImmediatelyAfterHeadersReadAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false); + responseStream = await httpResponseMessage.Content.ReadAsStreamAndTranslateExceptionAsync().ConfigureAwait(false); + } + catch (Exception ex) + { + activity?.SetError(ex); + throw; + } + + await foreach (var messageContent in this.GetStreamingChatMessageContentsOrPopulateStateForToolCallingAsync(state, responseStream, cancellationToken).ConfigureAwait(false)) + { + activity?.AddStreamingContent(messageContent); + yield return messageContent; + } + + activity?.EndStreaming(); + httpResponseMessage.Dispose(); + responseStream.Dispose(); } if (!state.AutoInvoke) diff --git a/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceClient.cs b/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceClient.cs index f93903094fad..e4187068ba2d 100644 --- a/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceClient.cs +++ b/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceClient.cs @@ -169,18 +169,31 @@ internal HttpRequestMessage CreatePost(object requestData, Uri endpoint, string? var request = this.CreateTextRequest(prompt, executionSettings); request.Stream = true; - using var httpRequestMessage = this.CreatePost(request, endpoint, this.ApiKey); - - using var response = await this.SendRequestAndGetResponseImmediatelyAfterHeadersReadAsync(httpRequestMessage, cancellationToken) - .ConfigureAwait(false); - - using var responseStream = await response.Content.ReadAsStreamAndTranslateExceptionAsync() - .ConfigureAwait(false); + using var activity = ModelDiagnostics.StartCompletionActivity(endpoint, modelId, this.ModelProvider, prompt, executionSettings); + HttpResponseMessage httpResponseMessage; + Stream responseStream; + try + { + using var httpRequestMessage = this.CreatePost(request, endpoint, this.ApiKey); + // We cannot dispose these two objects leaving the try-catch block because we need them to read the response stream + httpResponseMessage = await this.SendRequestAndGetResponseImmediatelyAfterHeadersReadAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false); + responseStream = await httpResponseMessage.Content.ReadAsStreamAndTranslateExceptionAsync().ConfigureAwait(false); + } + catch (Exception ex) + { + activity?.SetError(ex); + throw; + } await foreach (var streamingTextContent in this.ProcessTextResponseStreamAsync(responseStream, modelId, cancellationToken).ConfigureAwait(false)) { + activity?.AddStreamingContent(streamingTextContent); yield return streamingTextContent; } + + activity?.EndStreaming(); + httpResponseMessage.Dispose(); + responseStream.Dispose(); } private async IAsyncEnumerable ProcessTextResponseStreamAsync(Stream stream, string modelId, [EnumeratorCancellation] CancellationToken cancellationToken) diff --git a/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceMessageApiClient.cs b/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceMessageApiClient.cs index 10b587788719..91199dcb40f1 100644 --- a/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceMessageApiClient.cs +++ b/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceMessageApiClient.cs @@ -85,18 +85,31 @@ internal sealed class HuggingFaceMessageApiClient var request = this.CreateChatRequest(chatHistory, executionSettings); request.Stream = true; - using var httpRequestMessage = this._clientCore.CreatePost(request, endpoint, this._clientCore.ApiKey); - - using var response = await this._clientCore.SendRequestAndGetResponseImmediatelyAfterHeadersReadAsync(httpRequestMessage, cancellationToken) - .ConfigureAwait(false); - - using var responseStream = await response.Content.ReadAsStreamAndTranslateExceptionAsync() - .ConfigureAwait(false); + using var activity = ModelDiagnostics.StartCompletionActivity(endpoint, modelId, this._clientCore.ModelProvider, chatHistory, executionSettings); + HttpResponseMessage httpResponseMessage; + Stream responseStream; + try + { + using var httpRequestMessage = this._clientCore.CreatePost(request, endpoint, this._clientCore.ApiKey); + // We cannot dispose these two objects leaving the try-catch block because we need them to read the response stream + httpResponseMessage = await this._clientCore.SendRequestAndGetResponseImmediatelyAfterHeadersReadAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false); + responseStream = await httpResponseMessage.Content.ReadAsStreamAndTranslateExceptionAsync().ConfigureAwait(false); + } + catch (Exception ex) + { + activity?.SetError(ex); + throw; + } await foreach (var streamingChatContent in this.ProcessChatResponseStreamAsync(responseStream, modelId, cancellationToken).ConfigureAwait(false)) { + activity?.AddStreamingContent(streamingChatContent); yield return streamingChatContent; } + + activity?.EndStreaming(); + httpResponseMessage.Dispose(); + responseStream.Dispose(); } internal async Task> CompleteChatMessageAsync( diff --git a/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs b/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs index aa2bb962ae6e..2a9067b6ade4 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs @@ -119,13 +119,13 @@ internal ClientCore(ILogger? logger = null) /// /// Creates completions for the prompt and settings. /// - /// The prompt to complete. + /// The prompt to complete. /// Execution settings for the completion API. /// The containing services, plugins, and other state for use throughout the operation. /// The to monitor for cancellation requests. The default is . /// Completions generated by the remote model internal async Task> GetTextResultsAsync( - string text, + string prompt, PromptExecutionSettings? executionSettings, Kernel? kernel, CancellationToken cancellationToken = default) @@ -134,11 +134,11 @@ internal ClientCore(ILogger? logger = null) ValidateMaxTokens(textExecutionSettings.MaxTokens); - var options = CreateCompletionsOptions(text, textExecutionSettings, this.DeploymentOrModelName); + var options = CreateCompletionsOptions(prompt, textExecutionSettings, this.DeploymentOrModelName); Completions? responseData = null; List responseContent; - using (var activity = ModelDiagnostics.StartCompletionActivity(this.Endpoint, this.DeploymentOrModelName, ModelProvider, text, executionSettings)) + using (var activity = ModelDiagnostics.StartCompletionActivity(this.Endpoint, this.DeploymentOrModelName, ModelProvider, prompt, executionSettings)) { try { @@ -183,15 +183,30 @@ internal ClientCore(ILogger? logger = null) var options = CreateCompletionsOptions(prompt, textExecutionSettings, this.DeploymentOrModelName); - StreamingResponse? response = await RunRequestAsync(() => this.Client.GetCompletionsStreamingAsync(options, cancellationToken)).ConfigureAwait(false); + using var activity = ModelDiagnostics.StartCompletionActivity(this.Endpoint, this.DeploymentOrModelName, ModelProvider, prompt, executionSettings); + + StreamingResponse? response; + try + { + response = await RunRequestAsync(() => this.Client.GetCompletionsStreamingAsync(options, cancellationToken)).ConfigureAwait(false); + } + catch (Exception ex) + { + activity?.SetError(ex); + throw; + } await foreach (Completions completions in response.ConfigureAwait(false)) { foreach (Choice choice in completions.Choices) { - yield return new OpenAIStreamingTextContent(choice.Text, choice.Index, this.DeploymentOrModelName, choice, GetTextChoiceMetadata(completions, choice)); + var openAIStreamingTextContent = new OpenAIStreamingTextContent(choice.Text, choice.Index, this.DeploymentOrModelName, choice, GetTextChoiceMetadata(completions, choice)); + activity?.AddStreamingContent(openAIStreamingTextContent); + yield return openAIStreamingTextContent; } } + + activity?.EndStreaming(); } private static Dictionary GetTextChoiceMetadata(Completions completions, Choice choice) @@ -613,9 +628,6 @@ static void AddResponseMessage(ChatCompletionsOptions chatOptions, ChatHistory c for (int requestIndex = 1; ; requestIndex++) { - // Make the request. - var response = await RunRequestAsync(() => this.Client.GetChatCompletionsStreamingAsync(chatOptions, cancellationToken)).ConfigureAwait(false); - // Reset state contentBuilder?.Clear(); toolCallIdsByIndex?.Clear(); @@ -627,25 +639,45 @@ static void AddResponseMessage(ChatCompletionsOptions chatOptions, ChatHistory c string? streamedName = null; ChatRole? streamedRole = default; CompletionsFinishReason finishReason = default; - await foreach (StreamingChatCompletionsUpdate update in response.ConfigureAwait(false)) + + // Make the request. + using (var activity = ModelDiagnostics.StartCompletionActivity(this.Endpoint, this.DeploymentOrModelName, ModelProvider, chat, executionSettings)) { - metadata = GetResponseMetadata(update); - streamedRole ??= update.Role; - streamedName ??= update.AuthorName; - finishReason = update.FinishReason ?? default; + StreamingResponse? response; + try + { + response = await RunRequestAsync(() => this.Client.GetChatCompletionsStreamingAsync(chatOptions, cancellationToken)).ConfigureAwait(false); + } + catch (Exception ex) + { + activity?.SetError(ex); + throw; + } - // If we're intending to invoke function calls, we need to consume that function call information. - if (autoInvoke) + await foreach (StreamingChatCompletionsUpdate update in response.ConfigureAwait(false)) { - if (update.ContentUpdate is { Length: > 0 } contentUpdate) + metadata = GetResponseMetadata(update); + streamedRole ??= update.Role; + streamedName ??= update.AuthorName; + finishReason = update.FinishReason ?? default; + + // If we're intending to invoke function calls, we need to consume that function call information. + if (autoInvoke) { - (contentBuilder ??= new()).Append(contentUpdate); + if (update.ContentUpdate is { Length: > 0 } contentUpdate) + { + (contentBuilder ??= new()).Append(contentUpdate); + } + + OpenAIFunctionToolCall.TrackStreamingToolingUpdate(update.ToolCallUpdate, ref toolCallIdsByIndex, ref functionNamesByIndex, ref functionArgumentBuildersByIndex); } - OpenAIFunctionToolCall.TrackStreamingToolingUpdate(update.ToolCallUpdate, ref toolCallIdsByIndex, ref functionNamesByIndex, ref functionArgumentBuildersByIndex); + var openAIStreamingChatMessageContent = new OpenAIStreamingChatMessageContent(update, update.ChoiceIndex ?? 0, this.DeploymentOrModelName, metadata) { AuthorName = streamedName }; + activity?.AddStreamingContent(openAIStreamingChatMessageContent); + yield return openAIStreamingChatMessageContent; } - yield return new OpenAIStreamingChatMessageContent(update, update.ChoiceIndex ?? 0, this.DeploymentOrModelName, metadata) { AuthorName = streamedName }; + activity?.EndStreaming(); } // If we don't have a function to invoke, we're done. diff --git a/dotnet/src/InternalUtilities/src/Diagnostics/ModelDiagnostics.cs b/dotnet/src/InternalUtilities/src/Diagnostics/ModelDiagnostics.cs index 6ae98bb6e8e6..18b935762024 100644 --- a/dotnet/src/InternalUtilities/src/Diagnostics/ModelDiagnostics.cs +++ b/dotnet/src/InternalUtilities/src/Diagnostics/ModelDiagnostics.cs @@ -26,6 +26,7 @@ internal static class ModelDiagnostics { private static readonly string s_namespace = typeof(ModelDiagnostics).Namespace!; private static readonly ActivitySource s_activitySource = new(s_namespace); + private static readonly ActivityListener s_activityListener = new(); private const string EnableDiagnosticsSwitch = "Microsoft.SemanticKernel.Experimental.GenAI.EnableOTelDiagnostics"; private const string EnableSensitiveEventsSwitch = "Microsoft.SemanticKernel.Experimental.GenAI.EnableOTelDiagnosticsSensitive"; @@ -35,6 +36,26 @@ internal static class ModelDiagnostics private static readonly bool s_enableDiagnostics = AppContextSwitchHelper.GetConfigValue(EnableDiagnosticsSwitch, EnableDiagnosticsEnvVar); private static readonly bool s_enableSensitiveEvents = AppContextSwitchHelper.GetConfigValue(EnableSensitiveEventsSwitch, EnableSensitiveEventsEnvVar); + /// + /// Stores streaming text content for activities of streaming completions. + /// + private static readonly Dictionary> s_streamingContents = []; + + static ModelDiagnostics() + { + s_activityListener.ShouldListenTo = activitySource => activitySource.Name == s_namespace; + s_activityListener.ActivityStopped = activity => + { + // Called when an activity is stopped. Clean up the streaming content in case `EndStreaming` is not called. + // This action needs to be idempotent as the event may be fired multiple times. + if (activity.Id is not null) + { + s_streamingContents.Remove(activity.Id); + } + }; + ActivitySource.AddActivityListener(s_activityListener); + } + /// /// Start a text completion activity for a given model. /// The activity will be tagged with the a set of attributes specified by the semantic conventions. @@ -63,6 +84,43 @@ public static void SetCompletionResponse(this Activity activity, IEnumerable completions, int? promptTokens = null, int? completionTokens = null) => SetCompletionResponse(activity, completions, promptTokens, completionTokens, ToOpenAIFormat); + /// + /// Add streaming content to the activity. + /// + /// The activity to add the streaming content + /// The streaming content + public static void AddStreamingContent(this Activity activity, T content) where T : StreamingKernelContent + { + if (IsModelDiagnosticsEnabled() && activity.Id is not null) + { + if (!s_streamingContents.TryGetValue(activity.Id, out var contents)) + { + contents = []; + s_streamingContents[activity.Id] = contents; + } + + contents.Add(content); + } + } + + /// + /// Notify the end of streaming for a given activity. + /// + public static void EndStreaming(this Activity activity, int? promptTokens = null, int? completionTokens = null) + { + if (activity.Id is not null && IsModelDiagnosticsEnabled()) + { + if (s_streamingContents.TryGetValue(activity.Id, out var contents)) + { + var choices = OrganizeStreamingContent(contents); + SetCompletionResponse(activity, choices, promptTokens, completionTokens); + } + + // Remove the streaming content after it's processed + s_streamingContents.Remove(activity.Id); + } + } + /// /// Set the response id for a given activity. /// @@ -87,7 +145,7 @@ public static void SetCompletionResponse(this Activity activity, IEnumerableThe activity with the completion token usage set for chaining public static Activity SetCompletionTokenUsage(this Activity activity, int completionTokens) => activity.SetTag(ModelDiagnosticsTags.CompletionToken, completionTokens); - # region Private + #region Private /// /// Check if model diagnostics is enabled /// Model diagnostics is enabled if either EnableModelDiagnostics or EnableSensitiveEvents is set to true and there are listeners. @@ -238,6 +296,44 @@ private static string ToOpenAIFormat(IEnumerable chatHistory } } + /// + /// Set the streaming completion response for a given activity. + /// + private static void SetCompletionResponse( + Activity activity, + Dictionary> choices, + int? promptTokens, + int? completionTokens) + { + if (!IsModelDiagnosticsEnabled()) + { + return; + } + + // Assuming all metadata is in the last chunk of the choice + switch (choices.FirstOrDefault().Value.FirstOrDefault()) + { + case StreamingTextContent: + var textCompletions = choices.Select(choiceContents => + { + var lastContent = (StreamingTextContent)choiceContents.Value.Last(); + var text = choiceContents.Value.Select(c => c.ToString()).Aggregate((a, b) => a + b); + return new TextContent(text, metadata: lastContent.Metadata); + }).ToList(); + SetCompletionResponse(activity, textCompletions, promptTokens, completionTokens, completions => $"[{string.Join(", ", completions)}"); + break; + case StreamingChatMessageContent: + var chatCompletions = choices.Select(choiceContents => + { + var lastContent = (StreamingChatMessageContent)choiceContents.Value.Last(); + var chatMessage = choiceContents.Value.Select(c => c.ToString()).Aggregate((a, b) => a + b); + return new ChatMessageContent(lastContent.Role ?? AuthorRole.Assistant, chatMessage, metadata: lastContent.Metadata); + }).ToList(); + SetCompletionResponse(activity, chatCompletions, promptTokens, completionTokens, ToOpenAIFormat); + break; + }; + } + // Returns an activity for chaining private static Activity SetFinishReasons(this Activity activity, IEnumerable completions) { @@ -270,6 +366,26 @@ private static Activity SetResponseId(this Activity activity, KernelContent? com return activity; } + /// + /// Organize streaming content by choice index + /// + private static Dictionary> OrganizeStreamingContent(IEnumerable contents) + { + Dictionary> choices = []; + foreach (var content in contents) + { + if (!choices.TryGetValue(content.ChoiceIndex, out var choiceContents)) + { + choiceContents = []; + choices[content.ChoiceIndex] = choiceContents; + } + + choiceContents.Add(content); + } + + return choices; + } + /// /// Tags used in model diagnostics /// From e4876ec05f0e9c1fa514f51bbb1e4f57e4101cf0 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Tue, 14 May 2024 14:08:11 -0700 Subject: [PATCH 02/11] Parse tool calls in chat message content --- .../Demos/TelemetryWithAppInsights/Program.cs | 124 +++++++++++++++--- .../src/Diagnostics/ModelDiagnostics.cs | 38 +++++- 2 files changed, 141 insertions(+), 21 deletions(-) diff --git a/dotnet/samples/Demos/TelemetryWithAppInsights/Program.cs b/dotnet/samples/Demos/TelemetryWithAppInsights/Program.cs index b85f35f84cb3..2e7ec44a037d 100644 --- a/dotnet/samples/Demos/TelemetryWithAppInsights/Program.cs +++ b/dotnet/samples/Demos/TelemetryWithAppInsights/Program.cs @@ -77,11 +77,26 @@ public static async Task Main() Console.WriteLine(); Console.WriteLine("Write a poem about John Doe and translate it to Italian."); - await RunAzureOpenAIChatAsync(kernel); + using (var _ = s_activitySource.StartActivity("Chat")) + { + await RunAzureOpenAIChatAsync(kernel); + Console.WriteLine(); + await RunGoogleAIChatAsync(kernel); + Console.WriteLine(); + await RunHuggingFaceChatAsync(kernel); + } + Console.WriteLine(); - await RunGoogleAIChatAsync(kernel); Console.WriteLine(); - await RunHuggingFaceChatAsync(kernel); + + Console.WriteLine("Get weather."); + using (var _ = s_activitySource.StartActivity("ToolCalls")) + { + await RunAzureOpenAIToolCallsAsync(kernel); + Console.WriteLine(); + await RunGoogleAIToolCallAsync(kernel); + // HuggingFace does not support tool calls yet. + } } #region Private @@ -99,16 +114,17 @@ public static async Task Main() /// private static readonly ActivitySource s_activitySource = new("Telemetry.Example"); - private const string AzureOpenAIChatServiceKey = "AzureOpenAIChat"; - private const string GoogleAIGeminiChatServiceKey = "GoogleAIGeminiChat"; - private const string HuggingFaceChatServiceKey = "HuggingFaceChat"; + private const string AzureOpenAIServiceKey = "AzureOpenAI"; + private const string GoogleAIGeminiServiceKey = "GoogleAIGemini"; + private const string HuggingFaceServiceKey = "HuggingFace"; + #region chat completion private static async Task RunAzureOpenAIChatAsync(Kernel kernel) { Console.WriteLine("============= Azure OpenAI Chat Completion ============="); - using var activity = s_activitySource.StartActivity(AzureOpenAIChatServiceKey); - SetTargetService(kernel, AzureOpenAIChatServiceKey); + using var activity = s_activitySource.StartActivity(AzureOpenAIServiceKey); + SetTargetService(kernel, AzureOpenAIServiceKey); try { await RunChatAsync(kernel); @@ -124,8 +140,8 @@ private static async Task RunGoogleAIChatAsync(Kernel kernel) { Console.WriteLine("============= Google Gemini Chat Completion ============="); - using var activity = s_activitySource.StartActivity(GoogleAIGeminiChatServiceKey); - SetTargetService(kernel, GoogleAIGeminiChatServiceKey); + using var activity = s_activitySource.StartActivity(GoogleAIGeminiServiceKey); + SetTargetService(kernel, GoogleAIGeminiServiceKey); try { @@ -142,8 +158,8 @@ private static async Task RunHuggingFaceChatAsync(Kernel kernel) { Console.WriteLine("============= HuggingFace Chat Completion ============="); - using var activity = s_activitySource.StartActivity(HuggingFaceChatServiceKey); - SetTargetService(kernel, HuggingFaceChatServiceKey); + using var activity = s_activitySource.StartActivity(HuggingFaceServiceKey); + SetTargetService(kernel, HuggingFaceServiceKey); try { @@ -163,7 +179,7 @@ private static async Task RunChatAsync(Kernel kernel) "WriterPlugin", "ShortPoem", new KernelArguments { ["input"] = "Write a poem about John Doe." }); - Console.WriteLine($"Poem:\n{poem}\n\n"); + Console.WriteLine($"Poem:\n{poem}\n"); // Use streaming to translate the poem. Console.WriteLine("Translated Poem:"); @@ -179,6 +195,50 @@ private static async Task RunChatAsync(Kernel kernel) Console.Write(update); } } + #endregion + + #region tool calls + private static async Task RunAzureOpenAIToolCallsAsync(Kernel kernel) + { + Console.WriteLine("============= Azure OpenAI ToolCalls ============="); + + using var activity = s_activitySource.StartActivity(AzureOpenAIServiceKey); + SetTargetService(kernel, AzureOpenAIServiceKey); + try + { + await RunAutoToolCallAsync(kernel); + } + catch (Exception ex) + { + activity?.SetStatus(ActivityStatusCode.Error, ex.Message); + Console.WriteLine($"Error: {ex.Message}"); + } + } + + private static async Task RunGoogleAIToolCallAsync(Kernel kernel) + { + Console.WriteLine("============= Google Gemini ToolCalls ============="); + + using var activity = s_activitySource.StartActivity(GoogleAIGeminiServiceKey); + SetTargetService(kernel, GoogleAIGeminiServiceKey); + try + { + await RunAutoToolCallAsync(kernel); + } + catch (Exception ex) + { + activity?.SetStatus(ActivityStatusCode.Error, ex.Message); + Console.WriteLine($"Error: {ex.Message}"); + } + } + + private static async Task RunAutoToolCallAsync(Kernel kernel) + { + var result = await kernel.InvokePromptAsync("What is the weather like in Seattle?"); + + Console.WriteLine(result); + } + #endregion private static Kernel GetKernel(ILoggerFactory loggerFactory) { @@ -193,19 +253,21 @@ private static Kernel GetKernel(ILoggerFactory loggerFactory) modelId: TestConfiguration.AzureOpenAI.ChatModelId, endpoint: TestConfiguration.AzureOpenAI.Endpoint, apiKey: TestConfiguration.AzureOpenAI.ApiKey, - serviceId: AzureOpenAIChatServiceKey) + serviceId: AzureOpenAIServiceKey) .AddGoogleAIGeminiChatCompletion( modelId: TestConfiguration.GoogleAI.Gemini.ModelId, apiKey: TestConfiguration.GoogleAI.ApiKey, - serviceId: GoogleAIGeminiChatServiceKey) + serviceId: GoogleAIGeminiServiceKey) .AddHuggingFaceChatCompletion( model: TestConfiguration.HuggingFace.ModelId, endpoint: new Uri("https://api-inference.huggingface.co"), apiKey: TestConfiguration.HuggingFace.ApiKey, - serviceId: HuggingFaceChatServiceKey); + serviceId: HuggingFaceServiceKey); builder.Services.AddSingleton(new AIServiceSelector()); builder.Plugins.AddFromPromptDirectory(Path.Combine(folder, "WriterPlugin")); + builder.Plugins.AddFromType(); + builder.Plugins.AddFromType(); return builder.Build(); } @@ -246,9 +308,17 @@ private sealed class AIServiceSelector : IAIServiceSelector service = targetService; serviceSettings = targetServiceKey switch { - AzureOpenAIChatServiceKey => new OpenAIPromptExecutionSettings(), - GoogleAIGeminiChatServiceKey => new GeminiPromptExecutionSettings(), - HuggingFaceChatServiceKey => new HuggingFacePromptExecutionSettings(), + AzureOpenAIServiceKey => new OpenAIPromptExecutionSettings() + { + Temperature = 0, + ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions + }, + GoogleAIGeminiServiceKey => new GeminiPromptExecutionSettings() + { + Temperature = 0, + ToolCallBehavior = GeminiToolCallBehavior.AutoInvokeKernelFunctions + }, + HuggingFaceServiceKey => new HuggingFacePromptExecutionSettings(), _ => null, }; @@ -262,4 +332,20 @@ private sealed class AIServiceSelector : IAIServiceSelector } } #endregion + + #region Plugins + + public sealed class WeatherPlugin + { + [KernelFunction] + public string GetWeather(string location) => $"Weather in {location} is 70°F."; + } + + public sealed class LocationPlugin + { + [KernelFunction] + public string GetCurrentLocation() => "Seattle"; + } + + #endregion } diff --git a/dotnet/src/InternalUtilities/src/Diagnostics/ModelDiagnostics.cs b/dotnet/src/InternalUtilities/src/Diagnostics/ModelDiagnostics.cs index 18b935762024..f5c4a2d6926d 100644 --- a/dotnet/src/InternalUtilities/src/Diagnostics/ModelDiagnostics.cs +++ b/dotnet/src/InternalUtilities/src/Diagnostics/ModelDiagnostics.cs @@ -194,9 +194,43 @@ private static string ToOpenAIFormat(IEnumerable chatHistory sb.Append("{\"role\": \""); sb.Append(message.Role); - sb.Append("\", \"content\": \""); + sb.Append("\", \"content\": "); sb.Append(JsonSerializer.Serialize(message.Content)); - sb.Append("\"}"); + sb.Append(", \"tool_calls\": "); + sb.Append(ToOpenAIFormat(message.Items)); + sb.Append('}'); + + isFirst = false; + } + sb.Append(']'); + + return sb.ToString(); + } + + /// + /// Convert tool calls to a string aligned with the OpenAI format + /// + private static string ToOpenAIFormat(ChatMessageContentItemCollection chatMessageContentItems) + { + var sb = new StringBuilder(); + sb.Append('['); + var isFirst = true; + foreach (var functionCall in chatMessageContentItems.OfType()) + { + if (!isFirst) + { + // Append a comma and a newline to separate the elements after the previous one. + // This can avoid adding an unnecessary comma after the last element. + sb.Append(", \n"); + } + + sb.Append("{\"id\": \""); + sb.Append(functionCall.Id); + sb.Append("\", \"function\": {\"arguments\": "); + sb.Append(JsonSerializer.Serialize(functionCall.Arguments)); + sb.Append(", \"name\": \""); + sb.Append(functionCall.FunctionName); + sb.Append("\"}, \"type\": \"function\"}"); isFirst = false; } From 3259664bf3650471e9d3143cbe2447dfe7a11bab Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Tue, 14 May 2024 14:42:30 -0700 Subject: [PATCH 03/11] Dispose resources in catch block --- .../Core/Gemini/Clients/GeminiChatCompletionClient.cs | 10 ++++++---- .../Connectors.HuggingFace/Core/HuggingFaceClient.cs | 10 ++++++---- .../Core/HuggingFaceMessageApiClient.cs | 10 ++++++---- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs index 79936276f0d6..6f5d7dbaa337 100644 --- a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs +++ b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs @@ -229,8 +229,8 @@ internal sealed class GeminiChatCompletionClient : ClientBase using (var activity = ModelDiagnostics.StartCompletionActivity( this._chatGenerationEndpoint, this._modelId, ModelProvider, chatHistory, executionSettings)) { - HttpResponseMessage httpResponseMessage; - Stream responseStream; + HttpResponseMessage? httpResponseMessage = null; + Stream? responseStream = null; try { using var httpRequestMessage = await this.CreateHttpRequestAsync(state.GeminiRequest, this._chatStreamingEndpoint).ConfigureAwait(false); @@ -241,6 +241,8 @@ internal sealed class GeminiChatCompletionClient : ClientBase catch (Exception ex) { activity?.SetError(ex); + httpResponseMessage?.Dispose(); + responseStream?.Dispose(); throw; } @@ -251,8 +253,8 @@ await foreach (var messageContent in this.GetStreamingChatMessageContentsOrPopul } activity?.EndStreaming(); - httpResponseMessage.Dispose(); - responseStream.Dispose(); + httpResponseMessage?.Dispose(); + responseStream?.Dispose(); } if (!state.AutoInvoke) diff --git a/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceClient.cs b/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceClient.cs index e4187068ba2d..0bd70279134a 100644 --- a/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceClient.cs +++ b/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceClient.cs @@ -170,8 +170,8 @@ internal HttpRequestMessage CreatePost(object requestData, Uri endpoint, string? request.Stream = true; using var activity = ModelDiagnostics.StartCompletionActivity(endpoint, modelId, this.ModelProvider, prompt, executionSettings); - HttpResponseMessage httpResponseMessage; - Stream responseStream; + HttpResponseMessage? httpResponseMessage = null; + Stream? responseStream = null; try { using var httpRequestMessage = this.CreatePost(request, endpoint, this.ApiKey); @@ -182,6 +182,8 @@ internal HttpRequestMessage CreatePost(object requestData, Uri endpoint, string? catch (Exception ex) { activity?.SetError(ex); + httpResponseMessage?.Dispose(); + responseStream?.Dispose(); throw; } @@ -192,8 +194,8 @@ await foreach (var streamingTextContent in this.ProcessTextResponseStreamAsync(r } activity?.EndStreaming(); - httpResponseMessage.Dispose(); - responseStream.Dispose(); + httpResponseMessage?.Dispose(); + responseStream?.Dispose(); } private async IAsyncEnumerable ProcessTextResponseStreamAsync(Stream stream, string modelId, [EnumeratorCancellation] CancellationToken cancellationToken) diff --git a/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceMessageApiClient.cs b/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceMessageApiClient.cs index 91199dcb40f1..5680e16e1d86 100644 --- a/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceMessageApiClient.cs +++ b/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceMessageApiClient.cs @@ -86,8 +86,8 @@ internal sealed class HuggingFaceMessageApiClient request.Stream = true; using var activity = ModelDiagnostics.StartCompletionActivity(endpoint, modelId, this._clientCore.ModelProvider, chatHistory, executionSettings); - HttpResponseMessage httpResponseMessage; - Stream responseStream; + HttpResponseMessage? httpResponseMessage = null; + Stream? responseStream = null; try { using var httpRequestMessage = this._clientCore.CreatePost(request, endpoint, this._clientCore.ApiKey); @@ -98,6 +98,8 @@ internal sealed class HuggingFaceMessageApiClient catch (Exception ex) { activity?.SetError(ex); + httpResponseMessage?.Dispose(); + responseStream?.Dispose(); throw; } @@ -108,8 +110,8 @@ await foreach (var streamingChatContent in this.ProcessChatResponseStreamAsync(r } activity?.EndStreaming(); - httpResponseMessage.Dispose(); - responseStream.Dispose(); + httpResponseMessage?.Dispose(); + responseStream?.Dispose(); } internal async Task> CompleteChatMessageAsync( From 9884b55ce4752e024f2df8ea678f8a84fc4c4749 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Tue, 14 May 2024 14:47:25 -0700 Subject: [PATCH 04/11] Dispose resources in a finally block --- .../Clients/GeminiChatCompletionClient.cs | 19 ++++++++++++------- .../Core/HuggingFaceClient.cs | 19 ++++++++++++------- .../Core/HuggingFaceMessageApiClient.cs | 19 ++++++++++++------- 3 files changed, 36 insertions(+), 21 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs index 6f5d7dbaa337..92f71301d0c6 100644 --- a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs +++ b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs @@ -246,15 +246,20 @@ internal sealed class GeminiChatCompletionClient : ClientBase throw; } - await foreach (var messageContent in this.GetStreamingChatMessageContentsOrPopulateStateForToolCallingAsync(state, responseStream, cancellationToken).ConfigureAwait(false)) + try { - activity?.AddStreamingContent(messageContent); - yield return messageContent; + await foreach (var messageContent in this.GetStreamingChatMessageContentsOrPopulateStateForToolCallingAsync(state, responseStream, cancellationToken).ConfigureAwait(false)) + { + activity?.AddStreamingContent(messageContent); + yield return messageContent; + } + } + finally + { + activity?.EndStreaming(); + httpResponseMessage?.Dispose(); + responseStream?.Dispose(); } - - activity?.EndStreaming(); - httpResponseMessage?.Dispose(); - responseStream?.Dispose(); } if (!state.AutoInvoke) diff --git a/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceClient.cs b/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceClient.cs index 0bd70279134a..811985e554ec 100644 --- a/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceClient.cs +++ b/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceClient.cs @@ -187,15 +187,20 @@ internal HttpRequestMessage CreatePost(object requestData, Uri endpoint, string? throw; } - await foreach (var streamingTextContent in this.ProcessTextResponseStreamAsync(responseStream, modelId, cancellationToken).ConfigureAwait(false)) + try { - activity?.AddStreamingContent(streamingTextContent); - yield return streamingTextContent; + await foreach (var streamingTextContent in this.ProcessTextResponseStreamAsync(responseStream, modelId, cancellationToken).ConfigureAwait(false)) + { + activity?.AddStreamingContent(streamingTextContent); + yield return streamingTextContent; + } + } + finally + { + activity?.EndStreaming(); + httpResponseMessage?.Dispose(); + responseStream?.Dispose(); } - - activity?.EndStreaming(); - httpResponseMessage?.Dispose(); - responseStream?.Dispose(); } private async IAsyncEnumerable ProcessTextResponseStreamAsync(Stream stream, string modelId, [EnumeratorCancellation] CancellationToken cancellationToken) diff --git a/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceMessageApiClient.cs b/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceMessageApiClient.cs index 5680e16e1d86..b26b31da4084 100644 --- a/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceMessageApiClient.cs +++ b/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceMessageApiClient.cs @@ -103,15 +103,20 @@ internal sealed class HuggingFaceMessageApiClient throw; } - await foreach (var streamingChatContent in this.ProcessChatResponseStreamAsync(responseStream, modelId, cancellationToken).ConfigureAwait(false)) + try { - activity?.AddStreamingContent(streamingChatContent); - yield return streamingChatContent; + await foreach (var streamingChatContent in this.ProcessChatResponseStreamAsync(responseStream, modelId, cancellationToken).ConfigureAwait(false)) + { + activity?.AddStreamingContent(streamingChatContent); + yield return streamingChatContent; + } + } + finally + { + activity?.EndStreaming(); + httpResponseMessage?.Dispose(); + responseStream?.Dispose(); } - - activity?.EndStreaming(); - httpResponseMessage?.Dispose(); - responseStream?.Dispose(); } internal async Task> CompleteChatMessageAsync( From 05b918f0177415c266f2862a7fe1d3dd87e15715 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Tue, 14 May 2024 15:05:46 -0700 Subject: [PATCH 05/11] Remove LocationPlugin --- dotnet/samples/Demos/TelemetryWithAppInsights/Program.cs | 7 ------- 1 file changed, 7 deletions(-) diff --git a/dotnet/samples/Demos/TelemetryWithAppInsights/Program.cs b/dotnet/samples/Demos/TelemetryWithAppInsights/Program.cs index 2e7ec44a037d..cc691040e236 100644 --- a/dotnet/samples/Demos/TelemetryWithAppInsights/Program.cs +++ b/dotnet/samples/Demos/TelemetryWithAppInsights/Program.cs @@ -267,7 +267,6 @@ private static Kernel GetKernel(ILoggerFactory loggerFactory) builder.Services.AddSingleton(new AIServiceSelector()); builder.Plugins.AddFromPromptDirectory(Path.Combine(folder, "WriterPlugin")); builder.Plugins.AddFromType(); - builder.Plugins.AddFromType(); return builder.Build(); } @@ -341,11 +340,5 @@ public sealed class WeatherPlugin public string GetWeather(string location) => $"Weather in {location} is 70°F."; } - public sealed class LocationPlugin - { - [KernelFunction] - public string GetCurrentLocation() => "Seattle"; - } - #endregion } From 6f277529a513d0f5aa987c2c1e6f4325503f245c Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Tue, 14 May 2024 15:48:52 -0700 Subject: [PATCH 06/11] Catch exceptions in enumerating the stream --- .../Clients/GeminiChatCompletionClient.cs | 21 ++++++++++++++++--- .../Core/HuggingFaceClient.cs | 21 ++++++++++++++++--- .../Core/HuggingFaceMessageApiClient.cs | 21 ++++++++++++++++--- 3 files changed, 54 insertions(+), 9 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs index 92f71301d0c6..f37252a9bb69 100644 --- a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs +++ b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs @@ -246,13 +246,28 @@ internal sealed class GeminiChatCompletionClient : ClientBase throw; } + var responseEnumerator = this.GetStreamingChatMessageContentsOrPopulateStateForToolCallingAsync(state, responseStream, cancellationToken) + .ConfigureAwait(false) + .GetAsyncEnumerator(); try { - await foreach (var messageContent in this.GetStreamingChatMessageContentsOrPopulateStateForToolCallingAsync(state, responseStream, cancellationToken).ConfigureAwait(false)) + while (true) { - activity?.AddStreamingContent(messageContent); - yield return messageContent; + try + { + if (!await responseEnumerator.MoveNextAsync()) + { + break; + } + } + catch (Exception ex) + { + activity?.SetError(ex); + throw; + } } + + yield return responseEnumerator.Current; } finally { diff --git a/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceClient.cs b/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceClient.cs index 811985e554ec..ef23f6d16988 100644 --- a/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceClient.cs +++ b/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceClient.cs @@ -187,12 +187,27 @@ internal HttpRequestMessage CreatePost(object requestData, Uri endpoint, string? throw; } + var responseEnumerator = this.ProcessTextResponseStreamAsync(responseStream, modelId, cancellationToken) + .ConfigureAwait(false) + .GetAsyncEnumerator(); try { - await foreach (var streamingTextContent in this.ProcessTextResponseStreamAsync(responseStream, modelId, cancellationToken).ConfigureAwait(false)) + while (true) { - activity?.AddStreamingContent(streamingTextContent); - yield return streamingTextContent; + try + { + if (!await responseEnumerator.MoveNextAsync()) + { + break; + } + } + catch (Exception ex) + { + activity?.SetError(ex); + throw; + } + + yield return responseEnumerator.Current; } } finally diff --git a/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceMessageApiClient.cs b/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceMessageApiClient.cs index b26b31da4084..fae47830f0c2 100644 --- a/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceMessageApiClient.cs +++ b/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceMessageApiClient.cs @@ -103,12 +103,27 @@ internal sealed class HuggingFaceMessageApiClient throw; } + var responseEnumerator = this.ProcessChatResponseStreamAsync(responseStream, modelId, cancellationToken) + .ConfigureAwait(false) + .GetAsyncEnumerator(); try { - await foreach (var streamingChatContent in this.ProcessChatResponseStreamAsync(responseStream, modelId, cancellationToken).ConfigureAwait(false)) + while (true) { - activity?.AddStreamingContent(streamingChatContent); - yield return streamingChatContent; + try + { + if (!await responseEnumerator.MoveNextAsync()) + { + break; + } + } + catch (Exception ex) + { + activity?.SetError(ex); + throw; + } + + yield return responseEnumerator.Current; } } finally From 0fc26a5bd5a7d7d479d86ce903de5f32fc706418 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Tue, 14 May 2024 16:08:51 -0700 Subject: [PATCH 07/11] Fix: add back AddStreamingContent --- .../Core/Gemini/Clients/GeminiChatCompletionClient.cs | 1 + .../Connectors.HuggingFace/Core/HuggingFaceClient.cs | 1 + .../Core/HuggingFaceMessageApiClient.cs | 1 + .../src/InternalUtilities/src/Diagnostics/ModelDiagnostics.cs | 4 ++-- 4 files changed, 5 insertions(+), 2 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs index f37252a9bb69..0e02e1dc045a 100644 --- a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs +++ b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs @@ -267,6 +267,7 @@ internal sealed class GeminiChatCompletionClient : ClientBase } } + activity?.AddStreamingContent(responseEnumerator.Current); yield return responseEnumerator.Current; } finally diff --git a/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceClient.cs b/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceClient.cs index ef23f6d16988..94aa6897e22f 100644 --- a/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceClient.cs +++ b/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceClient.cs @@ -207,6 +207,7 @@ internal HttpRequestMessage CreatePost(object requestData, Uri endpoint, string? throw; } + activity?.AddStreamingContent(responseEnumerator.Current); yield return responseEnumerator.Current; } } diff --git a/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceMessageApiClient.cs b/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceMessageApiClient.cs index fae47830f0c2..784db77ee4ff 100644 --- a/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceMessageApiClient.cs +++ b/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceMessageApiClient.cs @@ -123,6 +123,7 @@ internal sealed class HuggingFaceMessageApiClient throw; } + activity?.AddStreamingContent(responseEnumerator.Current); yield return responseEnumerator.Current; } } diff --git a/dotnet/src/InternalUtilities/src/Diagnostics/ModelDiagnostics.cs b/dotnet/src/InternalUtilities/src/Diagnostics/ModelDiagnostics.cs index f5c4a2d6926d..8b8fe32a0b84 100644 --- a/dotnet/src/InternalUtilities/src/Diagnostics/ModelDiagnostics.cs +++ b/dotnet/src/InternalUtilities/src/Diagnostics/ModelDiagnostics.cs @@ -145,16 +145,16 @@ public static void EndStreaming(this Activity activity, int? promptTokens = null /// The activity with the completion token usage set for chaining public static Activity SetCompletionTokenUsage(this Activity activity, int completionTokens) => activity.SetTag(ModelDiagnosticsTags.CompletionToken, completionTokens); - #region Private /// /// Check if model diagnostics is enabled /// Model diagnostics is enabled if either EnableModelDiagnostics or EnableSensitiveEvents is set to true and there are listeners. /// - private static bool IsModelDiagnosticsEnabled() + public static bool IsModelDiagnosticsEnabled() { return (s_enableDiagnostics || s_enableSensitiveEvents) && s_activitySource.HasListeners(); } + #region Private private static void AddOptionalTags(Activity? activity, PromptExecutionSettings? executionSettings) { if (activity is null || executionSettings?.ExtensionData is null) From da1744a7356a69c513025de2f5d56f71a7f4c37e Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Tue, 14 May 2024 20:19:09 -0700 Subject: [PATCH 08/11] Remove unsafe dictionary in ModelDiagnostics --- .../Clients/GeminiChatCompletionClient.cs | 9 +- .../Core/HuggingFaceClient.cs | 5 +- .../Core/HuggingFaceMessageApiClient.cs | 5 +- .../Connectors.OpenAI/AzureSdk/ClientCore.cs | 91 ++++++++++++++----- .../src/Diagnostics/ModelDiagnostics.cs | 52 +---------- 5 files changed, 81 insertions(+), 81 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs index 0e02e1dc045a..433dfeb62acc 100644 --- a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs +++ b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs @@ -249,6 +249,7 @@ internal sealed class GeminiChatCompletionClient : ClientBase var responseEnumerator = this.GetStreamingChatMessageContentsOrPopulateStateForToolCallingAsync(state, responseStream, cancellationToken) .ConfigureAwait(false) .GetAsyncEnumerator(); + List streamedContents = []; try { while (true) @@ -265,14 +266,14 @@ internal sealed class GeminiChatCompletionClient : ClientBase activity?.SetError(ex); throw; } - } - activity?.AddStreamingContent(responseEnumerator.Current); - yield return responseEnumerator.Current; + streamedContents.Add(responseEnumerator.Current); + yield return responseEnumerator.Current; + } } finally { - activity?.EndStreaming(); + activity?.EndStreaming(streamedContents); httpResponseMessage?.Dispose(); responseStream?.Dispose(); } diff --git a/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceClient.cs b/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceClient.cs index 94aa6897e22f..7900560ca457 100644 --- a/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceClient.cs +++ b/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceClient.cs @@ -190,6 +190,7 @@ internal HttpRequestMessage CreatePost(object requestData, Uri endpoint, string? var responseEnumerator = this.ProcessTextResponseStreamAsync(responseStream, modelId, cancellationToken) .ConfigureAwait(false) .GetAsyncEnumerator(); + List streamedContents = []; try { while (true) @@ -207,13 +208,13 @@ internal HttpRequestMessage CreatePost(object requestData, Uri endpoint, string? throw; } - activity?.AddStreamingContent(responseEnumerator.Current); + streamedContents.Add(responseEnumerator.Current); yield return responseEnumerator.Current; } } finally { - activity?.EndStreaming(); + activity?.EndStreaming(streamedContents); httpResponseMessage?.Dispose(); responseStream?.Dispose(); } diff --git a/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceMessageApiClient.cs b/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceMessageApiClient.cs index 784db77ee4ff..3169c35741d3 100644 --- a/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceMessageApiClient.cs +++ b/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceMessageApiClient.cs @@ -106,6 +106,7 @@ internal sealed class HuggingFaceMessageApiClient var responseEnumerator = this.ProcessChatResponseStreamAsync(responseStream, modelId, cancellationToken) .ConfigureAwait(false) .GetAsyncEnumerator(); + List streamedContents = []; try { while (true) @@ -123,13 +124,13 @@ internal sealed class HuggingFaceMessageApiClient throw; } - activity?.AddStreamingContent(responseEnumerator.Current); + streamedContents.Add(responseEnumerator.Current); yield return responseEnumerator.Current; } } finally { - activity?.EndStreaming(); + activity?.EndStreaming(streamedContents); httpResponseMessage?.Dispose(); responseStream?.Dispose(); } diff --git a/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs b/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs index 2a9067b6ade4..4c01c68873af 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs @@ -196,17 +196,39 @@ internal ClientCore(ILogger? logger = null) throw; } - await foreach (Completions completions in response.ConfigureAwait(false)) + var responseEnumerator = response.ConfigureAwait(false).GetAsyncEnumerator(); + List streamedContents = []; + try { - foreach (Choice choice in completions.Choices) + while (true) { - var openAIStreamingTextContent = new OpenAIStreamingTextContent(choice.Text, choice.Index, this.DeploymentOrModelName, choice, GetTextChoiceMetadata(completions, choice)); - activity?.AddStreamingContent(openAIStreamingTextContent); - yield return openAIStreamingTextContent; + try + { + if (!await responseEnumerator.MoveNextAsync()) + { + break; + } + } + catch (Exception ex) + { + activity?.SetError(ex); + throw; + } + + Completions completions = responseEnumerator.Current; + foreach (Choice choice in completions.Choices) + { + var openAIStreamingTextContent = new OpenAIStreamingTextContent( + choice.Text, choice.Index, this.DeploymentOrModelName, choice, GetTextChoiceMetadata(completions, choice)); + streamedContents.Add(openAIStreamingTextContent); + yield return openAIStreamingTextContent; + } } } - - activity?.EndStreaming(); + finally + { + activity?.EndStreaming(streamedContents); + } } private static Dictionary GetTextChoiceMetadata(Completions completions, Choice choice) @@ -654,30 +676,51 @@ static void AddResponseMessage(ChatCompletionsOptions chatOptions, ChatHistory c throw; } - await foreach (StreamingChatCompletionsUpdate update in response.ConfigureAwait(false)) + var responseEnumerator = response.ConfigureAwait(false).GetAsyncEnumerator(); + List streamedContents = []; + try { - metadata = GetResponseMetadata(update); - streamedRole ??= update.Role; - streamedName ??= update.AuthorName; - finishReason = update.FinishReason ?? default; - - // If we're intending to invoke function calls, we need to consume that function call information. - if (autoInvoke) + while (true) { - if (update.ContentUpdate is { Length: > 0 } contentUpdate) + try { - (contentBuilder ??= new()).Append(contentUpdate); + if (!await responseEnumerator.MoveNextAsync()) + { + break; + } + } + catch (Exception ex) + { + activity?.SetError(ex); + throw; } - OpenAIFunctionToolCall.TrackStreamingToolingUpdate(update.ToolCallUpdate, ref toolCallIdsByIndex, ref functionNamesByIndex, ref functionArgumentBuildersByIndex); - } + StreamingChatCompletionsUpdate update = responseEnumerator.Current; + metadata = GetResponseMetadata(update); + streamedRole ??= update.Role; + streamedName ??= update.AuthorName; + finishReason = update.FinishReason ?? default; - var openAIStreamingChatMessageContent = new OpenAIStreamingChatMessageContent(update, update.ChoiceIndex ?? 0, this.DeploymentOrModelName, metadata) { AuthorName = streamedName }; - activity?.AddStreamingContent(openAIStreamingChatMessageContent); - yield return openAIStreamingChatMessageContent; - } + // If we're intending to invoke function calls, we need to consume that function call information. + if (autoInvoke) + { + if (update.ContentUpdate is { Length: > 0 } contentUpdate) + { + (contentBuilder ??= new()).Append(contentUpdate); + } + + OpenAIFunctionToolCall.TrackStreamingToolingUpdate(update.ToolCallUpdate, ref toolCallIdsByIndex, ref functionNamesByIndex, ref functionArgumentBuildersByIndex); + } - activity?.EndStreaming(); + var openAIStreamingChatMessageContent = new OpenAIStreamingChatMessageContent(update, update.ChoiceIndex ?? 0, this.DeploymentOrModelName, metadata) { AuthorName = streamedName }; + streamedContents.Add(openAIStreamingChatMessageContent); + yield return openAIStreamingChatMessageContent; + } + } + finally + { + activity?.EndStreaming(streamedContents); + } } // If we don't have a function to invoke, we're done. diff --git a/dotnet/src/InternalUtilities/src/Diagnostics/ModelDiagnostics.cs b/dotnet/src/InternalUtilities/src/Diagnostics/ModelDiagnostics.cs index 8b8fe32a0b84..02e97abf261c 100644 --- a/dotnet/src/InternalUtilities/src/Diagnostics/ModelDiagnostics.cs +++ b/dotnet/src/InternalUtilities/src/Diagnostics/ModelDiagnostics.cs @@ -26,7 +26,6 @@ internal static class ModelDiagnostics { private static readonly string s_namespace = typeof(ModelDiagnostics).Namespace!; private static readonly ActivitySource s_activitySource = new(s_namespace); - private static readonly ActivityListener s_activityListener = new(); private const string EnableDiagnosticsSwitch = "Microsoft.SemanticKernel.Experimental.GenAI.EnableOTelDiagnostics"; private const string EnableSensitiveEventsSwitch = "Microsoft.SemanticKernel.Experimental.GenAI.EnableOTelDiagnosticsSensitive"; @@ -36,26 +35,6 @@ internal static class ModelDiagnostics private static readonly bool s_enableDiagnostics = AppContextSwitchHelper.GetConfigValue(EnableDiagnosticsSwitch, EnableDiagnosticsEnvVar); private static readonly bool s_enableSensitiveEvents = AppContextSwitchHelper.GetConfigValue(EnableSensitiveEventsSwitch, EnableSensitiveEventsEnvVar); - /// - /// Stores streaming text content for activities of streaming completions. - /// - private static readonly Dictionary> s_streamingContents = []; - - static ModelDiagnostics() - { - s_activityListener.ShouldListenTo = activitySource => activitySource.Name == s_namespace; - s_activityListener.ActivityStopped = activity => - { - // Called when an activity is stopped. Clean up the streaming content in case `EndStreaming` is not called. - // This action needs to be idempotent as the event may be fired multiple times. - if (activity.Id is not null) - { - s_streamingContents.Remove(activity.Id); - } - }; - ActivitySource.AddActivityListener(s_activityListener); - } - /// /// Start a text completion activity for a given model. /// The activity will be tagged with the a set of attributes specified by the semantic conventions. @@ -84,40 +63,15 @@ public static void SetCompletionResponse(this Activity activity, IEnumerable completions, int? promptTokens = null, int? completionTokens = null) => SetCompletionResponse(activity, completions, promptTokens, completionTokens, ToOpenAIFormat); - /// - /// Add streaming content to the activity. - /// - /// The activity to add the streaming content - /// The streaming content - public static void AddStreamingContent(this Activity activity, T content) where T : StreamingKernelContent - { - if (IsModelDiagnosticsEnabled() && activity.Id is not null) - { - if (!s_streamingContents.TryGetValue(activity.Id, out var contents)) - { - contents = []; - s_streamingContents[activity.Id] = contents; - } - - contents.Add(content); - } - } - /// /// Notify the end of streaming for a given activity. /// - public static void EndStreaming(this Activity activity, int? promptTokens = null, int? completionTokens = null) + public static void EndStreaming(this Activity activity, IEnumerable contents, int? promptTokens = null, int? completionTokens = null) { if (activity.Id is not null && IsModelDiagnosticsEnabled()) { - if (s_streamingContents.TryGetValue(activity.Id, out var contents)) - { - var choices = OrganizeStreamingContent(contents); - SetCompletionResponse(activity, choices, promptTokens, completionTokens); - } - - // Remove the streaming content after it's processed - s_streamingContents.Remove(activity.Id); + var choices = OrganizeStreamingContent(contents); + SetCompletionResponse(activity, choices, promptTokens, completionTokens); } } From 416600f7ab80149c282577c007892baa8e024eac Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Tue, 14 May 2024 20:28:03 -0700 Subject: [PATCH 09/11] small fixes --- .../src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs | 6 +++--- .../InternalUtilities/src/Diagnostics/ModelDiagnostics.cs | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs b/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs index 4c01c68873af..e1799afd4669 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs @@ -185,7 +185,7 @@ internal ClientCore(ILogger? logger = null) using var activity = ModelDiagnostics.StartCompletionActivity(this.Endpoint, this.DeploymentOrModelName, ModelProvider, prompt, executionSettings); - StreamingResponse? response; + StreamingResponse response; try { response = await RunRequestAsync(() => this.Client.GetCompletionsStreamingAsync(options, cancellationToken)).ConfigureAwait(false); @@ -662,10 +662,10 @@ static void AddResponseMessage(ChatCompletionsOptions chatOptions, ChatHistory c ChatRole? streamedRole = default; CompletionsFinishReason finishReason = default; - // Make the request. using (var activity = ModelDiagnostics.StartCompletionActivity(this.Endpoint, this.DeploymentOrModelName, ModelProvider, chat, executionSettings)) { - StreamingResponse? response; + // Make the request. + StreamingResponse response; try { response = await RunRequestAsync(() => this.Client.GetChatCompletionsStreamingAsync(chatOptions, cancellationToken)).ConfigureAwait(false); diff --git a/dotnet/src/InternalUtilities/src/Diagnostics/ModelDiagnostics.cs b/dotnet/src/InternalUtilities/src/Diagnostics/ModelDiagnostics.cs index 02e97abf261c..2ff5c9c1250b 100644 --- a/dotnet/src/InternalUtilities/src/Diagnostics/ModelDiagnostics.cs +++ b/dotnet/src/InternalUtilities/src/Diagnostics/ModelDiagnostics.cs @@ -68,7 +68,7 @@ public static void SetCompletionResponse(this Activity activity, IEnumerable public static void EndStreaming(this Activity activity, IEnumerable contents, int? promptTokens = null, int? completionTokens = null) { - if (activity.Id is not null && IsModelDiagnosticsEnabled()) + if (IsModelDiagnosticsEnabled()) { var choices = OrganizeStreamingContent(contents); SetCompletionResponse(activity, choices, promptTokens, completionTokens); From 5beba9f8de5481e2d78db1fb33edab9c22129dea Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Wed, 15 May 2024 11:38:19 -0700 Subject: [PATCH 10/11] Addressed comments --- .../Demos/TelemetryWithAppInsights/Program.cs | 37 ++++++------------- .../TelemetryWithAppInsights.csproj | 2 +- .../Clients/GeminiChatCompletionClient.cs | 11 +++--- .../Core/HuggingFaceClient.cs | 11 +++--- .../Core/HuggingFaceMessageApiClient.cs | 11 +++--- .../Connectors.OpenAI/AzureSdk/ClientCore.cs | 8 ++-- .../src/Diagnostics/ModelDiagnostics.cs | 18 +++++---- 7 files changed, 42 insertions(+), 56 deletions(-) diff --git a/dotnet/samples/Demos/TelemetryWithAppInsights/Program.cs b/dotnet/samples/Demos/TelemetryWithAppInsights/Program.cs index cc691040e236..dc1009bb74b3 100644 --- a/dotnet/samples/Demos/TelemetryWithAppInsights/Program.cs +++ b/dotnet/samples/Demos/TelemetryWithAppInsights/Program.cs @@ -94,8 +94,6 @@ public static async Task Main() { await RunAzureOpenAIToolCallsAsync(kernel); Console.WriteLine(); - await RunGoogleAIToolCallAsync(kernel); - // HuggingFace does not support tool calls yet. } } @@ -215,26 +213,9 @@ private static async Task RunAzureOpenAIToolCallsAsync(Kernel kernel) } } - private static async Task RunGoogleAIToolCallAsync(Kernel kernel) - { - Console.WriteLine("============= Google Gemini ToolCalls ============="); - - using var activity = s_activitySource.StartActivity(GoogleAIGeminiServiceKey); - SetTargetService(kernel, GoogleAIGeminiServiceKey); - try - { - await RunAutoToolCallAsync(kernel); - } - catch (Exception ex) - { - activity?.SetStatus(ActivityStatusCode.Error, ex.Message); - Console.WriteLine($"Error: {ex.Message}"); - } - } - private static async Task RunAutoToolCallAsync(Kernel kernel) { - var result = await kernel.InvokePromptAsync("What is the weather like in Seattle?"); + var result = await kernel.InvokePromptAsync("What is the weather like in my location?"); Console.WriteLine(result); } @@ -267,6 +248,7 @@ private static Kernel GetKernel(ILoggerFactory loggerFactory) builder.Services.AddSingleton(new AIServiceSelector()); builder.Plugins.AddFromPromptDirectory(Path.Combine(folder, "WriterPlugin")); builder.Plugins.AddFromType(); + builder.Plugins.AddFromType(); return builder.Build(); } @@ -312,11 +294,7 @@ private sealed class AIServiceSelector : IAIServiceSelector Temperature = 0, ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions }, - GoogleAIGeminiServiceKey => new GeminiPromptExecutionSettings() - { - Temperature = 0, - ToolCallBehavior = GeminiToolCallBehavior.AutoInvokeKernelFunctions - }, + GoogleAIGeminiServiceKey => new GeminiPromptExecutionSettings(), HuggingFaceServiceKey => new HuggingFacePromptExecutionSettings(), _ => null, }; @@ -340,5 +318,14 @@ public sealed class WeatherPlugin public string GetWeather(string location) => $"Weather in {location} is 70°F."; } + public sealed class LocationPlugin + { + [KernelFunction] + public string GetCurrentLocation() + { + return "Seattle"; + } + } + #endregion } diff --git a/dotnet/samples/Demos/TelemetryWithAppInsights/TelemetryWithAppInsights.csproj b/dotnet/samples/Demos/TelemetryWithAppInsights/TelemetryWithAppInsights.csproj index 713b4043f3f3..26775e3a2402 100644 --- a/dotnet/samples/Demos/TelemetryWithAppInsights/TelemetryWithAppInsights.csproj +++ b/dotnet/samples/Demos/TelemetryWithAppInsights/TelemetryWithAppInsights.csproj @@ -7,7 +7,7 @@ disable false - $(NoWarn);CA1050;CA1707;CA2007;CS1591;VSTHRD111,SKEXP0050,SKEXP0060,SKEXP0070 + $(NoWarn);CA1024;CA1050;CA1707;CA2007;CS1591;VSTHRD111,SKEXP0050,SKEXP0060,SKEXP0070 5ee045b0-aea3-4f08-8d31-32d1a6f8fed0 diff --git a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs index 433dfeb62acc..79b9089da5cb 100644 --- a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs +++ b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs @@ -234,7 +234,6 @@ internal sealed class GeminiChatCompletionClient : ClientBase try { using var httpRequestMessage = await this.CreateHttpRequestAsync(state.GeminiRequest, this._chatStreamingEndpoint).ConfigureAwait(false); - // We cannot dispose these two objects leaving the try-catch block because we need them to read the response stream httpResponseMessage = await this.SendRequestAndGetResponseImmediatelyAfterHeadersReadAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false); responseStream = await httpResponseMessage.Content.ReadAsStreamAndTranslateExceptionAsync().ConfigureAwait(false); } @@ -247,16 +246,15 @@ internal sealed class GeminiChatCompletionClient : ClientBase } var responseEnumerator = this.GetStreamingChatMessageContentsOrPopulateStateForToolCallingAsync(state, responseStream, cancellationToken) - .ConfigureAwait(false) - .GetAsyncEnumerator(); - List streamedContents = []; + .GetAsyncEnumerator(cancellationToken); + List? streamedContents = activity is not null ? [] : null; try { while (true) { try { - if (!await responseEnumerator.MoveNextAsync()) + if (!await responseEnumerator.MoveNextAsync().ConfigureAwait(false)) { break; } @@ -267,7 +265,7 @@ internal sealed class GeminiChatCompletionClient : ClientBase throw; } - streamedContents.Add(responseEnumerator.Current); + streamedContents?.Add(responseEnumerator.Current); yield return responseEnumerator.Current; } } @@ -276,6 +274,7 @@ internal sealed class GeminiChatCompletionClient : ClientBase activity?.EndStreaming(streamedContents); httpResponseMessage?.Dispose(); responseStream?.Dispose(); + await responseEnumerator.DisposeAsync().ConfigureAwait(false); } } diff --git a/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceClient.cs b/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceClient.cs index 7900560ca457..a6c095738f1b 100644 --- a/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceClient.cs +++ b/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceClient.cs @@ -175,7 +175,6 @@ internal HttpRequestMessage CreatePost(object requestData, Uri endpoint, string? try { using var httpRequestMessage = this.CreatePost(request, endpoint, this.ApiKey); - // We cannot dispose these two objects leaving the try-catch block because we need them to read the response stream httpResponseMessage = await this.SendRequestAndGetResponseImmediatelyAfterHeadersReadAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false); responseStream = await httpResponseMessage.Content.ReadAsStreamAndTranslateExceptionAsync().ConfigureAwait(false); } @@ -188,16 +187,15 @@ internal HttpRequestMessage CreatePost(object requestData, Uri endpoint, string? } var responseEnumerator = this.ProcessTextResponseStreamAsync(responseStream, modelId, cancellationToken) - .ConfigureAwait(false) - .GetAsyncEnumerator(); - List streamedContents = []; + .GetAsyncEnumerator(cancellationToken); + List? streamedContents = activity is not null ? [] : null; try { while (true) { try { - if (!await responseEnumerator.MoveNextAsync()) + if (!await responseEnumerator.MoveNextAsync().ConfigureAwait(false)) { break; } @@ -208,7 +206,7 @@ internal HttpRequestMessage CreatePost(object requestData, Uri endpoint, string? throw; } - streamedContents.Add(responseEnumerator.Current); + streamedContents?.Add(responseEnumerator.Current); yield return responseEnumerator.Current; } } @@ -217,6 +215,7 @@ internal HttpRequestMessage CreatePost(object requestData, Uri endpoint, string? activity?.EndStreaming(streamedContents); httpResponseMessage?.Dispose(); responseStream?.Dispose(); + await responseEnumerator.DisposeAsync().ConfigureAwait(false); } } diff --git a/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceMessageApiClient.cs b/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceMessageApiClient.cs index 3169c35741d3..7ae142fb9cdd 100644 --- a/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceMessageApiClient.cs +++ b/dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceMessageApiClient.cs @@ -91,7 +91,6 @@ internal sealed class HuggingFaceMessageApiClient try { using var httpRequestMessage = this._clientCore.CreatePost(request, endpoint, this._clientCore.ApiKey); - // We cannot dispose these two objects leaving the try-catch block because we need them to read the response stream httpResponseMessage = await this._clientCore.SendRequestAndGetResponseImmediatelyAfterHeadersReadAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false); responseStream = await httpResponseMessage.Content.ReadAsStreamAndTranslateExceptionAsync().ConfigureAwait(false); } @@ -104,16 +103,15 @@ internal sealed class HuggingFaceMessageApiClient } var responseEnumerator = this.ProcessChatResponseStreamAsync(responseStream, modelId, cancellationToken) - .ConfigureAwait(false) - .GetAsyncEnumerator(); - List streamedContents = []; + .GetAsyncEnumerator(cancellationToken); + List? streamedContents = activity is not null ? [] : null; try { while (true) { try { - if (!await responseEnumerator.MoveNextAsync()) + if (!await responseEnumerator.MoveNextAsync().ConfigureAwait(false)) { break; } @@ -124,7 +122,7 @@ internal sealed class HuggingFaceMessageApiClient throw; } - streamedContents.Add(responseEnumerator.Current); + streamedContents?.Add(responseEnumerator.Current); yield return responseEnumerator.Current; } } @@ -133,6 +131,7 @@ internal sealed class HuggingFaceMessageApiClient activity?.EndStreaming(streamedContents); httpResponseMessage?.Dispose(); responseStream?.Dispose(); + await responseEnumerator.DisposeAsync().ConfigureAwait(false); } } diff --git a/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs b/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs index e1799afd4669..0a2cf1f43ba9 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs @@ -197,7 +197,7 @@ internal ClientCore(ILogger? logger = null) } var responseEnumerator = response.ConfigureAwait(false).GetAsyncEnumerator(); - List streamedContents = []; + List? streamedContents = activity is not null ? [] : null; try { while (true) @@ -220,7 +220,7 @@ internal ClientCore(ILogger? logger = null) { var openAIStreamingTextContent = new OpenAIStreamingTextContent( choice.Text, choice.Index, this.DeploymentOrModelName, choice, GetTextChoiceMetadata(completions, choice)); - streamedContents.Add(openAIStreamingTextContent); + streamedContents?.Add(openAIStreamingTextContent); yield return openAIStreamingTextContent; } } @@ -677,7 +677,7 @@ static void AddResponseMessage(ChatCompletionsOptions chatOptions, ChatHistory c } var responseEnumerator = response.ConfigureAwait(false).GetAsyncEnumerator(); - List streamedContents = []; + List? streamedContents = activity is not null ? [] : null; try { while (true) @@ -713,7 +713,7 @@ static void AddResponseMessage(ChatCompletionsOptions chatOptions, ChatHistory c } var openAIStreamingChatMessageContent = new OpenAIStreamingChatMessageContent(update, update.ChoiceIndex ?? 0, this.DeploymentOrModelName, metadata) { AuthorName = streamedName }; - streamedContents.Add(openAIStreamingChatMessageContent); + streamedContents?.Add(openAIStreamingChatMessageContent); yield return openAIStreamingChatMessageContent; } } diff --git a/dotnet/src/InternalUtilities/src/Diagnostics/ModelDiagnostics.cs b/dotnet/src/InternalUtilities/src/Diagnostics/ModelDiagnostics.cs index 2ff5c9c1250b..5522e0f73330 100644 --- a/dotnet/src/InternalUtilities/src/Diagnostics/ModelDiagnostics.cs +++ b/dotnet/src/InternalUtilities/src/Diagnostics/ModelDiagnostics.cs @@ -66,7 +66,7 @@ public static void SetCompletionResponse(this Activity activity, IEnumerable /// Notify the end of streaming for a given activity. /// - public static void EndStreaming(this Activity activity, IEnumerable contents, int? promptTokens = null, int? completionTokens = null) + public static void EndStreaming(this Activity activity, IEnumerable? contents, int? promptTokens = null, int? completionTokens = null) { if (IsModelDiagnosticsEnabled()) { @@ -151,7 +151,7 @@ private static string ToOpenAIFormat(IEnumerable chatHistory sb.Append("\", \"content\": "); sb.Append(JsonSerializer.Serialize(message.Content)); sb.Append(", \"tool_calls\": "); - sb.Append(ToOpenAIFormat(message.Items)); + ToOpenAIFormat(sb, message.Items); sb.Append('}'); isFirst = false; @@ -162,11 +162,10 @@ private static string ToOpenAIFormat(IEnumerable chatHistory } /// - /// Convert tool calls to a string aligned with the OpenAI format + /// Helper method to convert tool calls to a string aligned with the OpenAI format /// - private static string ToOpenAIFormat(ChatMessageContentItemCollection chatMessageContentItems) + private static void ToOpenAIFormat(StringBuilder sb, ChatMessageContentItemCollection chatMessageContentItems) { - var sb = new StringBuilder(); sb.Append('['); var isFirst = true; foreach (var functionCall in chatMessageContentItems.OfType()) @@ -189,8 +188,6 @@ private static string ToOpenAIFormat(ChatMessageContentItemCollection chatMessag isFirst = false; } sb.Append(']'); - - return sb.ToString(); } /// @@ -357,9 +354,14 @@ private static Activity SetResponseId(this Activity activity, KernelContent? com /// /// Organize streaming content by choice index /// - private static Dictionary> OrganizeStreamingContent(IEnumerable contents) + private static Dictionary> OrganizeStreamingContent(IEnumerable? contents) { Dictionary> choices = []; + if (contents is null) + { + return choices; + } + foreach (var content in contents) { if (!choices.TryGetValue(content.ChoiceIndex, out var choiceContents)) From 85dc0f0e003b4a63a822c252aaa37406e933b9b2 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Wed, 15 May 2024 11:43:44 -0700 Subject: [PATCH 11/11] Address comments 2 --- dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs b/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs index 0a2cf1f43ba9..fac60f53903e 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs @@ -228,6 +228,7 @@ internal ClientCore(ILogger? logger = null) finally { activity?.EndStreaming(streamedContents); + await responseEnumerator.DisposeAsync(); } } @@ -720,6 +721,7 @@ static void AddResponseMessage(ChatCompletionsOptions chatOptions, ChatHistory c finally { activity?.EndStreaming(streamedContents); + await responseEnumerator.DisposeAsync(); } }