Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for custom context in plan execution #826

Merged
merged 7 commits into from
May 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,9 @@ public ContextVariables Clone()
return clone;
}

#region private ================================================================================
internal const string MainKey = "INPUT";

private const string MainKey = "INPUT";
#region private ================================================================================

// Important: names are case insensitive
private readonly ConcurrentDictionary<string, string> _variables = new(StringComparer.OrdinalIgnoreCase);
Expand Down
179 changes: 179 additions & 0 deletions dotnet/src/SemanticKernel.UnitTests/Planning/PlanTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,185 @@ public async Task CanExecutePlanWithOneStepAndStateAsync()
mockFunction.Verify(x => x.InvokeAsync(It.IsAny<SKContext>(), null, null, default), Times.Once);
}

[Fact]
public async Task CanExecutePlanWithStateAsync()
{
// Arrange
var kernel = new Mock<IKernel>();
var log = new Mock<ILogger>();
var memory = new Mock<ISemanticTextMemory>();
var skills = new Mock<ISkillCollection>();

var returnContext = new SKContext(
new ContextVariables(),
memory.Object,
skills.Object,
log.Object
);

var mockFunction = new Mock<ISKFunction>();
mockFunction.Setup(x => x.InvokeAsync(It.IsAny<SKContext>(), null, null, default))
.Callback<SKContext, CompleteRequestSettings, ILogger, CancellationToken?>((c, s, l, ct) =>
{
c.Variables.Get("type", out var t);
returnContext.Variables.Update($"Here is a {t} about " + c.Variables.Input);
})
.Returns(() => Task.FromResult(returnContext));

var planStep = new Plan(mockFunction.Object);
planStep.Parameters.Set("type", string.Empty);
var plan = new Plan(string.Empty);
plan.AddSteps(planStep);
plan.State.Set("input", "Cleopatra");
plan.State.Set("type", "poem");

// Act
var result = await plan.InvokeAsync();

// Assert
Assert.NotNull(result);
Assert.Equal($"Here is a poem about Cleopatra", result.Result);
mockFunction.Verify(x => x.InvokeAsync(It.IsAny<SKContext>(), null, null, default), Times.Once);
}

[Fact]
public async Task CanExecutePlanWithCustomContextAsync()
{
// Arrange
var kernel = new Mock<IKernel>();
var log = new Mock<ILogger>();
var memory = new Mock<ISemanticTextMemory>();
var skills = new Mock<ISkillCollection>();

var returnContext = new SKContext(
new ContextVariables(),
memory.Object,
skills.Object,
log.Object
);

var mockFunction = new Mock<ISKFunction>();
mockFunction.Setup(x => x.InvokeAsync(It.IsAny<SKContext>(), null, null, default))
.Callback<SKContext, CompleteRequestSettings, ILogger, CancellationToken?>((c, s, l, ct) =>
{
c.Variables.Get("type", out var t);
returnContext.Variables.Update($"Here is a {t} about " + c.Variables.Input);
})
.Returns(() => Task.FromResult(returnContext));

var plan = new Plan(mockFunction.Object);
plan.State.Set("input", "Cleopatra");
plan.State.Set("type", "poem");

// Act
var result = await plan.InvokeAsync();

// Assert
Assert.NotNull(result);
Assert.Equal($"Here is a poem about Cleopatra", result.Result);
mockFunction.Verify(x => x.InvokeAsync(It.IsAny<SKContext>(), null, null, default), Times.Once);

plan = new Plan(mockFunction.Object);
plan.State.Set("input", "Cleopatra");
plan.State.Set("type", "poem");

var contextOverride = new SKContext(
new ContextVariables(),
memory.Object,
skills.Object,
log.Object
);
contextOverride.Variables.Set("type", "joke");
contextOverride.Variables.Update("Medusa");

// Act
result = await plan.InvokeAsync(contextOverride);

// Assert
Assert.NotNull(result);
Assert.Equal($"Here is a joke about Medusa", result.Result);
mockFunction.Verify(x => x.InvokeAsync(It.IsAny<SKContext>(), null, null, default), Times.Exactly(2));
}

[Fact]
public async Task CanExecutePlanWithCustomStateAsync()
{
// Arrange
var kernel = new Mock<IKernel>();
var log = new Mock<ILogger>();
var memory = new Mock<ISemanticTextMemory>();
var skills = new Mock<ISkillCollection>();

var returnContext = new SKContext(
new ContextVariables(),
memory.Object,
skills.Object,
log.Object
);

var mockFunction = new Mock<ISKFunction>();
mockFunction.Setup(x => x.InvokeAsync(It.IsAny<SKContext>(), null, null, default))
.Callback<SKContext, CompleteRequestSettings, ILogger, CancellationToken?>((c, s, l, ct) =>
{
c.Variables.Get("type", out var t);
returnContext.Variables.Update($"Here is a {t} about " + c.Variables.Input);
})
.Returns(() => Task.FromResult(returnContext));

var planStep = new Plan(mockFunction.Object);
planStep.Parameters.Set("type", string.Empty);
var plan = new Plan("A plan");
plan.State.Set("input", "Medusa");
plan.State.Set("type", "joke");
plan.AddSteps(planStep);

// Act
var result = await plan.InvokeAsync();

// Assert
Assert.NotNull(result);
Assert.Equal($"Here is a joke about Medusa", result.Result);
mockFunction.Verify(x => x.InvokeAsync(It.IsAny<SKContext>(), null, null, default), Times.Once);

planStep = new Plan(mockFunction.Object);
plan = new Plan("A plan");
planStep.Parameters.Set("input", "Medusa");
planStep.Parameters.Set("type", "joke");
plan.State.Set("input", "Cleopatra"); // state input will not override parameter
plan.State.Set("type", "poem");
plan.AddSteps(planStep);

// Act
result = await plan.InvokeAsync();

// Assert
Assert.NotNull(result);
Assert.Equal($"Here is a poem about Medusa", result.Result);
mockFunction.Verify(x => x.InvokeAsync(It.IsAny<SKContext>(), null, null, default), Times.Exactly(2));

planStep = new Plan(mockFunction.Object);
plan = new Plan("A plan");
planStep.Parameters.Set("input", "Cleopatra");
planStep.Parameters.Set("type", "poem");
plan.AddSteps(planStep);
var contextOverride = new SKContext(
new ContextVariables(),
memory.Object,
skills.Object,
log.Object
);
contextOverride.Variables.Set("type", "joke");
contextOverride.Variables.Update("Medusa"); // context input will not override parameters

// Act
result = await plan.InvokeAsync(contextOverride);

// Assert
Assert.NotNull(result);
Assert.Equal($"Here is a joke about Cleopatra", result.Result);
mockFunction.Verify(x => x.InvokeAsync(It.IsAny<SKContext>(), null, null, default), Times.Exactly(3));
}

[Fact]
public async Task CanExecutePlanWithJoinedResultAsync()
{
Expand Down
71 changes: 54 additions & 17 deletions dotnet/src/SemanticKernel/Planning/Plan.cs
Original file line number Diff line number Diff line change
Expand Up @@ -470,46 +470,83 @@ private SKContext UpdateContextWithOutputs(SKContext context)
/// <returns>The context variables for the next step in the plan.</returns>
private ContextVariables GetNextStepVariables(ContextVariables variables, Plan step)
{
// If the current step is passing to another plan, we set the default input to an empty string.
// Otherwise, we use the description from the current plan as the default input.
// We then set the input to the value from the SKContext, or the input from the Plan.State, or the default input.
var defaultInput = step.Steps.Count > 0 ? string.Empty : this.Description ?? string.Empty;
var planInput = string.IsNullOrEmpty(variables.Input) ? this.State.Input : variables.Input;
var stepInput = string.IsNullOrEmpty(planInput) ? defaultInput : planInput;
var stepVariables = new ContextVariables(stepInput);
// Priority for Input
// - Parameters (expand from variables if needed)
// - SKContext.Variables
// - Plan.State
// - Empty if sending to another plan
// - Plan.Description

var input = string.Empty;
if (!string.IsNullOrEmpty(step.Parameters.Input))
{
input = this.ExpandFromVariables(variables, step.Parameters.Input);
}
else if (!string.IsNullOrEmpty(variables.Input))
{
input = variables.Input;
}
else if (!string.IsNullOrEmpty(this.State.Input))
{
input = this.State.Input;
}
else if (step.Steps.Count > 0)
{
input = string.Empty;
}
else if (!string.IsNullOrEmpty(this.Description))
{
input = this.Description;
}

var stepVariables = new ContextVariables(input);

// Priority for remaining stepVariables is:
// - Parameters (pull from State by a key value)
// - Parameters (from context)
// - Parameters (from State)
// - Function Parameters (pull from variables or state by a key value)
// - Step Parameters (pull from variables or state by a key value)
var functionParameters = step.Describe();
foreach (var param in functionParameters.Parameters)
{
if (variables.Get(param.Name, out var value) && !string.IsNullOrEmpty(value))
if (param.Name.Equals(ContextVariables.MainKey, StringComparison.OrdinalIgnoreCase))
{
continue;
}

if (variables.Get(param.Name, out var value))
{
stepVariables.Set(param.Name, value);
}
else if (this.State.Get(param.Name, out value) && !string.IsNullOrEmpty(value))
else if (this.State.Get(param.Name, out value))
{
stepVariables.Set(param.Name, value);
}
}

foreach (var item in step.Parameters)
{
if (!string.IsNullOrEmpty(item.Value))
// Don't overwrite variable values that are already set
if (stepVariables.Get(item.Key, out _))
{
var value = this.ExpandFromVariables(variables, item.Value);
stepVariables.Set(item.Key, value);
continue;
}

var expandedValue = this.ExpandFromVariables(variables, item.Value);
if (!expandedValue.Equals(item.Value, StringComparison.OrdinalIgnoreCase))
{
stepVariables.Set(item.Key, expandedValue);
}
else if (variables.Get(item.Key, out var value) && !string.IsNullOrEmpty(value))
else if (variables.Get(item.Key, out var value))
{
stepVariables.Set(item.Key, value);
}
else if (this.State.Get(item.Key, out value) && !string.IsNullOrEmpty(value))
else if (this.State.Get(item.Key, out value))
{
stepVariables.Set(item.Key, value);
}
else
{
stepVariables.Set(item.Key, expandedValue);
lemillermicrosoft marked this conversation as resolved.
Show resolved Hide resolved
}
}

return stepVariables;
Expand Down