Skip to content

Commit

Permalink
.Net: make planner constructors more similar (#2849)
Browse files Browse the repository at this point in the history
### Motivation and Context

Fixes #2847,  #2519

Contains breaking changes.

This addresses inconsistencies in planner construction. The planners
sometimes took a prompt override through a delegate in the config and
sometimes as a constructor argument. Additionally, the ActionPlanner
used to take a loggerFactory as a constructor argument while the other
planners did not.

The goal of these changes is to make working with planners more
predictable.

### Description

* Added `GetPromptTemplate` delegate function to the planner base
config. All planners will use this to get the prompt override.
* Removed logger factory argument from ActionPlanner constructor. Used
the Kernel's logger factory to create a new logger within each planner
constructor.
* Renamed private `Config` private instance field to `_config` 

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [ ] I didn't break anyone 😄

---------

Co-authored-by: Shawn Callegari <36091529+shawncal@users.noreply.github.com>
  • Loading branch information
hario90 and shawncal committed Sep 21, 2023
1 parent 59f3346 commit fd391fa
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.AI;
using Microsoft.SemanticKernel.Diagnostics;
Expand Down Expand Up @@ -52,6 +53,25 @@ public async Task InvalidJsonThrowsAsync()
await Assert.ThrowsAsync<SKException>(() => planner.CreatePlanAsync("goal"));
}

[Fact]
public void UsesPromptDelegateWhenProvided()
{
// Arrange
var kernel = new Mock<IKernel>();
kernel.Setup(x => x.LoggerFactory).Returns(NullLoggerFactory.Instance);
var getPromptTemplateMock = new Mock<Func<string>>();
var config = new ActionPlannerConfig()
{
GetPromptTemplate = getPromptTemplateMock.Object
};

// Act
var planner = new Microsoft.SemanticKernel.Planning.ActionPlanner(kernel.Object, config);

// Assert
getPromptTemplateMock.Verify(x => x(), Times.Once());
}

[Fact]
public async Task MalformedJsonThrowsAsync()
{
Expand Down Expand Up @@ -170,6 +190,7 @@ private Mock<IKernel> CreateMockKernelAndFunctionFlowWithTestString(string testP
// Mock Functions
kernel.Setup(x => x.Functions).Returns(functions.Object);
kernel.Setup(x => x.CreateNewContext()).Returns(context);
kernel.Setup(x => x.LoggerFactory).Returns(NullLoggerFactory.Instance);

kernel.Setup(x => x.RegisterSemanticFunction(
It.IsAny<string>(),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
Expand All @@ -9,6 +10,7 @@
using Microsoft.SemanticKernel.AI;
using Microsoft.SemanticKernel.Diagnostics;
using Microsoft.SemanticKernel.Orchestration;
using Microsoft.SemanticKernel.Planning.Sequential;
using Microsoft.SemanticKernel.SemanticFunctions;
using Moq;
using Xunit;
Expand Down Expand Up @@ -195,6 +197,24 @@ public async Task InvalidXMLThrowsAsync()
await Assert.ThrowsAsync<SKException>(async () => await planner.CreatePlanAsync("goal"));
}

[Fact]
public void UsesPromptDelegateWhenProvided()
{
// Arrange
var kernel = new Mock<IKernel>();
var getPromptTemplateMock = new Mock<Func<string>>();
var config = new SequentialPlannerConfig()
{
GetPromptTemplate = getPromptTemplateMock.Object
};

// Act
var planner = new Microsoft.SemanticKernel.Planning.SequentialPlanner(kernel.Object, config);

// Assert
getPromptTemplateMock.Verify(x => x(), Times.Once());
}

// Method to create Mock<ISKFunction> objects
private static Mock<ISKFunction> CreateMockFunction(FunctionView functionView)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Planning.Stepwise;
using Moq;
using Xunit;

namespace SemanticKernel.Extensions.UnitTests.Planning.StepwisePlanner;

public sealed class StepwisePlannerTests
{
[Fact]
public void UsesPromptDelegateWhenProvided()
{
// Arrange
var kernel = new Mock<IKernel>();
kernel.Setup(x => x.LoggerFactory).Returns(NullLoggerFactory.Instance);
var getPromptTemplateMock = new Mock<Func<string>>();
var config = new StepwisePlannerConfig()
{
GetPromptTemplate = getPromptTemplateMock.Object
};

// Act
var planner = new Microsoft.SemanticKernel.Planning.StepwisePlanner(kernel.Object, config);

// Assert
getPromptTemplateMock.Verify(x => x(), Times.Once());
}
}
27 changes: 11 additions & 16 deletions dotnet/src/Extensions/Planning.ActionPlanner/ActionPlanner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.SemanticKernel.AI;
using Microsoft.SemanticKernel.Diagnostics;
using Microsoft.SemanticKernel.Orchestration;
Expand Down Expand Up @@ -55,19 +54,18 @@ public sealed class ActionPlanner : IActionPlanner
/// </summary>
/// <param name="kernel">The semantic kernel instance.</param>
/// <param name="config">The planner configuration.</param>
/// <param name="prompt">Optional prompt override</param>
/// <param name="loggerFactory">The <see cref="ILoggerFactory"/> to use for logging. If null, no logging will be performed.</param>
public ActionPlanner(
IKernel kernel,
ActionPlannerConfig? config = null,
string? prompt = null,
ILoggerFactory? loggerFactory = null)
ActionPlannerConfig? config = null)
{
Verify.NotNull(kernel);
this._kernel = kernel;

this._logger = loggerFactory is not null ? loggerFactory.CreateLogger(typeof(ActionPlanner)) : NullLogger.Instance;
// Set up Config with default values and excluded skills
this.Config = config ?? new();
this.Config.ExcludedPlugins.Add(PluginName);

string promptTemplate = prompt ?? EmbeddedResource.Read("skprompt.txt");
string promptTemplate = this.Config.GetPromptTemplate?.Invoke() ?? EmbeddedResource.Read("skprompt.txt");

this._plannerFunction = kernel.CreateSemanticFunction(
pluginName: PluginName,
Expand All @@ -83,12 +81,9 @@ public sealed class ActionPlanner : IActionPlanner

kernel.ImportPlugin(this, pluginName: PluginName);

this._kernel = kernel;
// Create context and logger
this._context = kernel.CreateNewContext();

// Set up Config with default values and excluded plugins
this._config = config ?? new();
this._config.ExcludedPlugins.Add(PluginName);
this._logger = this._kernel.LoggerFactory.CreateLogger(this.GetType());
}

/// <inheritdoc />
Expand Down Expand Up @@ -248,7 +243,7 @@ No parameters.
/// <summary>
/// The configuration for the ActionPlanner
/// </summary>
private ActionPlannerConfig _config { get; }
private ActionPlannerConfig Config { get; }

/// <summary>
/// Native function that filters out good JSON from planner result in case additional text is present
Expand Down Expand Up @@ -321,8 +316,8 @@ private IOrderedEnumerable<FunctionView> GetAvailableFunctions(SKContext context
{
Verify.NotNull(context.Functions);

var excludedPlugins = this._config.ExcludedPlugins ?? new();
var excludedFunctions = this._config.ExcludedFunctions ?? new();
var excludedPlugins = this.Config.ExcludedPlugins ?? new();
var excludedFunctions = this.Config.ExcludedFunctions ?? new();

var availableFunctions = context.Functions.GetFunctionViews()
.Where(s => !excludedPlugins.Contains(s.PluginName, StringComparer.CurrentCultureIgnoreCase)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,18 @@ public sealed class SequentialPlanner : ISequentialPlanner
/// </summary>
/// <param name="kernel">The semantic kernel instance.</param>
/// <param name="config">The planner configuration.</param>
/// <param name="prompt">Optional prompt override</param>
public SequentialPlanner(
IKernel kernel,
SequentialPlannerConfig? config = null,
string? prompt = null)
SequentialPlannerConfig? config = null)
{
Verify.NotNull(kernel);
this.Config = config ?? new();

// Set up config with default value and excluded skills
this.Config = config ?? new();
this.Config.ExcludedPlugins.Add(RestrictedPluginName);

string promptTemplate = prompt ?? EmbeddedResource.Read("skprompt.txt");
// Set up prompt template
string promptTemplate = this.Config.GetPromptTemplate?.Invoke() ?? EmbeddedResource.Read("skprompt.txt");

this._functionFlowFunction = kernel.CreateSemanticFunction(
promptTemplate: promptTemplate,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,6 @@ public sealed class StepwisePlannerConfig : PlannerConfigBase
/// </summary>
public int MinIterationTimeMs { get; set; } = 0;

/// <summary>
/// Delegate to get the prompt template string.
/// </summary>
public Func<string>? GetPromptTemplate { get; set; } = null;

/// <summary>
/// The configuration to use for the prompt template.
/// </summary>
Expand Down
6 changes: 6 additions & 0 deletions dotnet/src/SemanticKernel/Planning/PlannerConfigBase.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;

namespace Microsoft.SemanticKernel.Planning;
Expand All @@ -18,4 +19,9 @@ public abstract class PlannerConfigBase
/// A list of functions to exclude from the plan creation request.
/// </summary>
public HashSet<string> ExcludedFunctions { get; } = new();

/// <summary>
/// Delegate to get the prompt template string.
/// </summary>
public Func<string>? GetPromptTemplate { get; set; } = null;
}

0 comments on commit fd391fa

Please sign in to comment.