Skip to content

Commit

Permalink
.Net: Fix filters cloning when registered via Kernel properties (#6241)
Browse files Browse the repository at this point in the history
### Motivation and Context

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

Based on: #6240

Since filters are cloned when they are registered through DI container,
in the same way they should be cloned when registered through Kernel
properties (i.e. `kernel.FunctionInvocationFilters`).

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄
  • Loading branch information
dmytrostruk committed May 14, 2024
1 parent 1692207 commit af207dc
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 6 deletions.
3 changes: 3 additions & 0 deletions dotnet/src/SemanticKernel.Abstractions/Kernel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ public sealed class Kernel
FunctionInvoked = this.FunctionInvoked,
PromptRendering = this.PromptRendering,
PromptRendered = this.PromptRendered,
_functionInvocationFilters = this._functionInvocationFilters is { Count: > 0 } ? new NonNullCollection<IFunctionInvocationFilter>(this._functionInvocationFilters) : null,
_promptRenderFilters = this._promptRenderFilters is { Count: > 0 } ? new NonNullCollection<IPromptRenderFilter>(this._promptRenderFilters) : null,
_autoFunctionInvocationFilters = this._autoFunctionInvocationFilters is { Count: > 0 } ? new NonNullCollection<IAutoFunctionInvocationFilter>(this._autoFunctionInvocationFilters) : null,
_data = this._data is { Count: > 0 } ? new Dictionary<string, object?>(this._data) : null,
_culture = this._culture,
};
Expand Down
15 changes: 9 additions & 6 deletions dotnet/src/SemanticKernel.UnitTests/Filters/FilterBaseTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,21 @@ protected Mock<ITextGenerationService> GetMockTextGeneration(string? textResult
protected sealed class FakeFunctionFilter(
Func<FunctionInvocationContext, Func<FunctionInvocationContext, Task>, Task>? onFunctionInvocation) : IFunctionInvocationFilter
{
private readonly Func<FunctionInvocationContext, Func<FunctionInvocationContext, Task>, Task>? _onFunctionInvocation = onFunctionInvocation;

public Task OnFunctionInvocationAsync(FunctionInvocationContext context, Func<FunctionInvocationContext, Task> next) =>
this._onFunctionInvocation?.Invoke(context, next) ?? Task.CompletedTask;
onFunctionInvocation?.Invoke(context, next) ?? Task.CompletedTask;
}

protected sealed class FakePromptFilter(
Func<PromptRenderContext, Func<PromptRenderContext, Task>, Task>? onPromptRender) : IPromptRenderFilter
{
private readonly Func<PromptRenderContext, Func<PromptRenderContext, Task>, Task>? _onPromptRender = onPromptRender;

public Task OnPromptRenderAsync(PromptRenderContext context, Func<PromptRenderContext, Task> next) =>
this._onPromptRender?.Invoke(context, next) ?? Task.CompletedTask;
onPromptRender?.Invoke(context, next) ?? Task.CompletedTask;
}

protected sealed class FakeAutoFunctionFilter(
Func<AutoFunctionInvocationContext, Func<AutoFunctionInvocationContext, Task>, Task>? onAutoFunctionInvocation) : IAutoFunctionInvocationFilter
{
public Task OnAutoFunctionInvocationAsync(AutoFunctionInvocationContext context, Func<AutoFunctionInvocationContext, Task> next) =>
onAutoFunctionInvocation?.Invoke(context, next) ?? Task.CompletedTask;
}
}
68 changes: 68 additions & 0 deletions dotnet/src/SemanticKernel.UnitTests/Filters/KernelFilterTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.Extensions.DependencyInjection;
using Microsoft.SemanticKernel;
using Xunit;

namespace SemanticKernel.UnitTests.Filters;

public class KernelFilterTests : FilterBaseTest
{
[Fact]
public void FiltersAreClonedWhenRegisteredWithDI()
{
// Arrange
var functionFilter = new FakeFunctionFilter(onFunctionInvocation: async (context, next) => { await next(context); });
var promptFilter = new FakePromptFilter(onPromptRender: async (context, next) => { await next(context); });
var autoFunctionFilter = new FakeAutoFunctionFilter(onAutoFunctionInvocation: async (context, next) => { await next(context); });

var builder = Kernel.CreateBuilder();

builder.Services.AddSingleton<IFunctionInvocationFilter>(functionFilter);
builder.Services.AddSingleton<IPromptRenderFilter>(promptFilter);
builder.Services.AddSingleton<IAutoFunctionInvocationFilter>(autoFunctionFilter);

var kernel = builder.Build();

// Act
var clonedKernel = kernel.Clone();

// Assert
Assert.Single(kernel.FunctionInvocationFilters);
Assert.Single(kernel.PromptRenderFilters);
Assert.Single(kernel.AutoFunctionInvocationFilters);

Assert.Single(clonedKernel.FunctionInvocationFilters);
Assert.Single(clonedKernel.PromptRenderFilters);
Assert.Single(clonedKernel.AutoFunctionInvocationFilters);
}

[Fact]
public void FiltersAreClonedWhenRegisteredWithKernelProperties()
{
// Arrange
var functionFilter = new FakeFunctionFilter(onFunctionInvocation: async (context, next) => { await next(context); });
var promptFilter = new FakePromptFilter(onPromptRender: async (context, next) => { await next(context); });
var autoFunctionFilter = new FakeAutoFunctionFilter(onAutoFunctionInvocation: async (context, next) => { await next(context); });

var builder = Kernel.CreateBuilder();

var kernel = builder.Build();

kernel.FunctionInvocationFilters.Add(functionFilter);
kernel.PromptRenderFilters.Add(promptFilter);
kernel.AutoFunctionInvocationFilters.Add(autoFunctionFilter);

// Act
var clonedKernel = kernel.Clone();

// Assert
Assert.Single(kernel.FunctionInvocationFilters);
Assert.Single(kernel.PromptRenderFilters);
Assert.Single(kernel.AutoFunctionInvocationFilters);

Assert.Single(clonedKernel.FunctionInvocationFilters);
Assert.Single(clonedKernel.PromptRenderFilters);
Assert.Single(clonedKernel.AutoFunctionInvocationFilters);
}
}

0 comments on commit af207dc

Please sign in to comment.