Skip to content

Commit

Permalink
.Net Issue Fix - Add FunctionResponse to ChatStreaming interfaces (#3246
Browse files Browse the repository at this point in the history
)

### Motivation and Context

Resolves #3198 

Today we can't call functions using streaming, this change allows it to
be used.

The current approach will require the stream to be buffered while
listening to a potential function call.

The example shows how that will be achieved.
```
StringBuilder chatContent = new();

// Non function result streaming will happen here.
await foreach (var message in chatResult.GetStreamingChatMessageAsync())
{
    if (message.Content is not null)
    {
        Console.Write(message.Content);
        chatContent.Append(message.Content);
    }
}
chatHistory.AddAssistantMessage(chatContent.ToString());

// After all stream was loaded check if a `FunctionResponse` was present.
var functionResponse = await chatResult.GetStreamingFunctionResponseAsync();
```

Remark: Calling `GetStreamingFunctionResponseAsync()` will buffer the
stream to capture the full function call definition, use it after all
the streaming messages were flushed out.

### Description

Allows streaming interfaces to get the function result.

This is a simplified approach (Buffer the streaming of a function call
request until it's complete). Using the FunctionCall

### Contribution Checklist

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄
  • Loading branch information
RogerBarreto committed Nov 1, 2023
1 parent f3df736 commit 831ff8e
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

using System;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.AI.ChatCompletion;
Expand Down Expand Up @@ -36,10 +37,12 @@ public static async Task RunAsync()
// Set FunctionCall to the name of a specific function to force the model to use that function.
requestSettings.FunctionCall = "TimePlugin-Date";
await CompleteChatWithFunctionsAsync("What day is today?", chatHistory, chatCompletion, kernel, requestSettings);
await StreamingCompleteChatWithFunctionsAsync("What day is today?", chatHistory, chatCompletion, kernel, requestSettings);

// Set FunctionCall to auto to let the model choose the best function to use.
requestSettings.FunctionCall = OpenAIRequestSettings.FunctionCallAuto;
await CompleteChatWithFunctionsAsync("What computer tablets are available for under $200?", chatHistory, chatCompletion, kernel, requestSettings);
await StreamingCompleteChatWithFunctionsAsync("What computer tablets are available for under $200?", chatHistory, chatCompletion, kernel, requestSettings);
}

private static async Task<IKernel> InitializeKernelAsync()
Expand Down Expand Up @@ -77,7 +80,7 @@ private static async Task CompleteChatWithFunctionsAsync(string ask, ChatHistory
}

// Check for function response
OpenAIFunctionResponse? functionResponse = chatResult.GetFunctionResponse();
OpenAIFunctionResponse? functionResponse = chatResult.GetOpenAIFunctionResponse();
if (functionResponse is not null)
{
// Print function response details
Expand Down Expand Up @@ -121,4 +124,70 @@ private static async Task CompleteChatWithFunctionsAsync(string ask, ChatHistory
}
}
}

private static async Task StreamingCompleteChatWithFunctionsAsync(string ask, ChatHistory chatHistory, IChatCompletion chatCompletion, IKernel kernel, OpenAIRequestSettings requestSettings)
{
Console.WriteLine($"User message: {ask}");
chatHistory.AddUserMessage(ask);

// Send request
await foreach (var chatResult in chatCompletion.GetStreamingChatCompletionsAsync(chatHistory, requestSettings))
{
StringBuilder chatContent = new();
await foreach (var message in chatResult.GetStreamingChatMessageAsync())
{
if (message.Content is not null)
{
Console.Write(message.Content);
chatContent.Append(message.Content);
}
}
chatHistory.AddAssistantMessage(chatContent.ToString());

var functionResponse = await chatResult.GetOpenAIStreamingFunctionResponseAsync();

if (functionResponse is not null)
{
// Print function response details
Console.WriteLine("Function name: " + functionResponse.FunctionName);
Console.WriteLine("Plugin name: " + functionResponse.PluginName);
Console.WriteLine("Arguments: ");
foreach (var parameter in functionResponse.Parameters)
{
Console.WriteLine($"- {parameter.Key}: {parameter.Value}");
}

// If the function returned by OpenAI is an SKFunction registered with the kernel,
// you can invoke it using the following code.
if (kernel.Functions.TryGetFunctionAndContext(functionResponse, out ISKFunction? func, out ContextVariables? context))
{
var kernelResult = await kernel.RunAsync(func, context);

var result = kernelResult.GetValue<object>();

string? resultMessage = null;
if (result is RestApiOperationResponse apiResponse)
{
resultMessage = apiResponse.Content?.ToString();
}
else if (result is string str)
{
resultMessage = str;
}

if (!string.IsNullOrEmpty(resultMessage))
{
Console.WriteLine(resultMessage);

// Add the function result to chat history
chatHistory.AddAssistantMessage(resultMessage);
}
}
else
{
Console.WriteLine($"Error: Function {functionResponse.PluginName}.{functionResponse.FunctionName} not found.");
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using Microsoft.SemanticKernel.AI.ChatCompletion;

namespace Microsoft.SemanticKernel.Connectors.AI.OpenAI.AzureSdk;
Expand All @@ -14,7 +15,18 @@ public static class ChatResultExtensions
/// </summary>
/// <param name="chatResult"></param>
/// <returns>The <see cref="OpenAIFunctionResponse"/>, or null if no function was returned by the model.</returns>
[Obsolete("Obsoleted, please use GetOpenAIFunctionResponse instead")]
public static OpenAIFunctionResponse? GetFunctionResponse(this IChatResult chatResult)
{
return GetOpenAIFunctionResponse(chatResult);
}

/// <summary>
/// Retrieve the resulting function from the chat result.
/// </summary>
/// <param name="chatResult"></param>
/// <returns>The <see cref="OpenAIFunctionResponse"/>, or null if no function was returned by the model.</returns>
public static OpenAIFunctionResponse? GetOpenAIFunctionResponse(this IChatResult chatResult)
{
OpenAIFunctionResponse? functionResponse = null;
var functionCall = chatResult.ModelResult.GetResult<ChatModelResult>().Choice.Message.FunctionCall;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using Azure.AI.OpenAI;

namespace Microsoft.SemanticKernel.Connectors.AI.OpenAI.AzureSdk;

/// <summary> Represents a singular result of a chat completion.</summary>
public class ChatStreamingModelResult
{
/// <summary> A unique identifier associated with this chat completion response. </summary>
public string Id { get; }

/// <summary>
/// The first timestamp associated with generation activity for this completions response,
/// represented as seconds since the beginning of the Unix epoch of 00:00 on 1 Jan 1970.
/// </summary>
public DateTimeOffset Created { get; }

/// <summary>
/// Content filtering results for zero or more prompts in the request.
/// </summary>
public IReadOnlyList<PromptFilterResult> PromptFilterResults { get; }

/// <summary>
/// The completion choice associated with this completion result.
/// </summary>
public StreamingChatChoice Choice { get; }

/// <summary> Initializes a new instance of TextModelResult. </summary>
/// <param name="completionsData"> A completions response object to populate the fields relative the response.</param>
/// <param name="choiceData"> A choice object to populate the fields relative to the resulting choice.</param>
internal ChatStreamingModelResult(StreamingChatCompletions completionsData, StreamingChatChoice choiceData)
{
this.Id = completionsData.Id;
this.Created = completionsData.Created;
this.PromptFilterResults = completionsData.PromptFilterResults;
this.Choice = choiceData;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,16 @@ namespace Microsoft.SemanticKernel.Connectors.AI.OpenAI.AzureSdk;

internal sealed class ChatStreamingResult : IChatStreamingResult, ITextStreamingResult, IChatResult, ITextResult
{
private readonly ModelResult _modelResult;
private readonly StreamingChatChoice _choice;
public ModelResult ModelResult { get; }

public ChatStreamingResult(StreamingChatCompletions resultData, StreamingChatChoice choice)
{
Verify.NotNull(choice);
this._modelResult = new ModelResult(resultData);
this.ModelResult = new(new ChatStreamingModelResult(resultData, choice));
this._choice = choice;
}

public ModelResult ModelResult => this._modelResult;

/// <inheritdoc/>
public async Task<ChatMessageBase> GetChatMessageAsync(CancellationToken cancellationToken = default)
{
Expand All @@ -47,7 +45,7 @@ public async IAsyncEnumerable<ChatMessageBase> GetStreamingChatMessageAsync([Enu
{
await foreach (var message in this._choice.GetMessageStreaming(cancellationToken))
{
if (message.Content is { Length: > 0 })
if (message.FunctionCall is not null || message.Content is { Length: > 0 })
{
yield return new SKChatMessage(message);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Text;
using System.Threading.Tasks;
using Azure.AI.OpenAI;
using Microsoft.SemanticKernel.AI.ChatCompletion;

namespace Microsoft.SemanticKernel.Connectors.AI.OpenAI.AzureSdk;

/// <summary>
/// Provides extension methods for the IChatStreamingResult interface.
/// </summary>
public static class ChatStreamingResultExtensions
{
/// <summary>
/// Retrieve the resulting function from the chat result.
/// </summary>
/// <param name="chatStreamingResult">Chat streaming result</param>
/// <returns>The <see cref="OpenAIFunctionResponse"/>, or null if no function was returned by the model.</returns>
public static async Task<OpenAIFunctionResponse?> GetOpenAIStreamingFunctionResponseAsync(this IChatStreamingResult chatStreamingResult)
{
if (chatStreamingResult is not ChatStreamingResult)
{
throw new NotSupportedException($"Chat streaming result is not OpenAI {nameof(ChatStreamingResult)} supported type");
}

StringBuilder arguments = new();
FunctionCall? functionCall = null;
await foreach (SKChatMessage message in chatStreamingResult.GetStreamingChatMessageAsync())
{
functionCall ??= message.FunctionCall;

arguments.Append(message.FunctionCall.Arguments);
}

if (functionCall is null)
{
return null;
}

functionCall.Arguments = arguments.ToString();
return OpenAIFunctionResponse.FromFunctionCall(functionCall);
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using Azure.AI.OpenAI;
using Microsoft.SemanticKernel.AI.ChatCompletion;

namespace Microsoft.SemanticKernel.Connectors.AI.OpenAI.AzureSdk;
Expand All @@ -9,13 +11,16 @@ namespace Microsoft.SemanticKernel.Connectors.AI.OpenAI.AzureSdk;
/// </summary>
public class SKChatMessage : ChatMessageBase
{
private readonly ChatMessage? _message;

/// <summary>
/// Initializes a new instance of the <see cref="SKChatMessage"/> class.
/// </summary>
/// <param name="message">OpenAI SDK chat message representation</param>
public SKChatMessage(Azure.AI.OpenAI.ChatMessage message)
: base(new AuthorRole(message.Role.ToString()), message.Content)
{
this._message = message;
}

/// <summary>
Expand All @@ -27,4 +32,10 @@ public SKChatMessage(string role, string content)
: base(new AuthorRole(role), content)
{
}

/// <summary>
/// Exposes the underlying OpenAI SDK function call chat message representation
/// </summary>
public FunctionCall FunctionCall
=> this._message?.FunctionCall ?? throw new NotSupportedException("Function call is not supported");
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,14 @@ public static ChatModelResult GetOpenAIChatResult(this ModelResult resultBase)
{
return resultBase.GetResult<ChatModelResult>();
}

/// <summary>
/// Retrieves a typed <see cref="ChatStreamingModelResult"/> OpenAI / AzureOpenAI result from chat completion prompt.
/// </summary>
/// <param name="resultBase">Current context</param>
/// <returns>OpenAI / AzureOpenAI result<see cref="ChatCompletions"/></returns>
public static ChatStreamingModelResult GetOpenAIChatStreamingResult(this ModelResult resultBase)
{
return resultBase.GetResult<ChatStreamingModelResult>();
}
}

0 comments on commit 831ff8e

Please sign in to comment.