Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IChatCompletion Streaming Support #829

Merged
merged 19 commits into from
May 8, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
0ff9616
Added chat message stream and its implementation to both Azure and Op…
RogerBarreto May 5, 2023
6a61607
Improved the ChatGPT Examples
RogerBarreto May 5, 2023
13b5e8d
Splitting Custom Chat samples from the ChatGPT (Azure/OpenAI)
RogerBarreto May 5, 2023
88572b1
Enabling Example 33 in Program.cs
RogerBarreto May 5, 2023
c32fb98
Chaning description and pattern to match existing for TextCompletion …
RogerBarreto May 5, 2023
e7c0379
Added chat message stream and its implementation to both Azure and Op…
RogerBarreto May 5, 2023
4ff0c7a
Improved the ChatGPT Examples
RogerBarreto May 5, 2023
6dc6dd8
Splitting Custom Chat samples from the ChatGPT (Azure/OpenAI)
RogerBarreto May 5, 2023
4fb856c
Enabling Example 33 in Program.cs
RogerBarreto May 5, 2023
0de719e
Chaning description and pattern to match existing for TextCompletion …
RogerBarreto May 5, 2023
f421293
Merge from main
May 6, 2023
f073aed
Merge branch 'main' into features/ichatcompletion-stream
shawncal May 7, 2023
bb98a2f
Fixing conflicts and chat completion stream
RogerBarreto May 8, 2023
58a57f9
Improved ChatGPT sample logic with streaming
RogerBarreto May 8, 2023
5a7b0d9
Merge branch 'main' into features/ichatcompletion-stream
RogerBarreto May 8, 2023
454f341
Bettern improvements to the samples logic
RogerBarreto May 8, 2023
4bfe4a8
Removing unnecessary configuraawait from sample
RogerBarreto May 8, 2023
a05e1e4
Addresssed PR Feedback
RogerBarreto May 8, 2023
6b49f48
Merge branch 'main' into features/ichatcompletion-stream
RogerBarreto May 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
135 changes: 101 additions & 34 deletions dotnet/src/Connectors/Connectors.AI.OpenAI/AzureSdk/ClientBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public abstract class ClientBase
Verify.NotNull(requestSettings);

ValidateMaxTokens(requestSettings.MaxTokens);
var options = this.CreateCompletionsOptions(text, requestSettings);
var options = CreateCompletionsOptions(text, requestSettings);

Response<Completions>? response = await RunRequestAsync<Response<Completions>?>(
() => this.Client.GetCompletionsAsync(this.ModelId, options, cancellationToken)).ConfigureAwait(false);
Expand Down Expand Up @@ -72,7 +72,7 @@ public abstract class ClientBase
Verify.NotNull(requestSettings);

ValidateMaxTokens(requestSettings.MaxTokens);
var options = this.CreateCompletionsOptions(text, requestSettings);
var options = CreateCompletionsOptions(text, requestSettings);

Response<StreamingCompletions>? response = await RunRequestAsync<Response<StreamingCompletions>>(
() => this.Client.GetCompletionsStreamingAsync(this.ModelId, options, cancellationToken)).ConfigureAwait(false);
Expand Down Expand Up @@ -135,31 +135,35 @@ await foreach (string message in choice.GetTextStreaming(cancellationToken))
Verify.NotNull(chat);
Verify.NotNull(requestSettings);

if (requestSettings.MaxTokens < 1)
{
throw new AIException(
AIException.ErrorCodes.InvalidRequest,
$"MaxTokens {requestSettings.MaxTokens} is not valid, the value must be greater than zero");
}
ValidateMaxTokens(requestSettings.MaxTokens);
var options = CreateChatCompletionsOptions(requestSettings, chat);

var options = new ChatCompletionsOptions
{
MaxTokens = requestSettings.MaxTokens,
Temperature = (float?)requestSettings.Temperature,
NucleusSamplingFactor = (float?)requestSettings.TopP,
FrequencyPenalty = (float?)requestSettings.FrequencyPenalty,
PresencePenalty = (float?)requestSettings.PresencePenalty,
ChoicesPerPrompt = 1,
};
Response<ChatCompletions>? response = await RunRequestAsync<Response<ChatCompletions>?>(
() => this.Client.GetChatCompletionsAsync(this.ModelId, options, cancellationToken)).ConfigureAwait(false);

if (requestSettings.StopSequences is { Count: > 0 })
if (response == null || response.Value.Choices.Count < 1)
{
foreach (var s in requestSettings.StopSequences)
{
options.StopSequences.Add(s);
}
throw new AIException(AIException.ErrorCodes.InvalidResponseContent, "Chat completions not found");
}

return response.Value.Choices[0].Message.Content;
}

/// <summary>
/// Generate a new chat message stream
/// </summary>
/// <param name="chat">Chat history</param>
/// <param name="requestSettings">AI request settings</param>
/// <param name="cancellationToken">Async cancellation token</param>
/// <returns>Streaming of generated chat message in string format</returns>
protected async IAsyncEnumerable<string> InternalGenerateChatMessageStreamAsync(
ChatHistory chat,
ChatRequestSettings requestSettings,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
Verify.NotNull(chat);
Verify.NotNull(requestSettings);

foreach (ChatHistory.Message message in chat.Messages)
{
var role = message.AuthorRole switch
Expand All @@ -169,27 +173,38 @@ await foreach (string message in choice.GetTextStreaming(cancellationToken))
ChatHistory.AuthorRoles.System => ChatRole.System,
_ => throw new ArgumentException($"Invalid chat message author: {message.AuthorRole:G}")
};

options.Messages.Add(new ChatMessage(role, message.Content));
}

Response<ChatCompletions>? response = await RunRequestAsync<Response<ChatCompletions>?>(
() => this.Client.GetChatCompletionsAsync(this.ModelId, options, cancellationToken)).ConfigureAwait(false);
ValidateMaxTokens(requestSettings.MaxTokens);
var options = CreateChatCompletionsOptions(requestSettings, chat);

if (response == null || response.Value.Choices.Count < 1)
Response<StreamingChatCompletions>? response = await RunRequestAsync<Response<StreamingChatCompletions>>(
() => this.Client.GetChatCompletionsStreamingAsync(this.ModelId, options, cancellationToken)).ConfigureAwait(false);

using StreamingChatCompletions streamingChatCompletions = response.Value;

if (response is null)
{
throw new AIException(AIException.ErrorCodes.InvalidResponseContent, "Chat completions not found");
}

return response.Value.Choices[0].Message.Content;
await foreach (StreamingChatChoice choice in streamingChatCompletions.GetChoicesStreaming(cancellationToken))
{
await foreach (ChatMessage message in choice.GetMessageStreaming(cancellationToken))
{
yield return message.Content;
}

yield return Environment.NewLine;
}
}

/// <summary>
/// Create a new empty chat instance
/// </summary>
/// <param name="instructions">Optional chat instructions for the AI service</param>
/// <returns>Chat object</returns>
protected ChatHistory InternalCreateNewChat(string instructions = "")
protected static ChatHistory InternalCreateNewChat(string instructions = "")
{
return new OpenAIChatHistory(instructions);
}
Expand All @@ -206,9 +221,26 @@ protected ChatHistory InternalCreateNewChat(string instructions = "")
CompleteRequestSettings requestSettings,
CancellationToken cancellationToken = default)
{
var chat = this.InternalCreateNewChat();
ChatHistory chat = PrepareChatHistory(text, requestSettings, out ChatRequestSettings settings);

return await this.InternalGenerateChatMessageAsync(chat, settings, cancellationToken).ConfigureAwait(false);
}

protected IAsyncEnumerable<string> InternalCompleteTextUsingChatStreamAsync(
string text,
CompleteRequestSettings requestSettings,
CancellationToken cancellationToken = default)
{
ChatHistory chat = PrepareChatHistory(text, requestSettings, out ChatRequestSettings settings);

return this.InternalGenerateChatMessageStreamAsync(chat, settings, cancellationToken);
}

private static ChatHistory PrepareChatHistory(string text, CompleteRequestSettings requestSettings, out ChatRequestSettings settings)
{
var chat = InternalCreateNewChat();
chat.AddMessage(ChatHistory.AuthorRoles.User, text);
var settings = new ChatRequestSettings
settings = new ChatRequestSettings
{
MaxTokens = requestSettings.MaxTokens,
Temperature = requestSettings.Temperature,
Expand All @@ -217,11 +249,10 @@ protected ChatHistory InternalCreateNewChat(string instructions = "")
FrequencyPenalty = requestSettings.FrequencyPenalty,
StopSequences = requestSettings.StopSequences,
};

return await this.InternalGenerateChatMessageAsync(chat, settings, cancellationToken).ConfigureAwait(false);
return chat;
}

private CompletionsOptions CreateCompletionsOptions(string text, CompleteRequestSettings requestSettings)
private static CompletionsOptions CreateCompletionsOptions(string text, CompleteRequestSettings requestSettings)
{
var options = new CompletionsOptions
{
Expand Down Expand Up @@ -249,6 +280,42 @@ private CompletionsOptions CreateCompletionsOptions(string text, CompleteRequest
return options;
}

private static ChatCompletionsOptions CreateChatCompletionsOptions(ChatRequestSettings requestSettings, ChatHistory chat)
{
var options = new ChatCompletionsOptions
{
MaxTokens = requestSettings.MaxTokens,
Temperature = (float?)requestSettings.Temperature,
NucleusSamplingFactor = (float?)requestSettings.TopP,
FrequencyPenalty = (float?)requestSettings.FrequencyPenalty,
PresencePenalty = (float?)requestSettings.PresencePenalty,
ChoicesPerPrompt = 1,
};

if (requestSettings.StopSequences is { Count: > 0 })
{
foreach (var s in requestSettings.StopSequences)
{
options.StopSequences.Add(s);
}
}

foreach (ChatHistory.Message message in chat.Messages)
{
var role = message.AuthorRole switch
{
ChatHistory.AuthorRoles.User => ChatRole.User,
ChatHistory.AuthorRoles.Assistant => ChatRole.Assistant,
ChatHistory.AuthorRoles.System => ChatRole.System,
_ => throw new ArgumentException($"Invalid chat message author: {message.AuthorRole:G}")
};

options.Messages.Add(new ChatMessage(role, message.Content));
}

return options;
}

private static void ValidateMaxTokens(int maxTokens)
{
if (maxTokens < 1)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -62,10 +61,19 @@ public sealed class AzureChatCompletion : AzureOpenAIClientBase, IChatCompletion
return this.InternalGenerateChatMessageAsync(chat, requestSettings ?? new(), cancellationToken);
}

/// <inheritdoc/>
public IAsyncEnumerable<string> GenerateMessageStreamAsync(
ChatHistory chat,
ChatRequestSettings? requestSettings = null,
CancellationToken cancellationToken = default)
{
return this.InternalGenerateChatMessageStreamAsync(chat, requestSettings ?? new(), cancellationToken);
}

/// <inheritdoc/>
public ChatHistory CreateNewChat(string instructions = "")
{
return this.InternalCreateNewChat(instructions);
return InternalCreateNewChat(instructions);
}

/// <inheritdoc/>
Expand All @@ -82,6 +90,6 @@ public ChatHistory CreateNewChat(string instructions = "")
CompleteRequestSettings requestSettings,
CancellationToken cancellationToken = default)
{
return this.InternalCompleteTextUsingChatAsync(text, requestSettings, cancellationToken).ToAsyncEnumerable();
return this.InternalCompleteTextUsingChatStreamAsync(text, requestSettings, cancellationToken);
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -42,15 +41,22 @@ public sealed class OpenAIChatCompletion : OpenAIClientBase, IChatCompletion, IT
ChatRequestSettings? requestSettings = null,
CancellationToken cancellationToken = default)
{
requestSettings ??= new ChatRequestSettings();
return this.InternalGenerateChatMessageAsync(chat, requestSettings ?? new(), cancellationToken);
}

return this.InternalGenerateChatMessageAsync(chat, requestSettings, cancellationToken);
/// <inheritdoc/>
public IAsyncEnumerable<string> GenerateMessageStreamAsync(
ChatHistory chat,
ChatRequestSettings? requestSettings = null,
CancellationToken cancellationToken = default)
{
return this.InternalGenerateChatMessageStreamAsync(chat, requestSettings ?? new(), cancellationToken);
}

/// <inheritdoc/>
public ChatHistory CreateNewChat(string instructions = "")
{
return this.InternalCreateNewChat(instructions);
return InternalCreateNewChat(instructions);
}

/// <inheritdoc/>
Expand All @@ -68,6 +74,6 @@ public ChatHistory CreateNewChat(string instructions = "")
CompleteRequestSettings requestSettings,
CancellationToken cancellationToken = default)
{
return this.InternalCompleteTextUsingChatAsync(text, requestSettings, cancellationToken).ToAsyncEnumerable();
return this.InternalCompleteTextUsingChatStreamAsync(text, requestSettings, cancellationToken);
}
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.SemanticKernel.AI.ChatCompletion;

public interface IChatCompletion
{
/// <summary>
/// Create a new empty chat instance
/// </summary>
/// <param name="instructions">Optional chat instructions for the AI service</param>
/// <returns>Chat object</returns>
public ChatHistory CreateNewChat(string instructions = "");

/// <summary>
/// Generate a new chat message
/// </summary>
Expand All @@ -20,9 +28,14 @@ public interface IChatCompletion
CancellationToken cancellationToken = default);

/// <summary>
/// Create a new empty chat instance
/// Generate a new chat message
/// </summary>
/// <param name="instructions">Optional chat instructions for the AI service</param>
/// <returns>Chat object</returns>
public ChatHistory CreateNewChat(string instructions = "");
/// <param name="chat">Chat history</param>
/// <param name="requestSettings">AI request settings</param>
/// <param name="cancellationToken">Async cancellation token</param>
/// <returns>Stream the generated chat message in string format</returns>
public IAsyncEnumerable<string> GenerateMessageStreamAsync(
ChatHistory chat,
ChatRequestSettings? requestSettings = null,
CancellationToken cancellationToken = default);
}
7 changes: 5 additions & 2 deletions samples/dotnet/kernel-syntax-examples/Example16_CustomLLM.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,18 @@
*/
public class MyTextCompletionService : ITextCompletion
{
public Task<string> CompleteAsync(
public async Task<string> CompleteAsync(
string text,
CompleteRequestSettings requestSettings,
CancellationToken cancellationToken = default)
{
// Your model logic here
var result = "...output from your custom model...";

return Task.FromResult(result);
// Forcing a 2 sec delay (Simulating custom LLM lag)
await Task.Delay(2000, cancellationToken);

return result;
}

public async IAsyncEnumerable<string> CompleteStreamAsync(
Expand Down