Skip to content

Commit

Permalink
.Net: Service ID in execution_settings should be promoted as the key …
Browse files Browse the repository at this point in the history
…of the list/dictionary (#4211)

### Motivation and Context

Resolves #3981 

The code will be changed further to support
#4212

### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [ ] The code builds clean without any errors or warnings
- [ ] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [ ] All unit tests pass, and I have added new tests where possible
- [ ] I didn't break anyone 😄
  • Loading branch information
markwallace-microsoft committed Dec 13, 2023
1 parent 6ad5d26 commit 3a86d7c
Show file tree
Hide file tree
Showing 70 changed files with 275 additions and 231 deletions.
4 changes: 2 additions & 2 deletions dotnet/samples/KernelSyntaxExamples/Example61_MultipleLLMs.cs
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,10 @@ public static async Task RunByFirstModelIdAsync(Kernel kernel, params string[] m

var prompt = "Hello AI, what can you do for me?";

var modelSettings = new List<PromptExecutionSettings>();
var modelSettings = new Dictionary<string, PromptExecutionSettings>();
foreach (var modelId in modelIds)
{
modelSettings.Add(new PromptExecutionSettings() { ModelId = modelId });
modelSettings.Add(modelId, new PromptExecutionSettings() { ModelId = modelId });
}
var promptConfig = new PromptTemplateConfig(prompt) { Name = "HelloAI", ExecutionSettings = modelSettings };

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ public void ItCreatesOpenAIExecutionSettingsFromExtraPropertiesSnakeCase()
{ "stop_sequences", new [] { "foo", "bar" } },
{ "chat_system_prompt", "chat system prompt" },
{ "max_tokens", 128 },
{ "service_id", "service" },
{ "token_selection_biases", new Dictionary<int, int>() { { 1, 2 }, { 3, 4 } } },
{ "seed", 123456 },
}
Expand Down Expand Up @@ -124,7 +123,6 @@ public void ItCreatesOpenAIExecutionSettingsFromExtraPropertiesAsStrings()
{ "stop_sequences", new [] { "foo", "bar" } },
{ "chat_system_prompt", "chat system prompt" },
{ "max_tokens", "128" },
{ "service_id", "service" },
{ "token_selection_biases", new Dictionary<string, string>() { { "1", "2" }, { "3", "4" } } }
}
};
Expand All @@ -149,7 +147,6 @@ public void ItCreatesOpenAIExecutionSettingsFromJsonSnakeCase()
""stop_sequences"": [ ""foo"", ""bar"" ],
""chat_system_prompt"": ""chat system prompt"",
""token_selection_biases"": { ""1"": 2, ""3"": 4 },
""service_id"": ""service"",
""max_tokens"": 128
}";
var actualSettings = JsonSerializer.Deserialize<PromptExecutionSettings>(json);
Expand All @@ -172,7 +169,6 @@ private static void AssertExecutionSettings(OpenAIPromptExecutionSettings execut
Assert.Equal(new string[] { "foo", "bar" }, executionSettings.StopSequences);
Assert.Equal("chat system prompt", executionSettings.ChatSystemPrompt);
Assert.Equal(new Dictionary<int, int>() { { 1, 2 }, { 3, 4 } }, executionSettings.TokenSelectionBiases);
Assert.Equal("service", executionSettings.ServiceId);
Assert.Equal(128, executionSettings.MaxTokens);
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using System.Text.Json;
using Markdig;
using Markdig.Syntax;
Expand Down Expand Up @@ -52,10 +53,13 @@ internal static PromptTemplateConfig CreateFromPromptMarkdown(string text, strin

case "sk.execution_settings":
var modelSettings = codeBlock.Lines.ToString();
var executionSettings = JsonSerializer.Deserialize<PromptExecutionSettings>(modelSettings);
if (executionSettings is not null)
var settingsDictionary = JsonSerializer.Deserialize<Dictionary<string, PromptExecutionSettings>>(modelSettings);
if (settingsDictionary is not null)
{
promptFunctionModel.ExecutionSettings.Add(executionSettings);
foreach (var keyValue in settingsDictionary)
{
promptFunctionModel.ExecutionSettings.Add(keyValue.Key, keyValue.Value);
}
}
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ public void ItShouldCreatePromptFunctionConfigFromMarkdown()
Assert.Equal("TellMeAbout", model.Name);
Assert.Equal("Hello AI, tell me about {{$input}}", model.Template);
Assert.Equal(2, model.ExecutionSettings.Count);
Assert.Equal("gpt4", model.ExecutionSettings[0].ModelId);
Assert.Equal("gpt3.5", model.ExecutionSettings[1].ModelId);
Assert.Equal("gpt4", model.ExecutionSettings["service1"].ModelId);
Assert.Equal("gpt3.5", model.ExecutionSettings["service2"].ModelId);
}

[Fact]
Expand All @@ -45,15 +45,19 @@ public void ItShouldCreatePromptFunctionFromMarkdown()
These are AI execution settings
```sk.execution_settings
{
""model_id"": ""gpt4"",
""temperature"": 0.7
""service1"" : {
""model_id"": ""gpt4"",
""temperature"": 0.7
}
}
```
These are more AI execution settings
```sk.execution_settings
{
""model_id"": ""gpt3.5"",
""temperature"": 0.8
""service2"" : {
""model_id"": ""gpt3.5"",
""temperature"": 0.8
}
}
```
";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public void ItShouldSupportCreatingOpenAIExecutionSettings()
var promptFunctionModel = deserializer.Deserialize<PromptTemplateConfig>(this._yaml);

// Act
var executionSettings = OpenAIPromptExecutionSettings.FromExecutionSettings(promptFunctionModel.ExecutionSettings[0]);
var executionSettings = OpenAIPromptExecutionSettings.FromExecutionSettings(promptFunctionModel.ExecutionSettings["service1"]);

// Assert
Assert.NotNull(executionSettings);
Expand Down Expand Up @@ -99,14 +99,16 @@ public void ItShouldSupportCreatingOpenAIExecutionSettings()
description: The language to generate the greeting in
default: English
execution_settings:
- model_id: gpt-4
service1:
model_id: gpt-4
temperature: 1.0
top_p: 0.0
presence_penalty: 0.0
frequency_penalty: 0.0
max_tokens: 256
stop_sequences: []
- model_id: gpt-3.5
service2:
model_id: gpt-3.5
temperature: 1.0
top_p: 0.0
presence_penalty: 0.0
Expand All @@ -128,14 +130,16 @@ public void ItShouldSupportCreatingOpenAIExecutionSettings()
description: The language to generate the greeting in
default: English
execution_settings:
- model_id: gpt-4
service1:
model_id: gpt-4
temperature: 1.0
top_p: 0.0
presence_penalty: 0.0
frequency_penalty: 0.0
max_tokens: 256
stop_sequences: []
- model_id: random-model
service2:
model_id: random-model
temperaturex: 1.0
top_q: 0.0
rando_penalty: 0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ public void ItShouldCreatePromptFunctionFromYamlWithCustomModelSettings()
Assert.Equal(2, semanticFunctionConfig.InputVariables.Count);
Assert.Equal("language", semanticFunctionConfig.InputVariables[1].Name);
Assert.Equal(2, semanticFunctionConfig.ExecutionSettings.Count);
Assert.Equal("gpt-3.5", semanticFunctionConfig.ExecutionSettings[1].ModelId);
Assert.Equal("gpt-4", semanticFunctionConfig.ExecutionSettings["service1"].ModelId);
Assert.Equal("gpt-3.5", semanticFunctionConfig.ExecutionSettings["service2"].ModelId);
}

private readonly string _yaml = @"
Expand All @@ -47,14 +48,16 @@ public void ItShouldCreatePromptFunctionFromYamlWithCustomModelSettings()
description: The language to generate the greeting in
default: English
execution_settings:
- model_id: gpt-4
service1:
model_id: gpt-4
temperature: 1.0
top_p: 0.0
presence_penalty: 0.0
frequency_penalty: 0.0
max_tokens: 256
stop_sequences: []
- model_id: gpt-3.5
service2:
model_id: gpt-3.5
temperature: 1.0
top_p: 0.0
presence_penalty: 0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,12 +437,11 @@ public async Task MultipleServiceLoadPromptConfigTestAsync()
var azurePromptModel = PromptTemplateConfig.FromJson(
@"{
""name"": ""FishMarket2"",
""execution_settings"": [
{
""max_tokens"": 256,
""service_id"": ""azure-text-davinci-003""
""execution_settings"": {
""azure-text-davinci-003"": {
""max_tokens"": 256
}
]
}
}");
azurePromptModel.Template = prompt;

Expand Down
3 changes: 2 additions & 1 deletion dotnet/src/IntegrationTests/prompts/GenerateStory.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ input_variables:
output_variable:
description: The generated story.
execution_settings:
- temperature: 0.6
default:
temperature: 0.6
11 changes: 7 additions & 4 deletions dotnet/src/IntegrationTests/prompts/GenerateStoryHandlebars.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@ input_variables:
output_variable:
description: The generated story.
execution_settings:
- model_id: gpt-4
service1:
model_id: gpt-4
temperature: 0.6
- model_id: gpt-3
temperature: 0.4
- temperature: 0.5
service2:
model_id: gpt-3
temperature: 0.4
default:
temperature: 0.5
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ namespace Microsoft.SemanticKernel;
/// </summary>
public class PromptExecutionSettings
{
/// <summary>
/// Default service identifier.
/// </summary>
public const string DefaultServiceId = "default";

/// <summary>
/// Service identifier.
/// This identifies a service and is set when the AI service is registered.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public abstract class KernelFunction
/// <summary>
/// Gets the prompt execution settings.
/// </summary>
public IReadOnlyList<PromptExecutionSettings>? ExecutionSettings { get; }
internal IReadOnlyDictionary<string, PromptExecutionSettings>? ExecutionSettings { get; }

/// <summary>
/// Initializes a new instance of the <see cref="KernelFunction"/> class.
Expand All @@ -87,7 +87,7 @@ public abstract class KernelFunction
/// The <see cref="PromptExecutionSettings"/> to use with the function. These will apply unless they've been
/// overridden by settings passed into the invocation of the function.
/// </param>
internal KernelFunction(string name, string description, IReadOnlyList<KernelParameterMetadata> parameters, KernelReturnParameterMetadata? returnParameter = null, List<PromptExecutionSettings>? executionSettings = null)
internal KernelFunction(string name, string description, IReadOnlyList<KernelParameterMetadata> parameters, KernelReturnParameterMetadata? returnParameter = null, Dictionary<string, PromptExecutionSettings>? executionSettings = null)
{
Verify.NotNull(name);
Verify.ParametersUniqueness(parameters);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ public sealed class PromptTemplateConfig

/// <summary>Lazily-initialized input variables.</summary>
private List<InputVariable>? _inputVariables;
/// <summary>Lazily-initialized execution settings.</summary>
private List<PromptExecutionSettings>? _executionSettings;

/// <summary>Lazily-initialized execution settings. The key is the service id or "default" for the default execution settings.</summary>
private Dictionary<string, PromptExecutionSettings>? _executionSettings;

/// <summary>
/// Name of the kernel function.
Expand Down Expand Up @@ -72,7 +73,7 @@ public List<InputVariable> InputVariables
/// Prompt execution settings.
/// </summary>
[JsonPropertyName("execution_settings")]
public List<PromptExecutionSettings> ExecutionSettings
public Dictionary<string, PromptExecutionSettings> ExecutionSettings
{
get => this._executionSettings ??= new();
set
Expand All @@ -82,6 +83,11 @@ public List<PromptExecutionSettings> ExecutionSettings
}
}

/// <summary>
/// Default execution settings.
/// </summary>
public PromptExecutionSettings? DefaultExecutionSettings => this._executionSettings is not null && this._executionSettings.TryGetValue(PromptExecutionSettings.DefaultServiceId, out PromptExecutionSettings? settings) ? settings : null;

/// <summary>
/// Initializes a new instance of the <see cref="PromptTemplateConfig"/> class.
/// </summary>
Expand All @@ -97,6 +103,26 @@ public PromptTemplateConfig(string template)
this.Template = template;
}

/// <summary>
/// Adds the <see cref="PromptExecutionSettings"/> to the <see cref="ExecutionSettings"/> dictionary.
/// </summary>
/// <remarks>
/// The <see cref="PromptExecutionSettings.ServiceId"/> is used as the key if provided. Otherwise, the key is "default".
/// </remarks>
/// <param name="settings"></param>
public void AddExecutionSettings(PromptExecutionSettings settings)
{
Verify.NotNull(settings);

var key = settings.ServiceId ?? PromptExecutionSettings.DefaultServiceId;
if (this.ExecutionSettings.ContainsKey(key))
{
throw new ArgumentException($"Execution settings for service id '{key}' already exists.");
}

this.ExecutionSettings[key] = settings;
}

/// <summary>
/// Return the input variables metadata.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,13 @@ public static class AIServiceExtensions
var message = new StringBuilder($"Required service of type {typeof(T)} not registered.");
if (function.ExecutionSettings is not null)
{
string serviceIds = string.Join("|", function.ExecutionSettings.Select(model => model.ServiceId));
string serviceIds = string.Join("|", function.ExecutionSettings.Values.Select(model => model.ServiceId));
if (!string.IsNullOrEmpty(serviceIds))
{
message.Append($" Expected serviceIds: {serviceIds}.");
}

string modelIds = string.Join("|", function.ExecutionSettings.Select(model => model.ModelId));
string modelIds = string.Join("|", function.ExecutionSettings.Values.Select(model => model.ModelId));
if (!string.IsNullOrEmpty(modelIds))
{
message.Append($" Expected modelIds: {modelIds}.");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
Expand All @@ -23,7 +24,7 @@ internal sealed class OrderedAIServiceSelector : IAIServiceSelector
{
// Allow the execution settings from the kernel arguments to take precedence
var executionSettings = arguments.ExecutionSettings is not null
? new List<PromptExecutionSettings> { arguments.ExecutionSettings }
? new Dictionary<string, PromptExecutionSettings> { { arguments.ExecutionSettings.ServiceId ?? PromptExecutionSettings.DefaultServiceId, arguments.ExecutionSettings } }
: function.ExecutionSettings;
if (executionSettings is null || executionSettings.Count == 0)
{
Expand All @@ -37,18 +38,24 @@ internal sealed class OrderedAIServiceSelector : IAIServiceSelector
else
{
PromptExecutionSettings? defaultExecutionSettings = null;
foreach (var settings in executionSettings)
foreach (var keyValue in executionSettings)
{
if (!string.IsNullOrEmpty(settings.ServiceId))
var settings = keyValue.Value;
var serviceId = keyValue.Value.ServiceId ?? keyValue.Key;
if (string.IsNullOrEmpty(serviceId) || serviceId!.Equals(PromptExecutionSettings.DefaultServiceId, StringComparison.OrdinalIgnoreCase))
{
service = (kernel.Services as IKeyedServiceProvider)?.GetKeyedService<T>(settings.ServiceId);
defaultExecutionSettings ??= settings;
}
else if (!string.IsNullOrEmpty(serviceId))
{
service = (kernel.Services as IKeyedServiceProvider)?.GetKeyedService<T>(serviceId);
if (service is not null)
{
serviceSettings = settings;
return true;
}
}
else if (!string.IsNullOrEmpty(settings.ModelId))
if (!string.IsNullOrEmpty(settings.ModelId))
{
service = this.GetServiceByModelId<T>(kernel, settings.ModelId!);
if (service is not null)
Expand All @@ -57,11 +64,6 @@ internal sealed class OrderedAIServiceSelector : IAIServiceSelector
return true;
}
}
else
{
// First execution settings with empty or null service id is the default
defaultExecutionSettings ??= settings;
}
}

if (defaultExecutionSettings is not null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ internal sealed class KernelFunctionFromPrompt : KernelFunction

if (executionSettings is not null)
{
promptConfig.ExecutionSettings.Add(executionSettings);
promptConfig.ExecutionSettings.Add(executionSettings.ServiceId ?? PromptExecutionSettings.DefaultServiceId, executionSettings);
}

var factory = promptTemplateFactory ?? new KernelPromptTemplateFactory(loggerFactory);
Expand Down
Loading

0 comments on commit 3a86d7c

Please sign in to comment.