Skip to content

Commit

Permalink
.Net: Moved Onnx tests to integration tests (#5956)
Browse files Browse the repository at this point in the history
### Motivation and Context

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

In one of my PRs I received HTTP 503 error during CI run in Onnx unit
tests. It appeared that some of the tests perform actual requests to
Hugging Face to download model files. It would be better to keep all
unit tests isolated and lightweight, while keep the tests that require
additional requests to perform as integration tests.

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄
  • Loading branch information
dmytrostruk committed Apr 22, 2024
1 parent 0c40031 commit c84258a
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 86 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Text;
using Microsoft.SemanticKernel.Connectors.Onnx;
using Xunit;

namespace SemanticKernel.Connectors.Onnx.UnitTests;

public class BertOnnxTextEmbeddingGenerationServiceTests
{
[Fact]
public void VerifyOptionsDefaults()
{
var options = new BertOnnxOptions();
Assert.False(options.CaseSensitive);
Assert.Equal(512, options.MaximumTokens);
Assert.Equal("[CLS]", options.ClsToken);
Assert.Equal("[UNK]", options.UnknownToken);
Assert.Equal("[SEP]", options.SepToken);
Assert.Equal("[PAD]", options.PadToken);
Assert.Equal(NormalizationForm.FormD, options.UnicodeNormalization);
Assert.Equal(EmbeddingPoolingMode.Mean, options.PoolingMode);
Assert.False(options.NormalizeEmbeddings);
}

[Fact]
public void RoundtripOptionsProperties()
{
var options = new BertOnnxOptions()
{
CaseSensitive = true,
MaximumTokens = 128,
ClsToken = "<A>",
UnknownToken = "<B>",
SepToken = "<C>",
PadToken = "<D>",
UnicodeNormalization = NormalizationForm.FormKC,
PoolingMode = EmbeddingPoolingMode.MeanSquareRootTokensLength,
NormalizeEmbeddings = true,
};

Assert.True(options.CaseSensitive);
Assert.Equal(128, options.MaximumTokens);
Assert.Equal("<A>", options.ClsToken);
Assert.Equal("<B>", options.UnknownToken);
Assert.Equal("<C>", options.SepToken);
Assert.Equal("<D>", options.PadToken);
Assert.Equal(NormalizationForm.FormKC, options.UnicodeNormalization);
Assert.Equal(EmbeddingPoolingMode.MeanSquareRootTokensLength, options.PoolingMode);
Assert.True(options.NormalizeEmbeddings);
}

[Fact]
public void ValidateInvalidOptionsPropertiesThrow()
{
Assert.Throws<ArgumentOutOfRangeException>(() => new BertOnnxOptions() { MaximumTokens = 0 });
Assert.Throws<ArgumentOutOfRangeException>(() => new BertOnnxOptions() { MaximumTokens = -1 });

Assert.Throws<ArgumentNullException>(() => new BertOnnxOptions() { ClsToken = null! });
Assert.Throws<ArgumentException>(() => new BertOnnxOptions() { ClsToken = " " });

Assert.Throws<ArgumentNullException>(() => new BertOnnxOptions() { UnknownToken = null! });
Assert.Throws<ArgumentException>(() => new BertOnnxOptions() { UnknownToken = " " });

Assert.Throws<ArgumentNullException>(() => new BertOnnxOptions() { SepToken = null! });
Assert.Throws<ArgumentException>(() => new BertOnnxOptions() { SepToken = " " });

Assert.Throws<ArgumentNullException>(() => new BertOnnxOptions() { PadToken = null! });
Assert.Throws<ArgumentException>(() => new BertOnnxOptions() { PadToken = " " });

Assert.Throws<ArgumentOutOfRangeException>(() => new BertOnnxOptions() { PoolingMode = (EmbeddingPoolingMode)4 });
}
}
Original file line number Diff line number Diff line change
@@ -1,90 +1,27 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using Microsoft.SemanticKernel.Embeddings;
using Microsoft.SemanticKernel;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net.Http;
using System.Numerics.Tensors;
using System.Security.Cryptography;
using System.Text;
using System.Threading.Tasks;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Connectors.Onnx;
using Microsoft.SemanticKernel.Embeddings;
using System;
using Xunit;
using System.Numerics.Tensors;
using Microsoft.SemanticKernel.Connectors.Onnx;
using System.Text;
using System.Net.Http;
using System.Security.Cryptography;

namespace SemanticKernel.Connectors.Onnx.UnitTests;
namespace SemanticKernel.IntegrationTests.Connectors.Onnx;

public class BertOnnxTextEmbeddingGenerationServiceTests
{
private static readonly HttpClient s_client = new();

[Fact]
public void VerifyOptionsDefaults()
{
var options = new BertOnnxOptions();
Assert.False(options.CaseSensitive);
Assert.Equal(512, options.MaximumTokens);
Assert.Equal("[CLS]", options.ClsToken);
Assert.Equal("[UNK]", options.UnknownToken);
Assert.Equal("[SEP]", options.SepToken);
Assert.Equal("[PAD]", options.PadToken);
Assert.Equal(NormalizationForm.FormD, options.UnicodeNormalization);
Assert.Equal(EmbeddingPoolingMode.Mean, options.PoolingMode);
Assert.False(options.NormalizeEmbeddings);
}

[Fact]
public void RoundtripOptionsProperties()
{
var options = new BertOnnxOptions()
{
CaseSensitive = true,
MaximumTokens = 128,
ClsToken = "<A>",
UnknownToken = "<B>",
SepToken = "<C>",
PadToken = "<D>",
UnicodeNormalization = NormalizationForm.FormKC,
PoolingMode = EmbeddingPoolingMode.MeanSquareRootTokensLength,
NormalizeEmbeddings = true,
};

Assert.True(options.CaseSensitive);
Assert.Equal(128, options.MaximumTokens);
Assert.Equal("<A>", options.ClsToken);
Assert.Equal("<B>", options.UnknownToken);
Assert.Equal("<C>", options.SepToken);
Assert.Equal("<D>", options.PadToken);
Assert.Equal(NormalizationForm.FormKC, options.UnicodeNormalization);
Assert.Equal(EmbeddingPoolingMode.MeanSquareRootTokensLength, options.PoolingMode);
Assert.True(options.NormalizeEmbeddings);
}

[Fact]
public void ValidateInvalidOptionsPropertiesThrow()
{
Assert.Throws<ArgumentOutOfRangeException>(() => new BertOnnxOptions() { MaximumTokens = 0 });
Assert.Throws<ArgumentOutOfRangeException>(() => new BertOnnxOptions() { MaximumTokens = -1 });

Assert.Throws<ArgumentNullException>(() => new BertOnnxOptions() { ClsToken = null! });
Assert.Throws<ArgumentException>(() => new BertOnnxOptions() { ClsToken = " " });

Assert.Throws<ArgumentNullException>(() => new BertOnnxOptions() { UnknownToken = null! });
Assert.Throws<ArgumentException>(() => new BertOnnxOptions() { UnknownToken = " " });

Assert.Throws<ArgumentNullException>(() => new BertOnnxOptions() { SepToken = null! });
Assert.Throws<ArgumentException>(() => new BertOnnxOptions() { SepToken = " " });

Assert.Throws<ArgumentNullException>(() => new BertOnnxOptions() { PadToken = null! });
Assert.Throws<ArgumentException>(() => new BertOnnxOptions() { PadToken = " " });

Assert.Throws<ArgumentOutOfRangeException>(() => new BertOnnxOptions() { PoolingMode = (EmbeddingPoolingMode)4 });
}

[Fact]
public async Task ValidateEmbeddingsAreIdempotent()
public async Task ValidateEmbeddingsAreIdempotentAsync()
{
Func<Task<BertOnnxTextEmbeddingGenerationService>>[] funcs =
[
Expand All @@ -110,7 +47,9 @@ public async Task ValidateEmbeddingsAreIdempotent()

foreach (string input in inputs)
{
#pragma warning disable CA1308 // Normalize strings to uppercase
IList<ReadOnlyMemory<float>> results = await service.GenerateEmbeddingsAsync([input, input.ToUpperInvariant(), input.ToLowerInvariant()]);
#pragma warning restore CA1308 // Normalize strings to uppercase
for (int i = 1; i < results.Count; i++)
{
AssertEqualTolerance(results[0].Span, results[i].Span);
Expand All @@ -120,10 +59,10 @@ public async Task ValidateEmbeddingsAreIdempotent()
}

[Fact]
public async Task ValidateExpectedEmbeddingsForBgeMicroV2()
public async Task ValidateExpectedEmbeddingsForBgeMicroV2Async()
{
string modelPath = await GetTestFilePath(BgeMicroV2ModelUrl);
string vocabPath = await GetTestFilePath(BgeMicroV2VocabUrl);
string modelPath = await GetTestFilePathAsync(BgeMicroV2ModelUrl);
string vocabPath = await GetTestFilePathAsync(BgeMicroV2VocabUrl);

using Stream modelStream = File.OpenRead(modelPath);
using Stream vocabStream = File.OpenRead(vocabPath);
Expand Down Expand Up @@ -178,7 +117,7 @@ public async Task ValidateExpectedEmbeddingsForBgeMicroV2()
}

[Fact]
public async Task ValidateExpectedEmbeddingsForAllMiniLML6V2()
public async Task ValidateExpectedEmbeddingsForAllMiniLML6V2Async()
{
using BertOnnxTextEmbeddingGenerationService service = await GetAllMiniLML6V2Async();

Expand All @@ -203,7 +142,7 @@ public async Task ValidateExpectedEmbeddingsForAllMiniLML6V2()
}

[Fact]
public async Task ValidateSimilarityScoresOrderedForBgeMicroV2()
public async Task ValidateSimilarityScoresOrderedForBgeMicroV2Async()
{
using BertOnnxTextEmbeddingGenerationService service = await GetBgeMicroV2ServiceAsync();

Expand Down Expand Up @@ -265,7 +204,7 @@ public async Task ValidateSimilarityScoresOrderedForBgeMicroV2()
}

[Fact]
public async Task ValidateServiceMayBeUsedConcurrently()
public async Task ValidateServiceMayBeUsedConcurrentlyAsync()
{
using BertOnnxTextEmbeddingGenerationService service = await GetBgeMicroV2ServiceAsync();

Expand Down Expand Up @@ -340,7 +279,7 @@ private static bool IsEqualWithTolerance(float expected, float actual)
diff <= MathF.Max(MathF.Abs(expected), MathF.Abs(actual)) * Tolerance;
}

private static async Task<string> GetTestFilePath(string url)
private static async Task<string> GetTestFilePathAsync(string url)
{
// Rather than downloading each model on each use, try to cache it into a temporary file.
// The file's name is computed as a hash of the url.
Expand All @@ -350,15 +289,17 @@ private static async Task<string> GetTestFilePath(string url)

if (!File.Exists(path))
{
using Stream responseStream = await s_client.GetStreamAsync(url);
await using Stream responseStream = await s_client.GetStreamAsync(new Uri(url));
try
{
using FileStream dest = File.OpenWrite(path);
await using FileStream dest = File.OpenWrite(path);
await responseStream.CopyToAsync(dest);
}
catch
{
#pragma warning disable CA1031
try { File.Delete(path); } catch { } // if something goes wrong, try not to leave a bad file in place
#pragma warning restore CA1031
throw;
}
}
Expand All @@ -371,12 +312,12 @@ private static async Task<string> GetTestFilePath(string url)

private static async Task<BertOnnxTextEmbeddingGenerationService> GetBgeMicroV2ServiceAsync() =>
await BertOnnxTextEmbeddingGenerationService.CreateAsync(
await GetTestFilePath(BgeMicroV2ModelUrl),
await GetTestFilePath(BgeMicroV2VocabUrl));
await GetTestFilePathAsync(BgeMicroV2ModelUrl),
await GetTestFilePathAsync(BgeMicroV2VocabUrl));

private static async Task<BertOnnxTextEmbeddingGenerationService> GetAllMiniLML6V2Async() =>
await BertOnnxTextEmbeddingGenerationService.CreateAsync(
await GetTestFilePath("https://huggingface.co/optimum/all-MiniLM-L6-v2/resolve/1024484/model.onnx"),
await GetTestFilePath("https://huggingface.co/optimum/all-MiniLM-L6-v2/raw/1024484/vocab.txt"),
await GetTestFilePathAsync("https://huggingface.co/optimum/all-MiniLM-L6-v2/resolve/1024484/model.onnx"),
await GetTestFilePathAsync("https://huggingface.co/optimum/all-MiniLM-L6-v2/raw/1024484/vocab.txt"),
new BertOnnxOptions { NormalizeEmbeddings = true });
}
1 change: 1 addition & 0 deletions dotnet/src/IntegrationTests/IntegrationTests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\Connectors\Connectors.Google\Connectors.Google.csproj" />
<ProjectReference Include="..\Connectors\Connectors.Onnx\Connectors.Onnx.csproj" />
<ProjectReference Include="..\Connectors\Connectors.OpenAI\Connectors.OpenAI.csproj" />
<ProjectReference Include="..\Connectors\Connectors.HuggingFace\Connectors.HuggingFace.csproj" />
<ProjectReference Include="..\Connectors\Connectors.Memory.Chroma\Connectors.Memory.Chroma.csproj" />
Expand Down

0 comments on commit c84258a

Please sign in to comment.