Skip to content

Commit

Permalink
.Net: Summarization and translation evaluation examples with Filters (#…
Browse files Browse the repository at this point in the history
…6262)

### Motivation and Context

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

This example demonstrates how to perform quality check on LLM results
for such tasks as text summarization and translation with Semantic
Kernel Filters.

Metrics used in this example:
- [BERTScore](https://github.com/Tiiiger/bert_score) - leverages the
pre-trained contextual embeddings from BERT and matches words in
candidate and reference sentences by cosine similarity.
- [BLEU](https://en.wikipedia.org/wiki/BLEU) (BiLingual Evaluation
Understudy) - evaluates the quality of text which has been
machine-translated from one natural language to another.
- [METEOR](https://en.wikipedia.org/wiki/METEOR) (Metric for Evaluation
of Translation with Explicit ORdering) - evaluates the similarity
between the generated summary and the reference summary, taking into
account grammar and semantics.
- [COMET](https://unbabel.github.io/COMET) (Crosslingual Optimized
Metric for Evaluation of Translation) - is an open-source framework used
to train Machine Translation metrics that achieve high levels of
correlation with different types of human judgments.

In this example, SK Filters call dedicated server which is responsible
for task evaluation using metrics described above. If evaluation score
of specific metric doesn't meet configured threshold, an exception is
thrown with evaluation details.

[Hugging Face Evaluate Metric](https://github.com/huggingface/evaluate)
library is used to evaluate summarization and translation results.

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄
  • Loading branch information
dmytrostruk committed May 17, 2024
1 parent dbe6aa2 commit 51af5ee
Show file tree
Hide file tree
Showing 20 changed files with 754 additions and 0 deletions.
1 change: 1 addition & 0 deletions .github/_typos.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ EOF = "EOF" # End of File
ans = "ans" # Short for answers
arange = "arange" # Method in Python numpy package
prompty = "prompty" # prompty is a format name.
ist = "ist" # German language

[default.extend-identifiers]
ags = "ags" # Azure Graph Service
Expand Down
9 changes: 9 additions & 0 deletions dotnet/SK-dotnet.sln
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Connectors.Memory.SqlServer
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "CodeInterpreterPlugin", "samples\Demos\CodeInterpreterPlugin\CodeInterpreterPlugin.csproj", "{3ED53702-0E53-473A-A0F4-645DB33541C2}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "QualityCheckWithFilters", "samples\Demos\QualityCheck\QualityCheckWithFilters\QualityCheckWithFilters.csproj", "{1D3EEB5B-0E06-4700-80D5-164956E43D0A}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TimePlugin", "samples\Demos\TimePlugin\TimePlugin.csproj", "{F312FCE1-12D7-4DEF-BC29-2FF6618509F3}"
EndProject
Global
Expand Down Expand Up @@ -748,6 +750,12 @@ Global
{3ED53702-0E53-473A-A0F4-645DB33541C2}.Publish|Any CPU.Build.0 = Debug|Any CPU
{3ED53702-0E53-473A-A0F4-645DB33541C2}.Release|Any CPU.ActiveCfg = Release|Any CPU
{3ED53702-0E53-473A-A0F4-645DB33541C2}.Release|Any CPU.Build.0 = Release|Any CPU
{1D3EEB5B-0E06-4700-80D5-164956E43D0A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{1D3EEB5B-0E06-4700-80D5-164956E43D0A}.Debug|Any CPU.Build.0 = Debug|Any CPU
{1D3EEB5B-0E06-4700-80D5-164956E43D0A}.Publish|Any CPU.ActiveCfg = Debug|Any CPU
{1D3EEB5B-0E06-4700-80D5-164956E43D0A}.Publish|Any CPU.Build.0 = Debug|Any CPU
{1D3EEB5B-0E06-4700-80D5-164956E43D0A}.Release|Any CPU.ActiveCfg = Release|Any CPU
{1D3EEB5B-0E06-4700-80D5-164956E43D0A}.Release|Any CPU.Build.0 = Release|Any CPU
{F312FCE1-12D7-4DEF-BC29-2FF6618509F3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{F312FCE1-12D7-4DEF-BC29-2FF6618509F3}.Debug|Any CPU.Build.0 = Debug|Any CPU
{F312FCE1-12D7-4DEF-BC29-2FF6618509F3}.Publish|Any CPU.ActiveCfg = Debug|Any CPU
Expand Down Expand Up @@ -857,6 +865,7 @@ Global
{6B56D8EE-9991-43E3-90B2-B8F5C5CE77C2} = {5D4C0700-BBB5-418F-A7B2-F392B9A18263}
{24B8041B-92C6-4BB3-A699-C593AF5A870F} = {24503383-A8C4-4255-9998-28D70FE8E99A}
{3ED53702-0E53-473A-A0F4-645DB33541C2} = {5D4C0700-BBB5-418F-A7B2-F392B9A18263}
{1D3EEB5B-0E06-4700-80D5-164956E43D0A} = {5D4C0700-BBB5-418F-A7B2-F392B9A18263}
{F312FCE1-12D7-4DEF-BC29-2FF6618509F3} = {5D4C0700-BBB5-418F-A7B2-F392B9A18263}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel;
using QualityCheckWithFilters.Models;
using QualityCheckWithFilters.Services;

namespace QualityCheckWithFilters.Filters;

/// <summary>
/// Filter which performs text summarization evaluation using BERTScore metric: https://huggingface.co/spaces/evaluate-metric/bertscore.
/// Evaluation result contains three values: precision, recall and F1 score.
/// The higher F1 score - the better the quality of the summary.
/// </summary>
internal sealed class BertSummarizationEvaluationFilter(
EvaluationService evaluationService,
ILogger logger,
double threshold) : IFunctionInvocationFilter
{
public async Task OnFunctionInvocationAsync(FunctionInvocationContext context, Func<FunctionInvocationContext, Task> next)
{
await next(context);

var sourceText = context.Result.RenderedPrompt!;
var summary = context.Result.ToString();

var request = new SummarizationEvaluationRequest { Sources = [sourceText], Summaries = [summary] };
var response = await evaluationService.EvaluateAsync<SummarizationEvaluationRequest, BertSummarizationEvaluationResponse>(request);

var precision = Math.Round(response.Precision[0], 4);
var recall = Math.Round(response.Recall[0], 4);
var f1 = Math.Round(response.F1[0], 4);

logger.LogInformation("[BERT] Precision: {Precision}, Recall: {Recall}, F1: {F1}", precision, recall, f1);

if (f1 < threshold)
{
throw new KernelException($"BERT summary evaluation score ({f1}) is lower than threshold ({threshold})");
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel;
using QualityCheckWithFilters.Models;
using QualityCheckWithFilters.Services;

namespace QualityCheckWithFilters.Filters;

/// <summary>
/// Filter which performs text summarization evaluation using BLEU metric: https://huggingface.co/spaces/evaluate-metric/bleu.
/// Evaluation result contains values like score, precisions, brevity penalty and length ratio.
/// The closer the score and precision values are to 1 - the better the quality of the summary.
/// </summary>
internal sealed class BleuSummarizationEvaluationFilter(
EvaluationService evaluationService,
ILogger logger,
double threshold) : IFunctionInvocationFilter
{
public async Task OnFunctionInvocationAsync(FunctionInvocationContext context, Func<FunctionInvocationContext, Task> next)
{
await next(context);

var sourceText = context.Result.RenderedPrompt!;
var summary = context.Result.ToString();

var request = new SummarizationEvaluationRequest { Sources = [sourceText], Summaries = [summary] };
var response = await evaluationService.EvaluateAsync<SummarizationEvaluationRequest, BleuSummarizationEvaluationResponse>(request);

var score = Math.Round(response.Score, 4);
var precisions = response.Precisions.Select(l => Math.Round(l, 4)).ToList();
var brevityPenalty = Math.Round(response.BrevityPenalty, 4);
var lengthRatio = Math.Round(response.LengthRatio, 4);

logger.LogInformation("[BLEU] Score: {Score}, Precisions: {Precisions}, Brevity penalty: {BrevityPenalty}, Length Ratio: {LengthRatio}",
score,
string.Join(", ", precisions),
brevityPenalty,
lengthRatio);

if (precisions[0] < threshold)
{
throw new KernelException($"BLEU summary evaluation score ({precisions[0]}) is lower than threshold ({threshold})");
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel;
using QualityCheckWithFilters.Models;
using QualityCheckWithFilters.Services;

namespace QualityCheckWithFilters.Filters;

/// <summary>
/// Filter which performs text translation evaluation using COMET metric: https://huggingface.co/Unbabel/wmt22-cometkiwi-da.
/// COMET score ranges from 0 to 1, where higher values indicate better translation.
/// </summary>
internal sealed class CometTranslationEvaluationFilter(
EvaluationService evaluationService,
ILogger logger,
double threshold) : IFunctionInvocationFilter
{
public async Task OnFunctionInvocationAsync(FunctionInvocationContext context, Func<FunctionInvocationContext, Task> next)
{
await next(context);

var sourceText = context.Result.RenderedPrompt!;
var translation = context.Result.ToString();

logger.LogInformation("Translation: {Translation}", translation);

var request = new TranslationEvaluationRequest { Sources = [sourceText], Translations = [translation] };
var response = await evaluationService.EvaluateAsync<TranslationEvaluationRequest, CometTranslationEvaluationResponse>(request);

var score = Math.Round(response.Scores[0], 4);

logger.LogInformation("[COMET] Score: {Score}", score);

if (score < threshold)
{
throw new KernelException($"COMET translation evaluation score ({score}) is lower than threshold ({threshold})");
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel;
using QualityCheckWithFilters.Models;
using QualityCheckWithFilters.Services;

namespace QualityCheckWithFilters.Filters;

/// <summary>
/// Factory class for function invocation filters based on evaluation score type.
/// </summary>
internal sealed class FilterFactory
{
private static readonly Dictionary<EvaluationScoreType, Func<EvaluationService, ILogger, double, IFunctionInvocationFilter>> s_filters = new()
{
[EvaluationScoreType.BERT] = (service, logger, threshold) => new BertSummarizationEvaluationFilter(service, logger, threshold),
[EvaluationScoreType.BLEU] = (service, logger, threshold) => new BleuSummarizationEvaluationFilter(service, logger, threshold),
[EvaluationScoreType.METEOR] = (service, logger, threshold) => new MeteorSummarizationEvaluationFilter(service, logger, threshold),
[EvaluationScoreType.COMET] = (service, logger, threshold) => new CometTranslationEvaluationFilter(service, logger, threshold),
};

public static IFunctionInvocationFilter Create(EvaluationScoreType type, EvaluationService evaluationService, ILogger logger, double threshold)
=> s_filters[type].Invoke(evaluationService, logger, threshold);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel;
using QualityCheckWithFilters.Models;
using QualityCheckWithFilters.Services;

namespace QualityCheckWithFilters.Filters;

/// <summary>
/// Filter which performs text summarization evaluation using METEOR metric: https://huggingface.co/spaces/evaluate-metric/meteor.
/// METEOR score ranges from 0 to 1, where higher values indicate better similarity between original text and generated summary.
/// </summary>
internal sealed class MeteorSummarizationEvaluationFilter(
EvaluationService evaluationService,
ILogger logger,
double threshold) : IFunctionInvocationFilter
{
public async Task OnFunctionInvocationAsync(FunctionInvocationContext context, Func<FunctionInvocationContext, Task> next)
{
await next(context);

var sourceText = context.Result.RenderedPrompt!;
var summary = context.Result.ToString();

var request = new SummarizationEvaluationRequest { Sources = [sourceText], Summaries = [summary] };
var response = await evaluationService.EvaluateAsync<SummarizationEvaluationRequest, MeteorSummarizationEvaluationResponse>(request);

var score = Math.Round(response.Score, 4);

logger.LogInformation("[METEOR] Score: {Score}", score);

if (score < threshold)
{
throw new KernelException($"METEOR summary evaluation score ({score}) is lower than threshold ({threshold})");
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Text.Json.Serialization;

namespace QualityCheckWithFilters.Models;

/// <summary>Base request model with source texts.</summary>
internal class EvaluationRequest
{
[JsonPropertyName("sources")]
public List<string> Sources { get; set; }
}

/// <summary>Request model with generated summaries.</summary>
internal sealed class SummarizationEvaluationRequest : EvaluationRequest
{
[JsonPropertyName("summaries")]
public List<string> Summaries { get; set; }
}

/// <summary>Request model with generated translations.</summary>
internal sealed class TranslationEvaluationRequest : EvaluationRequest
{
[JsonPropertyName("translations")]
public List<string> Translations { get; set; }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Text.Json.Serialization;

namespace QualityCheckWithFilters.Models;

/// <summary>Response model for BERTScore metric: https://huggingface.co/spaces/evaluate-metric/bertscore.</summary>
internal sealed class BertSummarizationEvaluationResponse
{
[JsonPropertyName("precision")]
public List<double> Precision { get; set; }

[JsonPropertyName("recall")]
public List<double> Recall { get; set; }

[JsonPropertyName("f1")]
public List<double> F1 { get; set; }
}

/// <summary>Response model for BLEU metric: https://huggingface.co/spaces/evaluate-metric/bleu.</summary>
internal sealed class BleuSummarizationEvaluationResponse
{
[JsonPropertyName("bleu")]
public double Score { get; set; }

[JsonPropertyName("precisions")]
public List<double> Precisions { get; set; }

[JsonPropertyName("brevity_penalty")]
public double BrevityPenalty { get; set; }

[JsonPropertyName("length_ratio")]
public double LengthRatio { get; set; }
}

/// <summary>Response model for METEOR metric: https://huggingface.co/spaces/evaluate-metric/meteor.</summary>
internal sealed class MeteorSummarizationEvaluationResponse
{
[JsonPropertyName("meteor")]
public double Score { get; set; }
}

/// <summary>Response model for COMET metric: https://huggingface.co/Unbabel/wmt22-cometkiwi-da.</summary>
internal sealed class CometTranslationEvaluationResponse
{
[JsonPropertyName("scores")]
public List<double> Scores { get; set; }

[JsonPropertyName("system_score")]
public double SystemScore { get; set; }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Diagnostics.CodeAnalysis;

namespace QualityCheckWithFilters.Models;

/// <summary>
/// Internal representation of evaluation score type to configure and run examples.
/// </summary>
internal readonly struct EvaluationScoreType(string endpoint) : IEquatable<EvaluationScoreType>
{
public string Endpoint { get; } = endpoint;

public static EvaluationScoreType BERT = new("bert-score");
public static EvaluationScoreType BLEU = new("bleu-score");
public static EvaluationScoreType METEOR = new("meteor-score");
public static EvaluationScoreType COMET = new("comet-score");

public static bool operator ==(EvaluationScoreType left, EvaluationScoreType right) => left.Equals(right);
public static bool operator !=(EvaluationScoreType left, EvaluationScoreType right) => !(left == right);

/// <inheritdoc/>
public override bool Equals([NotNullWhen(true)] object? obj) => obj is EvaluationScoreType other && this == other;

/// <inheritdoc/>
public bool Equals(EvaluationScoreType other) => string.Equals(this.Endpoint, other.Endpoint, StringComparison.OrdinalIgnoreCase);

/// <inheritdoc/>
public override int GetHashCode() => StringComparer.OrdinalIgnoreCase.GetHashCode(this.Endpoint ?? string.Empty);

/// <inheritdoc/>
public override string ToString() => this.Endpoint ?? string.Empty;
}

0 comments on commit 51af5ee

Please sign in to comment.