Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.Net - Add support for instruction templating on Agents #4486

Merged
merged 12 commits into from
Jan 8, 2024
12 changes: 8 additions & 4 deletions dotnet/samples/KernelSyntaxExamples/Example70_Agents.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ private static async Task RunSimpleChatAsync()
await ChatAsync(
"Agents.ParrotAgent.yaml", // Defined under ./Resources/Agents
plugin: null, // No plugin
arguments: new KernelArguments { { "count", 3 } },
"Fortune favors the bold.",
"I came, I saw, I conquered.",
"Practice makes perfect.");
Expand All @@ -77,6 +78,7 @@ private static async Task RunWithMethodFunctionsAsync()
await ChatAsync(
"Agents.ToolAgent.yaml", // Defined under ./Resources/Agents
plugin,
arguments: null,
"Hello",
"What is the special soup?",
"What is the special drink?",
Expand All @@ -95,14 +97,15 @@ private static async Task RunWithPromptFunctionsAsync()
var function = KernelFunctionFactory.CreateFromPrompt(
"Correct any misspelling or gramatical errors provided in input: {{$input}}",
functionName: "spellChecker",
description: "Correct the spelling for the user input."
);
description: "Correct the spelling for the user input.");

var plugin = KernelPluginFactory.CreateFromFunctions("spelling", "Spelling functions", new[] { function });

// Call the common chat-loop
await ChatAsync(
"Agents.ToolAgent.yaml", // Defined under ./Resources/Agents
plugin,
arguments: null,
"Hello",
"Is this spelled correctly: exercize",
"What is the special soup?",
Expand All @@ -126,7 +129,7 @@ private static async Task RunAsFunctionAsync()
try
{
// Invoke agent plugin.
var response = await agent.AsPlugin().InvokeAsync("Practice makes perfect.");
var response = await agent.AsPlugin().InvokeAsync("Practice makes perfect.", new KernelArguments { { "count", 2 } });

// Display result.
Console.WriteLine(response ?? $"No response from agent: {agent.Id}");
Expand All @@ -149,6 +152,7 @@ private static async Task RunAsFunctionAsync()
private static async Task ChatAsync(
string resourcePath,
KernelPlugin? plugin = null,
KernelArguments? arguments = null,
params string[] messages)
{
// Read agent resource
Expand All @@ -170,7 +174,7 @@ private static async Task RunAsFunctionAsync()
Console.WriteLine($"[{agent.Id}]");

// Process each user message and agent response.
foreach (var response in messages.Select(m => thread.InvokeAsync(agent, m)))
foreach (var response in messages.Select(m => thread.InvokeAsync(agent, m, arguments)))
{
await foreach (var message in response)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
name: Parrot
instructions: |
Repeat the user message in the voice of a pirate and then end with a parrot sound.
template_format: semantic-kernel
template: |
Repeat the user message in the voice of a pirate and then end with {{$count}} parrot sounds.
description: A fun chat bot that repeats the user message in the voice of a pirate.
input_variables:
- name: count
description: The number of parrot sounds.
is_required: true
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
name: ToolRunner
instructions: |
template_format: semantic-kernel
template: |
Respond to the user using the single best tool.
If no tool is appropriate, let the user know you only provide responses using tools.
When reporting a tool result, start with, "The tool I used informed me that"
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/Experimental/Agents/AgentBuilder.Static.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,6 @@ public partial class AgentBuilder
var restContext = new OpenAIRestContext(apiKey);
var resultModel = await restContext.GetAssistantModelAsync(agentId, cancellationToken).ConfigureAwait(false);

return new Agent(resultModel, restContext, plugins);
return new Agent(resultModel, null, restContext, plugins);
}
}
23 changes: 15 additions & 8 deletions dotnet/src/Experimental/Agents/AgentBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
using Microsoft.SemanticKernel.Experimental.Agents.Exceptions;
using Microsoft.SemanticKernel.Experimental.Agents.Internal;
using Microsoft.SemanticKernel.Experimental.Agents.Models;
using YamlDotNet.Serialization;

namespace Microsoft.SemanticKernel.Experimental.Agents;

Expand All @@ -23,6 +22,7 @@ public partial class AgentBuilder

private string? _apiKey;
private Func<HttpClient>? _httpClientProvider;
private PromptTemplateConfig? _config;

/// <summary>
/// Initializes a new instance of the <see cref="AgentBuilder"/> class.
Expand Down Expand Up @@ -54,6 +54,7 @@ public async Task<IAgent> BuildAsync(CancellationToken cancellationToken = defau
await Agent.CreateAsync(
new OpenAIRestContext(this._apiKey!, this._httpClientProvider),
this._model,
this._config,
this._plugins,
cancellationToken).ConfigureAwait(false);
}
Expand All @@ -77,15 +78,21 @@ public AgentBuilder WithOpenAIChatCompletion(string model, string apiKey)
/// <returns><see cref="AgentBuilder"/> instance for fluid expression.</returns>
public AgentBuilder FromTemplate(string template)
{
var deserializer = new DeserializerBuilder().Build();
this._config = KernelFunctionYaml.ToPromptTemplateConfig(template);

var agentKernelModel = deserializer.Deserialize<AgentConfigurationModel>(template);
this.WithInstructions(this._config.Template.Trim());

return
this
.WithInstructions(agentKernelModel.Instructions.Trim())
.WithName(agentKernelModel.Name.Trim())
.WithDescription(agentKernelModel.Description.Trim());
if (!string.IsNullOrWhiteSpace(this._config.Name))
{
this.WithName(this._config.Name?.Trim());
}

if (!string.IsNullOrWhiteSpace(this._config.Description))
{
this.WithDescription(this._config.Description?.Trim());
}

return this;
}

/// <summary>
Expand Down
19 changes: 17 additions & 2 deletions dotnet/src/Experimental/Agents/AgentPlugin.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,23 @@ protected AgentPlugin(string name, string? description = null)
/// <returns>The agent response</returns>
public async Task<string> InvokeAsync(string input, CancellationToken cancellationToken = default)
{
var args = new KernelArguments { { "input", input } };
var result = await this.First().InvokeAsync(this.Agent.Kernel, args, cancellationToken).ConfigureAwait(false);
return await this.InvokeAsync(input, arguments: null, cancellationToken).ConfigureAwait(false);
}

/// <summary>
/// Invoke plugin with user input
/// </summary>
/// <param name="input">The user input</param>
/// <param name="arguments">The arguments</param>
/// <param name="cancellationToken">A cancel token</param>
/// <returns>The agent response</returns>
public async Task<string> InvokeAsync(string input, KernelArguments? arguments, CancellationToken cancellationToken = default)
{
arguments ??= new KernelArguments();

arguments["input"] = input;

var result = await this.First().InvokeAsync(this.Agent.Kernel, arguments, cancellationToken).ConfigureAwait(false);
var response = result.GetValue<AgentResponse>()!;

return response.Message;
Expand Down
3 changes: 2 additions & 1 deletion dotnet/src/Experimental/Agents/Experimental.Agents.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
</ItemGroup>
<ItemGroup>
<PackageReference Include="System.Linq.Async" />
<PackageReference Include="YamlDotNet" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\Connectors\Connectors.OpenAI\Connectors.OpenAI.csproj" />
<ProjectReference Include="..\..\Extensions\PromptTemplates.Handlebars\PromptTemplates.Handlebars.csproj" />
<ProjectReference Include="..\..\Functions\Functions.Yaml\Functions.Yaml.csproj" />
<ProjectReference Include="..\..\SemanticKernel.Core\SemanticKernel.Core.csproj" />
</ItemGroup>
<ItemGroup>
Expand Down
5 changes: 5 additions & 0 deletions dotnet/src/Experimental/Agents/IAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ public interface IAgent
/// </summary>
public AgentPlugin AsPlugin();

/// <summary>
/// Expose the agent internally as a prompt-template
/// </summary>
internal IPromptTemplate AsPromptTemplate();

/// <summary>
/// Creates a new agent chat thread.
/// </summary>
Expand Down
8 changes: 5 additions & 3 deletions dotnet/src/Experimental/Agents/IAgentExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,19 @@ public static class IAgentExtensions
/// </summary>
/// <param name="agent">the agent</param>
/// <param name="input">the user input</param>
/// <param name="cancellationToken">a cancel token</param>
/// <returns>chat messages</returns>
/// <param name="arguments">Optional arguments for parameterized instructions</param>
/// <param name="cancellationToken">Optional cancellation token</param>
/// <returns>Chat messages</returns>
public static async IAsyncEnumerable<IChatMessage> InvokeAsync(
this IAgent agent,
string input,
KernelArguments? arguments = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
IAgentThread thread = await agent.NewThreadAsync(cancellationToken).ConfigureAwait(false);
try
{
await foreach (var message in thread.InvokeAsync(agent, input, cancellationToken))
await foreach (var message in thread.InvokeAsync(agent, input, arguments, cancellationToken))
{
yield return message;
}
Expand Down
6 changes: 4 additions & 2 deletions dotnet/src/Experimental/Agents/IAgentThread.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,20 @@ public interface IAgentThread
/// Advance the thread with the specified agent.
/// </summary>
/// <param name="agent">An agent instance.</param>
/// <param name="arguments">Optional arguments for parameterized instructions</param>
/// <param name="cancellationToken">A cancellation token</param>
/// <returns>The resulting agent message(s)</returns>
IAsyncEnumerable<IChatMessage> InvokeAsync(IAgent agent, CancellationToken cancellationToken = default);
IAsyncEnumerable<IChatMessage> InvokeAsync(IAgent agent, KernelArguments? arguments = null, CancellationToken cancellationToken = default);

/// <summary>
/// Advance the thread with the specified agent.
/// </summary>
/// <param name="agent">An agent instance.</param>
/// <param name="userMessage">The user message</param>
/// <param name="arguments">Optional arguments for parameterized instructions</param>
/// <param name="cancellationToken">A cancellation token</param>
/// <returns>The resulting agent message(s)</returns>
IAsyncEnumerable<IChatMessage> InvokeAsync(IAgent agent, string userMessage, CancellationToken cancellationToken = default);
IAsyncEnumerable<IChatMessage> InvokeAsync(IAgent agent, string userMessage, KernelArguments? arguments = null, CancellationToken cancellationToken = default);

/// <summary>
/// Delete current thread. Terminal state - Unable to perform any
Expand Down
42 changes: 37 additions & 5 deletions dotnet/src/Experimental/Agents/Internal/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using System.Threading.Tasks;
using Microsoft.SemanticKernel.Experimental.Agents.Exceptions;
using Microsoft.SemanticKernel.Experimental.Agents.Models;
using Microsoft.SemanticKernel.PromptTemplates.Handlebars;

namespace Microsoft.SemanticKernel.Experimental.Agents.Internal;

Expand Down Expand Up @@ -50,9 +51,16 @@ internal sealed class Agent : IAgent
public string Instructions => this._model.Instructions;

private static readonly Regex s_removeInvalidCharsRegex = new("[^0-9A-Za-z-]");
private static readonly Dictionary<string, IPromptTemplateFactory> s_templateFactories =
new(StringComparer.OrdinalIgnoreCase)
{
{ PromptTemplateConfig.SemanticKernelTemplateFormat, new KernelPromptTemplateFactory() },
{ HandlebarsPromptTemplateFactory.HandlebarsTemplateFormat, new HandlebarsPromptTemplateFactory() },
};

private readonly OpenAIRestContext _restContext;
private readonly AssistantModel _model;
private readonly IPromptTemplate _promptTemplate;

private AgentPlugin? _agentPlugin;
private bool _isDeleted;
Expand All @@ -62,33 +70,45 @@ internal sealed class Agent : IAgent
/// </summary>
/// <param name="restContext">A context for accessing OpenAI REST endpoint</param>
/// <param name="assistantModel">The assistant definition</param>
/// <param name="config">The template config</param>
/// <param name="plugins">Plugins to initialize as agent tools</param>
/// <param name="cancellationToken">A cancellation token</param>
/// <returns>An initialized <see cref="Agent"> instance.</see></returns>
public static async Task<IAgent> CreateAsync(
OpenAIRestContext restContext,
AssistantModel assistantModel,
PromptTemplateConfig? config,
IEnumerable<KernelPlugin>? plugins = null,
CancellationToken cancellationToken = default)
{
var resultModel = await restContext.CreateAssistantModelAsync(assistantModel, cancellationToken).ConfigureAwait(false);

return new Agent(resultModel, restContext, plugins);
return new Agent(resultModel, config, restContext, plugins);
}

/// <summary>
/// Initializes a new instance of the <see cref="Agent"/> class.
/// </summary>
internal Agent(
AssistantModel assistantModel,
PromptTemplateConfig? config,
OpenAIRestContext restContext,
IEnumerable<KernelPlugin>? plugins = null)
{
config ??=
new PromptTemplateConfig
{
Name = assistantModel.Name,
Description = assistantModel.Description,
Template = assistantModel.Instructions,
};

this._model = assistantModel;
this._restContext = restContext;
this._promptTemplate = this.DefinePromptTemplate(config);

IKernelBuilder builder = Kernel.CreateBuilder();
;

this.Kernel =
Kernel
.CreateBuilder()
Expand All @@ -103,6 +123,8 @@ internal sealed class Agent : IAgent

public AgentPlugin AsPlugin() => this._agentPlugin ??= this.DefinePlugin();

public IPromptTemplate AsPromptTemplate() => this._promptTemplate;

/// <inheritdoc/>
public Task<IAgentThread> NewThreadAsync(CancellationToken cancellationToken = default)
{
Expand Down Expand Up @@ -146,19 +168,19 @@ public async Task DeleteAsync(CancellationToken cancellationToken = default)
/// Marshal thread run through <see cref="KernelFunction"/> interface.
/// </summary>
/// <param name="input">The user input</param>
/// <param name="arguments">Arguments for parameterized instructions</param>
/// <param name="cancellationToken">A cancellation token.</param>
/// <returns>An agent response (<see cref="AgentResponse"/></returns>
private async Task<AgentResponse> AskAsync(
[Description("The user message provided to the agent.")]
string input,
KernelArguments arguments,
CancellationToken cancellationToken = default)
{
var thread = await this.NewThreadAsync(cancellationToken).ConfigureAwait(false);
try
{
await thread.AddUserMessageAsync(input, cancellationToken).ConfigureAwait(false);

var messages = await thread.InvokeAsync(this, cancellationToken).ToArrayAsync(cancellationToken).ConfigureAwait(false);
var messages = await thread.InvokeAsync(this, input, arguments, cancellationToken).ToArrayAsync(cancellationToken).ConfigureAwait(false);
var response =
new AgentResponse
{
Expand All @@ -181,6 +203,16 @@ private AgentPluginImpl DefinePlugin()
return new AgentPluginImpl(this, functionAsk);
}

private IPromptTemplate DefinePromptTemplate(PromptTemplateConfig config)
{
if (!s_templateFactories.TryGetValue(config.TemplateFormat, out var factory))
{
factory = new KernelPromptTemplateFactory();
}

return factory.Create(config);
}

private void ThrowIfDeleted()
{
if (this._isDeleted)
Expand Down
14 changes: 10 additions & 4 deletions dotnet/src/Experimental/Agents/Internal/ChatThread.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,13 @@ public async Task<IChatMessage> AddUserMessageAsync(string message, Cancellation
}

/// <inheritdoc/>
public IAsyncEnumerable<IChatMessage> InvokeAsync(IAgent agent, CancellationToken cancellationToken)
public IAsyncEnumerable<IChatMessage> InvokeAsync(IAgent agent, KernelArguments? arguments = null, CancellationToken cancellationToken = default)
{
return this.InvokeAsync(agent, string.Empty, cancellationToken);
return this.InvokeAsync(agent, string.Empty, arguments, cancellationToken);
}

/// <inheritdoc/>
public async IAsyncEnumerable<IChatMessage> InvokeAsync(IAgent agent, string userMessage, [EnumeratorCancellation] CancellationToken cancellationToken)
public async IAsyncEnumerable<IChatMessage> InvokeAsync(IAgent agent, string userMessage, KernelArguments? arguments = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
this.ThrowIfDeleted();

Expand All @@ -80,9 +80,15 @@ public async IAsyncEnumerable<IChatMessage> InvokeAsync(IAgent agent, string use
yield return await this.AddUserMessageAsync(userMessage, cancellationToken).ConfigureAwait(false);
}

// Define tools as part of the run definition, since there's no enforcement that an agent
// is initialized with the same tools every time.
var tools = agent.Plugins.SelectMany(p => p.Select(f => f.ToToolModel(p.Name)));
var runModel = await this._restContext.CreateRunAsync(this.Id, agent.Id, agent.Instructions, tools, cancellationToken).ConfigureAwait(false);

// Finalize prompt / agent instructions using provided parameters.
var instructions = await agent.AsPromptTemplate().RenderAsync(agent.Kernel, arguments, cancellationToken).ConfigureAwait(false);

// Create run using templated prompt
var runModel = await this._restContext.CreateRunAsync(this.Id, agent.Id, instructions, tools, cancellationToken).ConfigureAwait(false);
var run = new ChatRun(runModel, agent.Kernel, this._restContext);
var results = await run.GetResultAsync(cancellationToken).ConfigureAwait(false);

Expand Down
Loading
Loading