Skip to content

Commit

Permalink
.Net Hugging Face TGI Chat Completion Message API Support (#5785)
Browse files Browse the repository at this point in the history
### Motivation and Context

Closes #5403 

1. Adding support to Chat Completion for TGI (Text Generation Inference)
Deployment.
3. Adding Missing UnitTests for Streaming and Non Streaming scenarios
(Text/Chat Completion)
4. Update Metadata + Usage Details for hugging face clients.

### Contribution Checklist

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄
  • Loading branch information
RogerBarreto committed Apr 16, 2024
1 parent bafc65e commit e416946
Show file tree
Hide file tree
Showing 41 changed files with 3,947 additions and 371 deletions.
96 changes: 92 additions & 4 deletions dotnet/samples/KernelSyntaxExamples/Example20_HuggingFace.cs
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Threading.Tasks;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.HuggingFace;
using Microsoft.SemanticKernel.Embeddings;
using xRetry;
using Xunit;
using Xunit.Abstractions;

#pragma warning disable format // Format item can be simplified
#pragma warning disable CA1861 // Avoid constant arrays as arguments

namespace Examples;

// The following example shows how to use Semantic Kernel with HuggingFace API.
public class Example20_HuggingFace(ITestOutputHelper output) : BaseTest(output)
public class Example20_HuggingFace : BaseTest
{
/// <summary>
/// This example uses HuggingFace Inference API to access hosted models.
Expand Down Expand Up @@ -65,13 +69,17 @@ public async Task RunStreamingExampleAsync()
Kernel kernel = Kernel.CreateBuilder()
.AddHuggingFaceTextGeneration(
model: Model,
//endpoint: Endpoint,
apiKey: TestConfiguration.HuggingFace.ApiKey)
.Build();

var questionAnswerFunction = kernel.CreateFunctionFromPrompt("Question: {{$input}}; Answer:");
var settings = new HuggingFacePromptExecutionSettings { UseCache = false };

var questionAnswerFunction = kernel.CreateFunctionFromPrompt("Question: {{$input}}; Answer:", new HuggingFacePromptExecutionSettings
{
UseCache = false
});

await foreach (string text in kernel.InvokeStreamingAsync<string>(questionAnswerFunction, new() { ["input"] = "What is New York?" }))
await foreach (string text in kernel.InvokePromptStreamingAsync<string>("Question: {{$input}}; Answer:", new(settings) { ["input"] = "What is New York?" }))
{
this.Write(text);
}
Expand Down Expand Up @@ -112,4 +120,84 @@ public async Task RunLlamaExampleAsync()

WriteLine(result.GetValue<string>());
}

/// <summary>
/// Follow steps in <see href="https://huggingface.co/docs/text-generation-inference/main/en/quicktour"/> to setup HuggingFace local Text Generation Inference HTTP server.
/// </summary>
[Fact(Skip = "Requires TGI (text generation inference) deployment")]
public async Task RunTGI_ChatCompletionAsync()
{
WriteLine("\n======== HuggingFace - TGI Chat Completion ========\n");

// This example was run against one of the chat completion (Message API) supported models from HuggingFace, listed in here: <see href="https://huggingface.co/docs/text-generation-inference/main/en/supported_models"/>
// Starting a Local Docker i.e:
// docker run --gpus all --shm-size 1g -p 8080:80 -v "F:\temp\huggingface:/data" ghcr.io/huggingface/text-generation-inference:1.4 --model-id teknium/OpenHermes-2.5-Mistral-7B

// HuggingFace local HTTP server endpoint
var endpoint = new Uri("http://localhost:8080");

const string Model = "teknium/OpenHermes-2.5-Mistral-7B";

Kernel kernel = Kernel.CreateBuilder()
.AddHuggingFaceChatCompletion(
model: Model,
endpoint: endpoint)
.Build();

var chatCompletion = kernel.GetRequiredService<IChatCompletionService>();
var chatHistory = new ChatHistory("You are a helpful assistant.")
{
new ChatMessageContent(AuthorRole.User, "What is deep learning?")
};

var result = await chatCompletion.GetChatMessageContentAsync(chatHistory);

WriteLine(result.Role);
WriteLine(result.Content);
}

/// <summary>
/// Follow steps in <see href="https://huggingface.co/docs/text-generation-inference/main/en/quicktour"/> to setup HuggingFace local Text Generation Inference HTTP server.
/// </summary>
[Fact(Skip = "Requires TGI (text generation inference) deployment")]
public async Task RunTGI_StreamingChatCompletionAsync()
{
WriteLine("\n======== HuggingFace - TGI Chat Completion Streaming ========\n");

// This example was run against one of the chat completion (Message API) supported models from HuggingFace, listed in here: <see href="https://huggingface.co/docs/text-generation-inference/main/en/supported_models"/>
// Starting a Local Docker i.e:
// docker run --gpus all --shm-size 1g -p 8080:80 -v "F:\temp\huggingface:/data" ghcr.io/huggingface/text-generation-inference:1.4 --model-id teknium/OpenHermes-2.5-Mistral-7B

// HuggingFace local HTTP server endpoint
var endpoint = new Uri("http://localhost:8080");

const string Model = "teknium/OpenHermes-2.5-Mistral-7B";

Kernel kernel = Kernel.CreateBuilder()
.AddHuggingFaceChatCompletion(
model: Model,
endpoint: endpoint)
.Build();

var chatCompletion = kernel.GetRequiredService<IChatCompletionService>();
var chatHistory = new ChatHistory("You are a helpful assistant.")
{
new ChatMessageContent(AuthorRole.User, "What is deep learning?")
};

AuthorRole? role = null;
await foreach (var chatMessageChunk in chatCompletion.GetStreamingChatMessageContentsAsync(chatHistory))
{
if (role is null)
{
role = chatMessageChunk.Role;
Write(role);
}
Write(chatMessageChunk.Content);
}
}

public Example20_HuggingFace(ITestOutputHelper output) : base(output)
{
}
}
5 changes: 4 additions & 1 deletion dotnet/samples/KernelSyntaxExamples/Example86_ImageToText.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ public async Task ImageToTextAsync()

// Read image content from a file
ReadOnlyMemory<byte> imageData = await EmbeddedResource.ReadAllAsync(ImageFilePath);
ImageContent imageContent = new(new BinaryData(imageData), "image/jpeg");
ImageContent imageContent = new(new BinaryData(imageData))
{
MimeType = "image/jpeg"
};

// Convert image to text
var textContent = await imageToText.GetTextContentAsync(imageContent, executionSettings);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ public void FromExecutionSettingsWhenSerializedHasPropertiesShouldPopulateSpecia
Assert.Equal(0.5, huggingFaceExecutionSettings.Temperature);
Assert.Equal(50, huggingFaceExecutionSettings.TopK);
Assert.Equal(100, huggingFaceExecutionSettings.MaxTokens);
Assert.Equal(10.0, huggingFaceExecutionSettings.MaxTime);
Assert.Equal(0.9, huggingFaceExecutionSettings.TopP);
Assert.Equal(1.0, huggingFaceExecutionSettings.RepetitionPenalty);
Assert.Equal(10.0f, huggingFaceExecutionSettings.MaxTime);
Assert.Equal(0.9f, huggingFaceExecutionSettings.TopP);
Assert.Equal(1.0f, huggingFaceExecutionSettings.RepetitionPenalty);
Assert.True(huggingFaceExecutionSettings.UseCache);
Assert.Equal(1, huggingFaceExecutionSettings.ResultsPerPrompt);
Assert.False(huggingFaceExecutionSettings.WaitForModel);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Linq;
using System.Net.Http;
using System.Text;
using System.Text.Json;
using System.Threading.Tasks;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.HuggingFace;
using Microsoft.SemanticKernel.Connectors.HuggingFace.Core;
using Xunit;

namespace SemanticKernel.Connectors.HuggingFace.UnitTests;

/// <summary>
/// Unit tests for <see cref="HuggingFaceChatCompletionTests"/> class.
/// </summary>
public sealed class HuggingFaceChatCompletionTests : IDisposable
{
private readonly HttpMessageHandlerStub _messageHandlerStub;
private readonly HttpClient _httpClient;

public HuggingFaceChatCompletionTests()
{
this._messageHandlerStub = new HttpMessageHandlerStub();
this._messageHandlerStub.ResponseToReturn.Content = new StringContent(HuggingFaceTestHelper.GetTestResponse("chatcompletion_test_response.json"));

this._httpClient = new HttpClient(this._messageHandlerStub, false);
this._httpClient.BaseAddress = new Uri("https://fake-random-test-host/fake-path");
}

[Fact]
public async Task ShouldContainModelInRequestBodyAsync()
{
//Arrange
string modelId = "fake-model234";
var sut = new HuggingFaceChatCompletionService(modelId, httpClient: this._httpClient);
var chatHistory = CreateSampleChatHistory();

//Act
await sut.GetChatMessageContentAsync(chatHistory);

//Assert
Assert.NotNull(this._messageHandlerStub.RequestContent);
var requestContent = Encoding.UTF8.GetString(this._messageHandlerStub.RequestContent);

Assert.Contains(modelId, requestContent, StringComparison.Ordinal);
}

[Fact]
public async Task NoAuthorizationHeaderShouldBeAddedIfApiKeyIsNotProvidedAsync()
{
//Arrange
var sut = new HuggingFaceChatCompletionService("fake-model", apiKey: null, httpClient: this._httpClient);

//Act
await sut.GetChatMessageContentAsync("fake-text");

//Assert
Assert.False(this._messageHandlerStub.RequestHeaders?.Contains("Authorization"));
}

[Fact]
public async Task AuthorizationHeaderShouldBeAddedIfApiKeyIsProvidedAsync()
{
//Arrange
var sut = new HuggingFaceChatCompletionService("fake-model", apiKey: "fake-api-key", httpClient: this._httpClient);

//Act
await sut.GetChatMessageContentAsync("fake-text");

//Assert
Assert.True(this._messageHandlerStub.RequestHeaders?.Contains("Authorization"));

var values = this._messageHandlerStub.RequestHeaders!.GetValues("Authorization");

var value = values.SingleOrDefault();
Assert.Equal("Bearer fake-api-key", value);
}

[Fact]
public async Task UserAgentHeaderShouldBeUsedAsync()
{
//Arrange
var sut = new HuggingFaceChatCompletionService("fake-model", httpClient: this._httpClient);
var chatHistory = CreateSampleChatHistory();

//Act
await sut.GetChatMessageContentAsync(chatHistory);

//Assert
Assert.True(this._messageHandlerStub.RequestHeaders?.Contains("User-Agent"));

var values = this._messageHandlerStub.RequestHeaders!.GetValues("User-Agent");

var value = values.SingleOrDefault();
Assert.Equal("Semantic-Kernel", value);
}

[Fact]
public async Task ProvidedEndpointShouldBeUsedAsync()
{
//Arrange
var sut = new HuggingFaceChatCompletionService("fake-model", endpoint: new Uri("https://fake-random-test-host/fake-path"), httpClient: this._httpClient);
var chatHistory = CreateSampleChatHistory();

//Act
await sut.GetChatMessageContentAsync(chatHistory);

//Assert
Assert.StartsWith("https://fake-random-test-host/fake-path", this._messageHandlerStub.RequestUri?.AbsoluteUri, StringComparison.OrdinalIgnoreCase);
}

[Fact]
public async Task HttpClientBaseAddressShouldBeUsedAsync()
{
//Arrange
this._httpClient.BaseAddress = new Uri("https://fake-random-test-host/fake-path");

var sut = new HuggingFaceChatCompletionService("fake-model", httpClient: this._httpClient);
var chatHistory = CreateSampleChatHistory();

//Act
await sut.GetChatMessageContentAsync(chatHistory);

//Assert
Assert.StartsWith("https://fake-random-test-host/fake-path", this._messageHandlerStub.RequestUri?.AbsoluteUri, StringComparison.OrdinalIgnoreCase);
}

[Fact]
public void ShouldThrowIfNotEndpointIsProvided()
{
// Act
this._httpClient.BaseAddress = null;

// Assert
Assert.Throws<ArgumentNullException>(() => new HuggingFaceChatCompletionService("fake-model", httpClient: this._httpClient));
}

[Fact]
public async Task ShouldSendPromptToServiceAsync()
{
//Arrange
var sut = new HuggingFaceChatCompletionService("fake-model", httpClient: this._httpClient);
var chatHistory = CreateSampleChatHistory();

//Act
await sut.GetChatMessageContentAsync(chatHistory);

//Assert
var requestPayload = JsonSerializer.Deserialize<ChatCompletionRequest>(this._messageHandlerStub.RequestContent);
Assert.NotNull(requestPayload);

Assert.Equal(chatHistory.Count, requestPayload.Messages!.Count);
for (var i = 0; i < chatHistory.Count; i++)
{
Assert.Equal(chatHistory[i].Content, requestPayload.Messages[i].Content);
Assert.Equal(chatHistory[i].Role.ToString(), requestPayload.Messages[i].Role);
}
}

[Fact]
public async Task ShouldHandleServiceResponseAsync()
{
//Arrange
var sut = new HuggingFaceChatCompletionService("fake-model", endpoint: new Uri("https://fake-random-test-host/fake-path"), httpClient: this._httpClient);
var chatHistory = CreateSampleChatHistory();

//Act
var contents = await sut.GetChatMessageContentsAsync(chatHistory);

//Assert
Assert.NotNull(contents);

var content = contents.SingleOrDefault();
Assert.NotNull(content);

Assert.Equal("This is a testing chat completion response", content.Content);
}

[Fact]
public async Task GetChatShouldHaveModelIdFromResponseAsync()
{
//Arrange
var sut = new HuggingFaceChatCompletionService("fake-model", endpoint: new Uri("https://fake-random-test-host/fake-path"), httpClient: this._httpClient);
var chatHistory = CreateSampleChatHistory();

//Act
var content = await sut.GetChatMessageContentAsync(chatHistory);

// Assert
Assert.NotNull(content.ModelId);
Assert.Equal("teknium/OpenHermes-2.5-Mistral-7B", content.ModelId);
}

private static ChatHistory CreateSampleChatHistory()
{
var chatHistory = new ChatHistory();
chatHistory.AddUserMessage("Hello");
chatHistory.AddAssistantMessage("Hi");
chatHistory.AddUserMessage("How are you?");
return chatHistory;
}

public void Dispose()
{
this._httpClient.Dispose();
this._messageHandlerStub.Dispose();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
using System.Text.Json;
using System.Threading.Tasks;
using Microsoft.SemanticKernel.Connectors.HuggingFace;
using Microsoft.SemanticKernel.Connectors.HuggingFace.Client;
using Microsoft.SemanticKernel.Connectors.HuggingFace.Core;
using Xunit;

namespace SemanticKernel.Connectors.HuggingFace.UnitTests;
Expand Down
Loading

0 comments on commit e416946

Please sign in to comment.