Skip to content

Commit

Permalink
.Net Stream Json Parser as Utility for Connectors. (#5574)
Browse files Browse the repository at this point in the history
### Motivation and Context

As a proven useful component created by @Krzysztof318 in the Gemini
Connector implementation, I'm bringing this as a Utility that can be
used internally across different connectors for streaming
deserialization of SSE and Non-SSE responses.

- Hugging Face updates using the utility
- Unit tests moved to SemanticKernel.UnitTests for the tool
- StreamJsonParser tests removed from HuggingFace UnitTests
- Added extra Example using streaming with HuggingFace Zephyr model.
  • Loading branch information
RogerBarreto committed Apr 2, 2024
1 parent 87ead74 commit e32ab5e
Show file tree
Hide file tree
Showing 8 changed files with 532 additions and 367 deletions.
22 changes: 22 additions & 0 deletions dotnet/samples/KernelSyntaxExamples/Example20_HuggingFace.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,28 @@ public async Task RunInferenceApiEmbeddingAsync()
this.WriteLine($"Generated {embeddings.Count} embeddings for the provided text");
}

[RetryFact(typeof(HttpOperationException))]
public async Task RunStreamingExampleAsync()
{
WriteLine("\n======== HuggingFace zephyr-7b-beta streaming example ========\n");

const string Model = "HuggingFaceH4/zephyr-7b-beta";

Kernel kernel = Kernel.CreateBuilder()
.AddHuggingFaceTextGeneration(
model: Model,
//endpoint: Endpoint,
apiKey: TestConfiguration.HuggingFace.ApiKey)
.Build();

var questionAnswerFunction = kernel.CreateFunctionFromPrompt("Question: {{$input}}; Answer:");

await foreach (string text in kernel.InvokeStreamingAsync<string>(questionAnswerFunction, new() { ["input"] = "What is New York?" }))
{
this.Write(text);
}
}

/// <summary>
/// This example uses HuggingFace Llama 2 model and local HTTP server from Semantic Kernel repository.
/// How to setup local HTTP server: <see href="https://github.com/microsoft/semantic-kernel/blob/main/samples/apps/hugging-face-http-server/README.md"/>.
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@
using System.Collections.Generic;
using System.IO;
using System.Text.Json;
using System.Threading.Tasks;
using Microsoft.SemanticKernel.Connectors.HuggingFace.Client;
using Microsoft.SemanticKernel.Text;
using Xunit;

namespace SemanticKernel.Connectors.HuggingFace.UnitTests.TextGeneration;
public class TextGenerationStreamResponseTests
{
[Fact]
public void SerializationShouldPopulateAllProperties()
public async Task SerializationShouldPopulateAllPropertiesAsync()
{
// Arrange
var parser = new TextGenerationStreamJsonParser();
var parser = new StreamJsonParser();
var stream = new MemoryStream();
var huggingFaceStreamExample = """
{
Expand Down Expand Up @@ -44,7 +46,7 @@ public void SerializationShouldPopulateAllProperties()

// Act
var chunks = new List<TextGenerationStreamResponse>();
foreach (var chunk in parser.Parse(stream))
await foreach (var chunk in parser.ParseAsync(stream))
{
chunks.Add(JsonSerializer.Deserialize<TextGenerationStreamResponse>(chunk)!);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.SemanticKernel.Connectors.HuggingFace.TextGeneration;
using Microsoft.SemanticKernel.Http;
using Microsoft.SemanticKernel.Text;

namespace Microsoft.SemanticKernel.Connectors.HuggingFace.Client;

internal sealed class HuggingFaceClient
{
private readonly IStreamJsonParser _streamJsonParser;
private readonly StreamJsonParser _streamJsonParser;
private readonly string _modelId;
private readonly string? _apiKey;
private readonly Uri? _endpoint;
Expand All @@ -32,7 +33,7 @@ internal sealed class HuggingFaceClient
HttpClient httpClient,
Uri? endpoint = null,
string? apiKey = null,
IStreamJsonParser? streamJsonParser = null,
StreamJsonParser? streamJsonParser = null,
ILogger? logger = null)
{
Verify.NotNullOrWhiteSpace(modelId);
Expand All @@ -45,7 +46,7 @@ internal sealed class HuggingFaceClient
this._apiKey = apiKey;
this._httpClient = httpClient;
this._logger = logger ?? NullLogger.Instance;
this._streamJsonParser = streamJsonParser ?? new TextGenerationStreamJsonParser();
this._streamJsonParser = streamJsonParser ?? new StreamJsonParser();
}

public async Task<IReadOnlyList<TextContent>> GenerateTextAsync(
Expand Down Expand Up @@ -87,7 +88,7 @@ internal sealed class HuggingFaceClient
using var responseStream = await response.Content.ReadAsStreamAndTranslateExceptionAsync()
.ConfigureAwait(false);

foreach (var streamingTextContent in this.ProcessTextResponseStream(responseStream, modelId))
await foreach (var streamingTextContent in this.ProcessTextResponseStreamAsync(responseStream, modelId, cancellationToken).ConfigureAwait(false))
{
yield return streamingTextContent;
}
Expand Down Expand Up @@ -151,31 +152,45 @@ private static void ValidateMaxTokens(int? maxTokens)
return response;
}

private IEnumerable<StreamingTextContent> ProcessTextResponseStream(Stream stream, string modelId)
=> from response in this.ParseTextResponseStream(stream)
from textContent in this.GetTextStreamContentsFromResponse(response, modelId)
select GetStreamingTextContentFromTextContent(textContent);
private async IAsyncEnumerable<StreamingTextContent> ProcessTextResponseStreamAsync(Stream stream, string modelId, [EnumeratorCancellation] CancellationToken cancellationToken)
{
IAsyncEnumerator<TextGenerationStreamResponse>? responseEnumerator = null;

try
{
var responseEnumerable = this.ParseTextResponseStreamAsync(stream, cancellationToken);
responseEnumerator = responseEnumerable.GetAsyncEnumerator(cancellationToken);

private IEnumerable<TextGenerationStreamResponse> ParseTextResponseStream(Stream responseStream)
=> this._streamJsonParser.Parse(responseStream).Select(DeserializeResponse<TextGenerationStreamResponse>);
while (await responseEnumerator.MoveNextAsync().ConfigureAwait(false))
{
var textContent = responseEnumerator.Current!;

private List<TextContent> GetTextStreamContentsFromResponse(TextGenerationStreamResponse response, string modelId)
yield return GetStreamingTextContentFromStreamResponse(textContent, modelId);
}
}
finally
{
if (responseEnumerator != null)
{
await responseEnumerator.DisposeAsync().ConfigureAwait(false);
}
}
}

private async IAsyncEnumerable<TextGenerationStreamResponse> ParseTextResponseStreamAsync(Stream responseStream, [EnumeratorCancellation] CancellationToken cancellationToken)
{
return new List<TextContent>
await foreach (var json in this._streamJsonParser.ParseAsync(responseStream, cancellationToken: cancellationToken))
{
new(text: response.Token?.Text,
modelId: modelId,
innerContent: response,
metadata: new TextGenerationStreamMetadata(response))
};
yield return DeserializeResponse<TextGenerationStreamResponse>(json);
}
}

private static StreamingTextContent GetStreamingTextContentFromTextContent(TextContent textContent)
private static StreamingTextContent GetStreamingTextContentFromStreamResponse(TextGenerationStreamResponse response, string modelId)
=> new(
text: textContent.Text,
modelId: textContent.ModelId,
innerContent: textContent.InnerContent,
metadata: textContent.Metadata);
text: response.Token?.Text,
modelId: modelId,
innerContent: response,
metadata: new TextGenerationStreamMetadata(response));

private TextGenerationRequest CreateTextRequest(
string prompt,
Expand Down

This file was deleted.

Loading

0 comments on commit e32ab5e

Please sign in to comment.