Skip to content

Commit

Permalink
IChatCompletion Streaming Support (#829)
Browse files Browse the repository at this point in the history
### Motivation and Context
Get streaming results from chat completion models using the
IAsyncEnumerable pattern.

Closes #378 
Resolves #378 

Added samples how to use those as well as how to implement your own
custom chat completion model using streaming capability.

Co-authored-by: dedalo <dedalo@gmail.com>
Co-authored-by: Shawn Callegari <36091529+shawncal@users.noreply.github.com>
  • Loading branch information
3 people committed May 8, 2023
1 parent e38d97f commit ef85b61
Show file tree
Hide file tree
Showing 10 changed files with 507 additions and 86 deletions.
135 changes: 101 additions & 34 deletions dotnet/src/Connectors/Connectors.AI.OpenAI/AzureSdk/ClientBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public abstract class ClientBase
Verify.NotNull(requestSettings);

ValidateMaxTokens(requestSettings.MaxTokens);
var options = this.CreateCompletionsOptions(text, requestSettings);
var options = CreateCompletionsOptions(text, requestSettings);

Response<Completions>? response = await RunRequestAsync<Response<Completions>?>(
() => this.Client.GetCompletionsAsync(this.ModelId, options, cancellationToken)).ConfigureAwait(false);
Expand Down Expand Up @@ -72,7 +72,7 @@ public abstract class ClientBase
Verify.NotNull(requestSettings);

ValidateMaxTokens(requestSettings.MaxTokens);
var options = this.CreateCompletionsOptions(text, requestSettings);
var options = CreateCompletionsOptions(text, requestSettings);

Response<StreamingCompletions>? response = await RunRequestAsync<Response<StreamingCompletions>>(
() => this.Client.GetCompletionsStreamingAsync(this.ModelId, options, cancellationToken)).ConfigureAwait(false);
Expand Down Expand Up @@ -135,31 +135,35 @@ await foreach (string message in choice.GetTextStreaming(cancellationToken))
Verify.NotNull(chat);
Verify.NotNull(requestSettings);

if (requestSettings.MaxTokens < 1)
{
throw new AIException(
AIException.ErrorCodes.InvalidRequest,
$"MaxTokens {requestSettings.MaxTokens} is not valid, the value must be greater than zero");
}
ValidateMaxTokens(requestSettings.MaxTokens);
var options = CreateChatCompletionsOptions(requestSettings, chat);

var options = new ChatCompletionsOptions
{
MaxTokens = requestSettings.MaxTokens,
Temperature = (float?)requestSettings.Temperature,
NucleusSamplingFactor = (float?)requestSettings.TopP,
FrequencyPenalty = (float?)requestSettings.FrequencyPenalty,
PresencePenalty = (float?)requestSettings.PresencePenalty,
ChoicesPerPrompt = 1,
};
Response<ChatCompletions>? response = await RunRequestAsync<Response<ChatCompletions>?>(
() => this.Client.GetChatCompletionsAsync(this.ModelId, options, cancellationToken)).ConfigureAwait(false);

if (requestSettings.StopSequences is { Count: > 0 })
if (response == null || response.Value.Choices.Count < 1)
{
foreach (var s in requestSettings.StopSequences)
{
options.StopSequences.Add(s);
}
throw new AIException(AIException.ErrorCodes.InvalidResponseContent, "Chat completions not found");
}

return response.Value.Choices[0].Message.Content;
}

/// <summary>
/// Generate a new chat message stream
/// </summary>
/// <param name="chat">Chat history</param>
/// <param name="requestSettings">AI request settings</param>
/// <param name="cancellationToken">Async cancellation token</param>
/// <returns>Streaming of generated chat message in string format</returns>
protected async IAsyncEnumerable<string> InternalGenerateChatMessageStreamAsync(
ChatHistory chat,
ChatRequestSettings requestSettings,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
Verify.NotNull(chat);
Verify.NotNull(requestSettings);

foreach (ChatHistory.Message message in chat.Messages)
{
var role = message.AuthorRole switch
Expand All @@ -169,27 +173,38 @@ await foreach (string message in choice.GetTextStreaming(cancellationToken))
ChatHistory.AuthorRoles.System => ChatRole.System,
_ => throw new ArgumentException($"Invalid chat message author: {message.AuthorRole:G}")
};

options.Messages.Add(new ChatMessage(role, message.Content));
}

Response<ChatCompletions>? response = await RunRequestAsync<Response<ChatCompletions>?>(
() => this.Client.GetChatCompletionsAsync(this.ModelId, options, cancellationToken)).ConfigureAwait(false);
ValidateMaxTokens(requestSettings.MaxTokens);
var options = CreateChatCompletionsOptions(requestSettings, chat);

if (response == null || response.Value.Choices.Count < 1)
Response<StreamingChatCompletions>? response = await RunRequestAsync<Response<StreamingChatCompletions>>(
() => this.Client.GetChatCompletionsStreamingAsync(this.ModelId, options, cancellationToken)).ConfigureAwait(false);

using StreamingChatCompletions streamingChatCompletions = response.Value;

if (response is null)
{
throw new AIException(AIException.ErrorCodes.InvalidResponseContent, "Chat completions not found");
}

return response.Value.Choices[0].Message.Content;
await foreach (StreamingChatChoice choice in streamingChatCompletions.GetChoicesStreaming(cancellationToken))
{
await foreach (ChatMessage message in choice.GetMessageStreaming(cancellationToken))
{
yield return message.Content;
}

yield return Environment.NewLine;
}
}

/// <summary>
/// Create a new empty chat instance
/// </summary>
/// <param name="instructions">Optional chat instructions for the AI service</param>
/// <returns>Chat object</returns>
protected ChatHistory InternalCreateNewChat(string instructions = "")
protected static ChatHistory InternalCreateNewChat(string instructions = "")
{
return new OpenAIChatHistory(instructions);
}
Expand All @@ -206,9 +221,26 @@ protected ChatHistory InternalCreateNewChat(string instructions = "")
CompleteRequestSettings requestSettings,
CancellationToken cancellationToken = default)
{
var chat = this.InternalCreateNewChat();
ChatHistory chat = PrepareChatHistory(text, requestSettings, out ChatRequestSettings settings);

return await this.InternalGenerateChatMessageAsync(chat, settings, cancellationToken).ConfigureAwait(false);
}

protected IAsyncEnumerable<string> InternalCompleteTextUsingChatStreamAsync(
string text,
CompleteRequestSettings requestSettings,
CancellationToken cancellationToken = default)
{
ChatHistory chat = PrepareChatHistory(text, requestSettings, out ChatRequestSettings settings);

return this.InternalGenerateChatMessageStreamAsync(chat, settings, cancellationToken);
}

private static ChatHistory PrepareChatHistory(string text, CompleteRequestSettings requestSettings, out ChatRequestSettings settings)
{
var chat = InternalCreateNewChat();
chat.AddMessage(ChatHistory.AuthorRoles.User, text);
var settings = new ChatRequestSettings
settings = new ChatRequestSettings
{
MaxTokens = requestSettings.MaxTokens,
Temperature = requestSettings.Temperature,
Expand All @@ -217,11 +249,10 @@ protected ChatHistory InternalCreateNewChat(string instructions = "")
FrequencyPenalty = requestSettings.FrequencyPenalty,
StopSequences = requestSettings.StopSequences,
};

return await this.InternalGenerateChatMessageAsync(chat, settings, cancellationToken).ConfigureAwait(false);
return chat;
}

private CompletionsOptions CreateCompletionsOptions(string text, CompleteRequestSettings requestSettings)
private static CompletionsOptions CreateCompletionsOptions(string text, CompleteRequestSettings requestSettings)
{
var options = new CompletionsOptions
{
Expand Down Expand Up @@ -249,6 +280,42 @@ private CompletionsOptions CreateCompletionsOptions(string text, CompleteRequest
return options;
}

private static ChatCompletionsOptions CreateChatCompletionsOptions(ChatRequestSettings requestSettings, ChatHistory chat)
{
var options = new ChatCompletionsOptions
{
MaxTokens = requestSettings.MaxTokens,
Temperature = (float?)requestSettings.Temperature,
NucleusSamplingFactor = (float?)requestSettings.TopP,
FrequencyPenalty = (float?)requestSettings.FrequencyPenalty,
PresencePenalty = (float?)requestSettings.PresencePenalty,
ChoicesPerPrompt = 1,
};

if (requestSettings.StopSequences is { Count: > 0 })
{
foreach (var s in requestSettings.StopSequences)
{
options.StopSequences.Add(s);
}
}

foreach (ChatHistory.Message message in chat.Messages)
{
var role = message.AuthorRole switch
{
ChatHistory.AuthorRoles.User => ChatRole.User,
ChatHistory.AuthorRoles.Assistant => ChatRole.Assistant,
ChatHistory.AuthorRoles.System => ChatRole.System,
_ => throw new ArgumentException($"Invalid chat message author: {message.AuthorRole:G}")
};

options.Messages.Add(new ChatMessage(role, message.Content));
}

return options;
}

private static void ValidateMaxTokens(int maxTokens)
{
if (maxTokens < 1)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -62,10 +61,19 @@ public sealed class AzureChatCompletion : AzureOpenAIClientBase, IChatCompletion
return this.InternalGenerateChatMessageAsync(chat, requestSettings ?? new(), cancellationToken);
}

/// <inheritdoc/>
public IAsyncEnumerable<string> GenerateMessageStreamAsync(
ChatHistory chat,
ChatRequestSettings? requestSettings = null,
CancellationToken cancellationToken = default)
{
return this.InternalGenerateChatMessageStreamAsync(chat, requestSettings ?? new(), cancellationToken);
}

/// <inheritdoc/>
public ChatHistory CreateNewChat(string instructions = "")
{
return this.InternalCreateNewChat(instructions);
return InternalCreateNewChat(instructions);
}

/// <inheritdoc/>
Expand All @@ -82,6 +90,6 @@ public ChatHistory CreateNewChat(string instructions = "")
CompleteRequestSettings requestSettings,
CancellationToken cancellationToken = default)
{
return this.InternalCompleteTextUsingChatAsync(text, requestSettings, cancellationToken).ToAsyncEnumerable();
return this.InternalCompleteTextUsingChatStreamAsync(text, requestSettings, cancellationToken);
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -42,15 +41,22 @@ public sealed class OpenAIChatCompletion : OpenAIClientBase, IChatCompletion, IT
ChatRequestSettings? requestSettings = null,
CancellationToken cancellationToken = default)
{
requestSettings ??= new ChatRequestSettings();
return this.InternalGenerateChatMessageAsync(chat, requestSettings ?? new(), cancellationToken);
}

return this.InternalGenerateChatMessageAsync(chat, requestSettings, cancellationToken);
/// <inheritdoc/>
public IAsyncEnumerable<string> GenerateMessageStreamAsync(
ChatHistory chat,
ChatRequestSettings? requestSettings = null,
CancellationToken cancellationToken = default)
{
return this.InternalGenerateChatMessageStreamAsync(chat, requestSettings ?? new(), cancellationToken);
}

/// <inheritdoc/>
public ChatHistory CreateNewChat(string instructions = "")
{
return this.InternalCreateNewChat(instructions);
return InternalCreateNewChat(instructions);
}

/// <inheritdoc/>
Expand All @@ -68,6 +74,6 @@ public ChatHistory CreateNewChat(string instructions = "")
CompleteRequestSettings requestSettings,
CancellationToken cancellationToken = default)
{
return this.InternalCompleteTextUsingChatAsync(text, requestSettings, cancellationToken).ToAsyncEnumerable();
return this.InternalCompleteTextUsingChatStreamAsync(text, requestSettings, cancellationToken);
}
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.SemanticKernel.AI.ChatCompletion;

public interface IChatCompletion
{
/// <summary>
/// Create a new empty chat instance
/// </summary>
/// <param name="instructions">Optional chat instructions for the AI service</param>
/// <returns>Chat object</returns>
public ChatHistory CreateNewChat(string instructions = "");

/// <summary>
/// Generate a new chat message
/// </summary>
Expand All @@ -20,9 +28,14 @@ public interface IChatCompletion
CancellationToken cancellationToken = default);

/// <summary>
/// Create a new empty chat instance
/// Generate a new chat message
/// </summary>
/// <param name="instructions">Optional chat instructions for the AI service</param>
/// <returns>Chat object</returns>
public ChatHistory CreateNewChat(string instructions = "");
/// <param name="chat">Chat history</param>
/// <param name="requestSettings">AI request settings</param>
/// <param name="cancellationToken">Async cancellation token</param>
/// <returns>Stream the generated chat message in string format</returns>
public IAsyncEnumerable<string> GenerateMessageStreamAsync(
ChatHistory chat,
ChatRequestSettings? requestSettings = null,
CancellationToken cancellationToken = default);
}
7 changes: 5 additions & 2 deletions samples/dotnet/kernel-syntax-examples/Example16_CustomLLM.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,18 @@
*/
public class MyTextCompletionService : ITextCompletion
{
public Task<string> CompleteAsync(
public async Task<string> CompleteAsync(
string text,
CompleteRequestSettings requestSettings,
CancellationToken cancellationToken = default)
{
// Your model logic here
var result = "...output from your custom model...";

return Task.FromResult(result);
// Forcing a 2 sec delay (Simulating custom LLM lag)
await Task.Delay(2000, cancellationToken);

return result;
}

public async IAsyncEnumerable<string> CompleteStreamAsync(
Expand Down

0 comments on commit ef85b61

Please sign in to comment.