Skip to content

Commit

Permalink
.Net: Don't limit [KernelFunction] to public methods (#6206)
Browse files Browse the repository at this point in the history
A developer already need to opt-in a method on a plugin to being part of
the plugin by specifying the [KernelFunction] attribute; requiring that
method to also be public is superfluous, and means that a type's plugin
surface area must be a subset of its public surface area. That prohibits
patterns where a type wants to syntactically be a plugin but not expose
those APIs via its .NET public surface area.

(Curious to see if folks think this is controversial.)
  • Loading branch information
stephentoub committed May 13, 2024
1 parent 8a8cd95 commit 34f201a
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@ namespace Microsoft.SemanticKernel;
/// </summary>
/// <remarks>
/// <para>
/// When the system imports functions from an object, it searches for all public methods tagged with this attribute.
/// When the system imports functions from an object, it searches for all methods tagged with this attribute.
/// If a method is not tagged with this attribute, it may still be imported directly via a <see cref="Delegate"/>
/// or <see cref="MethodInfo"/> referencing the method directly.
/// </para>
/// <para>
/// Method visibility does not impact whether a method may be imported. Any method tagged with this attribute, regardless
/// of whether it's public or not, will be imported.
/// </para>
/// <para>
/// A description of the method should be supplied using the <see cref="DescriptionAttribute"/>.
/// That description will be used both with LLM prompts and embedding comparisons; the quality of
/// the description affects the planner's ability to reason about complex tasks. A <see cref="DescriptionAttribute"/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public static class KernelPluginFactory
/// </param>
/// <returns>A <see cref="KernelPlugin"/> containing <see cref="KernelFunction"/>s for all relevant members of <typeparamref name="T"/>.</returns>
/// <remarks>
/// Public methods decorated with <see cref="KernelFunctionAttribute"/> will be included in the plugin.
/// Methods decorated with <see cref="KernelFunctionAttribute"/> will be included in the plugin.
/// Attributed methods must all have different names; overloads are not supported.
/// </remarks>
public static KernelPlugin CreateFromType<T>(string? pluginName = null, IServiceProvider? serviceProvider = null)
Expand All @@ -42,7 +42,7 @@ public static KernelPlugin CreateFromType<T>(string? pluginName = null, IService
/// <param name="loggerFactory">The <see cref="ILoggerFactory"/> to use for logging. If null, no logging will be performed.</param>
/// <returns>A <see cref="KernelPlugin"/> containing <see cref="KernelFunction"/>s for all relevant members of <paramref name="target"/>.</returns>
/// <remarks>
/// Public methods decorated with <see cref="KernelFunctionAttribute"/> will be included in the plugin.
/// Methods decorated with <see cref="KernelFunctionAttribute"/> will be included in the plugin.
/// Attributed methods must all have different names; overloads are not supported.
/// </remarks>
public static KernelPlugin CreateFromObject(object target, string? pluginName = null, ILoggerFactory? loggerFactory = null)
Expand All @@ -52,7 +52,7 @@ public static KernelPlugin CreateFromObject(object target, string? pluginName =
pluginName ??= target.GetType().Name;
Verify.ValidPluginName(pluginName);

MethodInfo[] methods = target.GetType().GetMethods(BindingFlags.Public | BindingFlags.Instance | BindingFlags.Static);
MethodInfo[] methods = target.GetType().GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static);

// Filter out non-KernelFunctions and fail if two functions have the same name (with or without the same casing).
var functions = new List<KernelFunction>();
Expand All @@ -65,7 +65,7 @@ public static KernelPlugin CreateFromObject(object target, string? pluginName =
}
if (functions.Count == 0)
{
throw new ArgumentException($"The {target.GetType()} instance doesn't expose any public [KernelFunction]-attributed methods.");
throw new ArgumentException($"The {target.GetType()} instance doesn't implement any [KernelFunction]-attributed methods.");
}

if (loggerFactory?.CreateLogger(target.GetType()) is ILogger logger &&
Expand Down
24 changes: 16 additions & 8 deletions dotnet/src/SemanticKernel.Core/KernelExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ public static class KernelExtensions
/// </param>
/// <returns>A <see cref="KernelPlugin"/> containing <see cref="KernelFunction"/>s for all relevant members of <typeparamref name="T"/>.</returns>
/// <remarks>
/// Public methods that have the <see cref="KernelFunctionFromPrompt"/> attribute will be included in the plugin.
/// Methods that have the <see cref="KernelFunctionAttribute"/> attribute will be included in the plugin.
/// See <see cref="KernelFunctionAttribute"/> attribute for details.
/// </remarks>
public static KernelPlugin CreatePluginFromType<T>(this Kernel kernel, string? pluginName = null)
{
Expand All @@ -159,7 +160,8 @@ public static KernelPlugin CreatePluginFromType<T>(this Kernel kernel, string? p
/// </param>
/// <returns>A <see cref="KernelPlugin"/> containing <see cref="KernelFunction"/>s for all relevant members of <paramref name="target"/>.</returns>
/// <remarks>
/// Public methods that have the <see cref="KernelFunctionFromPrompt"/> attribute will be included in the plugin.
/// Methods that have the <see cref="KernelFunctionAttribute"/> attribute will be included in the plugin.
/// See <see cref="KernelFunctionAttribute"/> attribute for details.
/// </remarks>
public static KernelPlugin CreatePluginFromObject(this Kernel kernel, object target, string? pluginName = null)
{
Expand Down Expand Up @@ -209,7 +211,8 @@ public static KernelPlugin CreatePluginFromFunctions(this Kernel kernel, string
/// </param>
/// <returns>A <see cref="KernelPlugin"/> containing <see cref="KernelFunction"/>s for all relevant members of <typeparamref name="T"/>.</returns>
/// <remarks>
/// Public methods that have the <see cref="KernelFunctionFromPrompt"/> attribute will be included in the plugin.
/// Methods that have the <see cref="KernelFunctionAttribute"/> attribute will be included in the plugin.
/// See <see cref="KernelFunctionAttribute"/> attribute for details.
/// </remarks>
public static KernelPlugin ImportPluginFromType<T>(this Kernel kernel, string? pluginName = null)
{
Expand All @@ -227,7 +230,8 @@ public static KernelPlugin ImportPluginFromType<T>(this Kernel kernel, string? p
/// <param name="serviceProvider">Service provider from which to resolve dependencies, such as <see cref="ILoggerFactory"/>.</param>
/// <returns>A <see cref="KernelPlugin"/> containing <see cref="KernelFunction"/>s for all relevant members of <typeparamref name="T"/>.</returns>
/// <remarks>
/// Public methods that have the <see cref="KernelFunctionFromPrompt"/> attribute will be included in the plugin.
/// Methods that have the <see cref="KernelFunctionAttribute"/> attribute will be included in the plugin.
/// See <see cref="KernelFunctionAttribute"/> attribute for details.
/// </remarks>
public static KernelPlugin AddFromType<T>(this ICollection<KernelPlugin> plugins, string? pluginName = null, IServiceProvider? serviceProvider = null)
{
Expand All @@ -246,7 +250,8 @@ public static KernelPlugin AddFromType<T>(this ICollection<KernelPlugin> plugins
/// </param>
/// <returns>The same instance as <paramref name="plugins"/>.</returns>
/// <remarks>
/// Public methods that have the <see cref="KernelFunctionFromPrompt"/> attribute will be included in the plugin.
/// Methods that have the <see cref="KernelFunctionAttribute"/> attribute will be included in the plugin.
/// See <see cref="KernelFunctionAttribute"/> attribute for details.
/// </remarks>
public static IKernelBuilderPlugins AddFromType<T>(this IKernelBuilderPlugins plugins, string? pluginName = null)
{
Expand Down Expand Up @@ -281,7 +286,8 @@ public static IKernelBuilderPlugins Add(this IKernelBuilderPlugins plugins, Kern
/// </param>
/// <returns>A <see cref="KernelPlugin"/> containing <see cref="KernelFunction"/>s for all relevant members of <paramref name="target"/>.</returns>
/// <remarks>
/// Public methods that have the <see cref="KernelFunctionFromPrompt"/> attribute will be included in the plugin.
/// Methods that have the <see cref="KernelFunctionAttribute"/> attribute will be included in the plugin.
/// See <see cref="KernelFunctionAttribute"/> attribute for details.
/// </remarks>
public static KernelPlugin ImportPluginFromObject(this Kernel kernel, object target, string? pluginName = null)
{
Expand All @@ -299,7 +305,8 @@ public static KernelPlugin ImportPluginFromObject(this Kernel kernel, object tar
/// <param name="serviceProvider">Service provider from which to resolve dependencies, such as <see cref="ILoggerFactory"/>.</param>
/// <returns>A <see cref="KernelPlugin"/> containing <see cref="KernelFunction"/>s for all relevant members of <paramref name="target"/>.</returns>
/// <remarks>
/// Public methods that have the <see cref="KernelFunctionFromPrompt"/> attribute will be included in the plugin.
/// Methods that have the <see cref="KernelFunctionAttribute"/> attribute will be included in the plugin.
/// See <see cref="KernelFunctionAttribute"/> attribute for details.
/// </remarks>
public static KernelPlugin AddFromObject(this ICollection<KernelPlugin> plugins, object target, string? pluginName = null, IServiceProvider? serviceProvider = null)
{
Expand All @@ -318,7 +325,8 @@ public static KernelPlugin AddFromObject(this ICollection<KernelPlugin> plugins,
/// </param>
/// <returns>The same instance as <paramref name="plugins"/>.</returns>
/// <remarks>
/// Public methods that have the <see cref="KernelFunctionFromPrompt"/> attribute will be included in the plugin.
/// Methods that have the <see cref="KernelFunctionAttribute"/> attribute will be included in the plugin.
/// See <see cref="KernelFunctionAttribute"/> attribute for details.
/// </remarks>
public static IKernelBuilderPlugins AddFromObject(this IKernelBuilderPlugins plugins, object target, string? pluginName = null)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ public void ItDoesntThrowForValidFunctionsViaDelegate()
// Arrange
var pluginInstance = new LocalExamplePlugin();
MethodInfo[] methods = pluginInstance.GetType()
.GetMethods(BindingFlags.Static | BindingFlags.Instance | BindingFlags.Public | BindingFlags.InvokeMethod)
.Where(m => m.Name is not "GetType" and not "Equals" and not "GetHashCode" and not "ToString")
.GetMethods(BindingFlags.Static | BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.InvokeMethod)
.Where(m => m.Name is not ("GetType" or "Equals" or "GetHashCode" or "ToString" or "Finalize" or "MemberwiseClone"))
.ToArray();

KernelFunction[] functions = (from method in methods select KernelFunctionFactory.CreateFromMethod(method, pluginInstance, "plugin")).ToArray();
Expand All @@ -43,8 +43,8 @@ public void ItDoesNotThrowForValidFunctionsViaPlugin()
// Arrange
var pluginInstance = new LocalExamplePlugin();
MethodInfo[] methods = pluginInstance.GetType()
.GetMethods(BindingFlags.Static | BindingFlags.Instance | BindingFlags.Public | BindingFlags.InvokeMethod)
.Where(m => m.Name is not "GetType" and not "Equals" and not "GetHashCode" and not "ToString")
.GetMethods(BindingFlags.Static | BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.InvokeMethod)
.Where(m => m.Name is not ("GetType" or "Equals" or "GetHashCode" or "ToString" or "Finalize" or "MemberwiseClone"))
.ToArray();

KernelFunction[] functions = [.. KernelPluginFactory.CreateFromObject(pluginInstance)];
Expand Down Expand Up @@ -329,13 +329,13 @@ public string Type05(string input)
}

[KernelFunction]
public string? Type05Nullable(string? input = null)
private string? Type05Nullable(string? input = null)
{
return "";
}

[KernelFunction]
public string? Type05EmptyDefault(string? input = "")
internal string? Type05EmptyDefault(string? input = "")
{
return "";
}
Expand Down

0 comments on commit 34f201a

Please sign in to comment.