Skip to content

Commit

Permalink
.Net: Enhance Service Selector to use ExecutionSettings from Invoke A…
Browse files Browse the repository at this point in the history
…sync Calls (#4141)

### Motivation and Context

Closes #4127

### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [ ] The code builds clean without any errors or warnings
- [ ] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [ ] All unit tests pass, and I have added new tests where possible
- [ ] I didn't break anyone 😄
  • Loading branch information
markwallace-microsoft committed Dec 11, 2023
1 parent ae64b16 commit a9c958b
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using Microsoft.Extensions.DependencyInjection;
Expand All @@ -20,7 +21,10 @@ internal sealed class OrderedAIServiceSelector : IAIServiceSelector
[NotNullWhen(true)] out T? service,
out PromptExecutionSettings? serviceSettings) where T : class, IAIService
{
var executionSettings = function.ExecutionSettings;
// Allow the execution settings from the kernel arguments to take precedence
var executionSettings = arguments.ExecutionSettings is not null
? new List<PromptExecutionSettings> { arguments.ExecutionSettings }
: function.ExecutionSettings;
if (executionSettings is null || executionSettings.Count == 0)
{
service = GetAnyService(kernel);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,158 @@ await foreach (var chunk in kernel.InvokeStreamingAsync<StreamingChatMessageCont
}
}

[Fact]
public async Task InvokeAsyncUsesPromptExecutionSettingsAsync()
{
// Arrange
var mockTextContent = new TextContent("Result");
var mockTextCompletion = new Mock<ITextGenerationService>();
mockTextCompletion.Setup(m => m.GetTextContentsAsync(It.IsAny<string>(), It.IsAny<PromptExecutionSettings>(), It.IsAny<Kernel>(), It.IsAny<CancellationToken>())).ReturnsAsync(new List<TextContent> { mockTextContent });
KernelBuilder builder = new();
builder.Services.AddTransient<ITextGenerationService>((sp) => mockTextCompletion.Object);
Kernel kernel = builder.Build();

KernelFunction function = KernelFunctionFactory.CreateFromPrompt("Anything", new OpenAIPromptExecutionSettings { MaxTokens = 1000 });

// Act
var result = await kernel.InvokeAsync(function);

// Assert
Assert.Equal("Result", result.GetValue<string>());
mockTextCompletion.Verify(m => m.GetTextContentsAsync("Anything", It.Is<OpenAIPromptExecutionSettings>(settings => settings.MaxTokens == 1000), It.IsAny<Kernel>(), It.IsAny<CancellationToken>()), Times.Once());
}

[Fact]
public async Task InvokeAsyncUsesKernelArgumentsExecutionSettingsAsync()
{
// Arrange
var mockTextContent = new TextContent("Result");
var mockTextCompletion = new Mock<ITextGenerationService>();
mockTextCompletion.Setup(m => m.GetTextContentsAsync(It.IsAny<string>(), It.IsAny<PromptExecutionSettings>(), It.IsAny<Kernel>(), It.IsAny<CancellationToken>())).ReturnsAsync(new List<TextContent> { mockTextContent });
KernelBuilder builder = new();
builder.Services.AddTransient<ITextGenerationService>((sp) => mockTextCompletion.Object);
Kernel kernel = builder.Build();

KernelFunction function = KernelFunctionFactory.CreateFromPrompt("Anything", new OpenAIPromptExecutionSettings { MaxTokens = 1000 });

// Act
var result = await kernel.InvokeAsync(function, new KernelArguments(new OpenAIPromptExecutionSettings { MaxTokens = 2000 }));

// Assert
Assert.Equal("Result", result.GetValue<string>());
mockTextCompletion.Verify(m => m.GetTextContentsAsync("Anything", It.Is<OpenAIPromptExecutionSettings>(settings => settings.MaxTokens == 2000), It.IsAny<Kernel>(), It.IsAny<CancellationToken>()), Times.Once());
}

[Fact]
public async Task InvokeAsyncWithServiceIdUsesKernelArgumentsExecutionSettingsAsync()
{
// Arrange
var mockTextContent = new TextContent("Result");
var mockTextCompletion = new Mock<ITextGenerationService>();
mockTextCompletion.Setup(m => m.GetTextContentsAsync(It.IsAny<string>(), It.IsAny<PromptExecutionSettings>(), It.IsAny<Kernel>(), It.IsAny<CancellationToken>())).ReturnsAsync(new List<TextContent> { mockTextContent });
KernelBuilder builder = new();
builder.Services.AddKeyedSingleton<ITextGenerationService>("service1", mockTextCompletion.Object);
Kernel kernel = builder.Build();

KernelFunction function = KernelFunctionFactory.CreateFromPrompt("Anything", new OpenAIPromptExecutionSettings { ServiceId = "service1", MaxTokens = 1000 });

// Act
var result = await kernel.InvokeAsync(function, new KernelArguments(new OpenAIPromptExecutionSettings { MaxTokens = 2000 }));

// Assert
Assert.Equal("Result", result.GetValue<string>());
mockTextCompletion.Verify(m => m.GetTextContentsAsync("Anything", It.Is<OpenAIPromptExecutionSettings>(settings => settings.MaxTokens == 2000), It.IsAny<Kernel>(), It.IsAny<CancellationToken>()), Times.Once());
}

[Fact]
public async Task InvokeAsyncWithMultipleServicesUsesKernelArgumentsExecutionSettingsAsync()
{
// Arrange
var mockTextContent1 = new TextContent("Result1");
var mockTextCompletion1 = new Mock<ITextGenerationService>();
mockTextCompletion1.Setup(m => m.GetTextContentsAsync(It.IsAny<string>(), It.IsAny<PromptExecutionSettings>(), It.IsAny<Kernel>(), It.IsAny<CancellationToken>())).ReturnsAsync(new List<TextContent> { mockTextContent1 });
var mockTextContent2 = new TextContent("Result2");
var mockTextCompletion2 = new Mock<ITextGenerationService>();
mockTextCompletion2.Setup(m => m.GetTextContentsAsync(It.IsAny<string>(), It.IsAny<PromptExecutionSettings>(), It.IsAny<Kernel>(), It.IsAny<CancellationToken>())).ReturnsAsync(new List<TextContent> { mockTextContent2 });

KernelBuilder builder = new();
builder.Services.AddKeyedSingleton<ITextGenerationService>("service1", mockTextCompletion1.Object);
builder.Services.AddKeyedSingleton<ITextGenerationService>("service2", mockTextCompletion2.Object);
Kernel kernel = builder.Build();

KernelFunction function1 = KernelFunctionFactory.CreateFromPrompt("Prompt1", new OpenAIPromptExecutionSettings { ServiceId = "service1", MaxTokens = 1000 });
KernelFunction function2 = KernelFunctionFactory.CreateFromPrompt("Prompt2", new OpenAIPromptExecutionSettings { ServiceId = "service2", MaxTokens = 2000 });

// Act
var result1 = await kernel.InvokeAsync(function1);
var result2 = await kernel.InvokeAsync(function2);

// Assert
Assert.Equal("Result1", result1.GetValue<string>());
mockTextCompletion1.Verify(m => m.GetTextContentsAsync("Prompt1", It.Is<OpenAIPromptExecutionSettings>(settings => settings.MaxTokens == 1000), It.IsAny<Kernel>(), It.IsAny<CancellationToken>()), Times.Once());
Assert.Equal("Result2", result2.GetValue<string>());
mockTextCompletion2.Verify(m => m.GetTextContentsAsync("Prompt2", It.Is<OpenAIPromptExecutionSettings>(settings => settings.MaxTokens == 2000), It.IsAny<Kernel>(), It.IsAny<CancellationToken>()), Times.Once());
}

[Fact]
public async Task InvokeAsyncWithMultipleServicesUsesServiceFromKernelArgumentsExecutionSettingsAsync()
{
// Arrange
var mockTextContent1 = new TextContent("Result1");
var mockTextCompletion1 = new Mock<ITextGenerationService>();
mockTextCompletion1.Setup(m => m.GetTextContentsAsync(It.IsAny<string>(), It.IsAny<PromptExecutionSettings>(), It.IsAny<Kernel>(), It.IsAny<CancellationToken>())).ReturnsAsync(new List<TextContent> { mockTextContent1 });
var mockTextContent2 = new TextContent("Result2");
var mockTextCompletion2 = new Mock<ITextGenerationService>();
mockTextCompletion2.Setup(m => m.GetTextContentsAsync(It.IsAny<string>(), It.IsAny<PromptExecutionSettings>(), It.IsAny<Kernel>(), It.IsAny<CancellationToken>())).ReturnsAsync(new List<TextContent> { mockTextContent2 });

KernelBuilder builder = new();
builder.Services.AddKeyedSingleton<ITextGenerationService>("service1", mockTextCompletion1.Object);
builder.Services.AddKeyedSingleton<ITextGenerationService>("service2", mockTextCompletion2.Object);
Kernel kernel = builder.Build();

KernelFunction function = KernelFunctionFactory.CreateFromPrompt("Prompt");

// Act
var result1 = await kernel.InvokeAsync(function, new(new OpenAIPromptExecutionSettings { ServiceId = "service1", MaxTokens = 1000 }));
var result2 = await kernel.InvokeAsync(function, new(new OpenAIPromptExecutionSettings { ServiceId = "service2", MaxTokens = 2000 }));

// Assert
Assert.Equal("Result1", result1.GetValue<string>());
mockTextCompletion1.Verify(m => m.GetTextContentsAsync("Prompt", It.Is<OpenAIPromptExecutionSettings>(settings => settings.MaxTokens == 1000), It.IsAny<Kernel>(), It.IsAny<CancellationToken>()), Times.Once());
Assert.Equal("Result2", result2.GetValue<string>());
mockTextCompletion2.Verify(m => m.GetTextContentsAsync("Prompt", It.Is<OpenAIPromptExecutionSettings>(settings => settings.MaxTokens == 2000), It.IsAny<Kernel>(), It.IsAny<CancellationToken>()), Times.Once());
}

[Fact]
public async Task InvokeAsyncWithMultipleServicesUsesKernelArgumentsExecutionSettingsOverrideAsync()
{
// Arrange
var mockTextContent1 = new TextContent("Result1");
var mockTextCompletion1 = new Mock<ITextGenerationService>();
mockTextCompletion1.Setup(m => m.GetTextContentsAsync(It.IsAny<string>(), It.IsAny<PromptExecutionSettings>(), It.IsAny<Kernel>(), It.IsAny<CancellationToken>())).ReturnsAsync(new List<TextContent> { mockTextContent1 });
var mockTextContent2 = new TextContent("Result2");
var mockTextCompletion2 = new Mock<ITextGenerationService>();
mockTextCompletion2.Setup(m => m.GetTextContentsAsync(It.IsAny<string>(), It.IsAny<PromptExecutionSettings>(), It.IsAny<Kernel>(), It.IsAny<CancellationToken>())).ReturnsAsync(new List<TextContent> { mockTextContent2 });

KernelBuilder builder = new();
builder.Services.AddKeyedSingleton<ITextGenerationService>("service1", mockTextCompletion1.Object);
builder.Services.AddKeyedSingleton<ITextGenerationService>("service2", mockTextCompletion2.Object);
Kernel kernel = builder.Build();

KernelFunction function1 = KernelFunctionFactory.CreateFromPrompt("Prompt1", new OpenAIPromptExecutionSettings { ServiceId = "service1", MaxTokens = 1000 });
KernelFunction function2 = KernelFunctionFactory.CreateFromPrompt("Prompt2", new OpenAIPromptExecutionSettings { ServiceId = "service2", MaxTokens = 2000 });

// Act
var result1 = await kernel.InvokeAsync(function1, new(new OpenAIPromptExecutionSettings { ServiceId = "service2", MaxTokens = 2000 }));
var result2 = await kernel.InvokeAsync(function2, new(new OpenAIPromptExecutionSettings { ServiceId = "service1", MaxTokens = 1000 }));

// Assert
Assert.Equal("Result2", result1.GetValue<string>());
mockTextCompletion2.Verify(m => m.GetTextContentsAsync("Prompt1", It.Is<OpenAIPromptExecutionSettings>(settings => settings.MaxTokens == 2000), It.IsAny<Kernel>(), It.IsAny<CancellationToken>()), Times.Once());
Assert.Equal("Result1", result2.GetValue<string>());
mockTextCompletion1.Verify(m => m.GetTextContentsAsync("Prompt2", It.Is<OpenAIPromptExecutionSettings>(settings => settings.MaxTokens == 1000), It.IsAny<Kernel>(), It.IsAny<CancellationToken>()), Times.Once());
}

#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
#pragma warning disable IDE1006 // Naming Styles
private async IAsyncEnumerable<T> ToAsyncEnumerable<T>(IEnumerable<T> enumeration)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

namespace SemanticKernel.UnitTests.Functions;

public class OrderedAIServiceConfigurationProviderTests
public class OrderedAIServiceSelectorTests
{
[Fact]
public void ItThrowsAnSKExceptionForNoServices()
Expand Down

0 comments on commit a9c958b

Please sign in to comment.