From 34f201ab2e0719988d330749ff28fdae4fb17080 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Mon, 13 May 2024 15:14:14 -0400 Subject: [PATCH] .Net: Don't limit [KernelFunction] to public methods (#6206) A developer already need to opt-in a method on a plugin to being part of the plugin by specifying the [KernelFunction] attribute; requiring that method to also be public is superfluous, and means that a type's plugin surface area must be a subset of its public surface area. That prohibits patterns where a type wants to syntactically be a plugin but not expose those APIs via its .NET public surface area. (Curious to see if folks think this is controversial.) --- .../Functions/KernelFunctionAttribute.cs | 6 ++++- .../Functions/KernelPluginFactory.cs | 8 +++---- .../SemanticKernel.Core/KernelExtensions.cs | 24 ++++++++++++------- .../KernelFunctionFromMethodTests2.cs | 12 +++++----- 4 files changed, 31 insertions(+), 19 deletions(-) diff --git a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunctionAttribute.cs b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunctionAttribute.cs index 927c68b70840..88654212e438 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunctionAttribute.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunctionAttribute.cs @@ -14,11 +14,15 @@ namespace Microsoft.SemanticKernel; /// /// /// -/// When the system imports functions from an object, it searches for all public methods tagged with this attribute. +/// When the system imports functions from an object, it searches for all methods tagged with this attribute. /// If a method is not tagged with this attribute, it may still be imported directly via a /// or referencing the method directly. /// /// +/// Method visibility does not impact whether a method may be imported. Any method tagged with this attribute, regardless +/// of whether it's public or not, will be imported. +/// +/// /// A description of the method should be supplied using the . /// That description will be used both with LLM prompts and embedding comparisons; the quality of /// the description affects the planner's ability to reason about complex tasks. A diff --git a/dotnet/src/SemanticKernel.Core/Functions/KernelPluginFactory.cs b/dotnet/src/SemanticKernel.Core/Functions/KernelPluginFactory.cs index 6ad62f9e122a..40ac04efe75c 100644 --- a/dotnet/src/SemanticKernel.Core/Functions/KernelPluginFactory.cs +++ b/dotnet/src/SemanticKernel.Core/Functions/KernelPluginFactory.cs @@ -25,7 +25,7 @@ public static class KernelPluginFactory /// /// A containing s for all relevant members of . /// - /// Public methods decorated with will be included in the plugin. + /// Methods decorated with will be included in the plugin. /// Attributed methods must all have different names; overloads are not supported. /// public static KernelPlugin CreateFromType(string? pluginName = null, IServiceProvider? serviceProvider = null) @@ -42,7 +42,7 @@ public static KernelPlugin CreateFromType(string? pluginName = null, IService /// The to use for logging. If null, no logging will be performed. /// A containing s for all relevant members of . /// - /// Public methods decorated with will be included in the plugin. + /// Methods decorated with will be included in the plugin. /// Attributed methods must all have different names; overloads are not supported. /// public static KernelPlugin CreateFromObject(object target, string? pluginName = null, ILoggerFactory? loggerFactory = null) @@ -52,7 +52,7 @@ public static KernelPlugin CreateFromObject(object target, string? pluginName = pluginName ??= target.GetType().Name; Verify.ValidPluginName(pluginName); - MethodInfo[] methods = target.GetType().GetMethods(BindingFlags.Public | BindingFlags.Instance | BindingFlags.Static); + MethodInfo[] methods = target.GetType().GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static); // Filter out non-KernelFunctions and fail if two functions have the same name (with or without the same casing). var functions = new List(); @@ -65,7 +65,7 @@ public static KernelPlugin CreateFromObject(object target, string? pluginName = } if (functions.Count == 0) { - throw new ArgumentException($"The {target.GetType()} instance doesn't expose any public [KernelFunction]-attributed methods."); + throw new ArgumentException($"The {target.GetType()} instance doesn't implement any [KernelFunction]-attributed methods."); } if (loggerFactory?.CreateLogger(target.GetType()) is ILogger logger && diff --git a/dotnet/src/SemanticKernel.Core/KernelExtensions.cs b/dotnet/src/SemanticKernel.Core/KernelExtensions.cs index 8ea72b82603a..a05340a64775 100644 --- a/dotnet/src/SemanticKernel.Core/KernelExtensions.cs +++ b/dotnet/src/SemanticKernel.Core/KernelExtensions.cs @@ -140,7 +140,8 @@ public static class KernelExtensions /// /// A containing s for all relevant members of . /// - /// Public methods that have the attribute will be included in the plugin. + /// Methods that have the attribute will be included in the plugin. + /// See attribute for details. /// public static KernelPlugin CreatePluginFromType(this Kernel kernel, string? pluginName = null) { @@ -159,7 +160,8 @@ public static KernelPlugin CreatePluginFromType(this Kernel kernel, string? p /// /// A containing s for all relevant members of . /// - /// Public methods that have the attribute will be included in the plugin. + /// Methods that have the attribute will be included in the plugin. + /// See attribute for details. /// public static KernelPlugin CreatePluginFromObject(this Kernel kernel, object target, string? pluginName = null) { @@ -209,7 +211,8 @@ public static KernelPlugin CreatePluginFromFunctions(this Kernel kernel, string /// /// A containing s for all relevant members of . /// - /// Public methods that have the attribute will be included in the plugin. + /// Methods that have the attribute will be included in the plugin. + /// See attribute for details. /// public static KernelPlugin ImportPluginFromType(this Kernel kernel, string? pluginName = null) { @@ -227,7 +230,8 @@ public static KernelPlugin ImportPluginFromType(this Kernel kernel, string? p /// Service provider from which to resolve dependencies, such as . /// A containing s for all relevant members of . /// - /// Public methods that have the attribute will be included in the plugin. + /// Methods that have the attribute will be included in the plugin. + /// See attribute for details. /// public static KernelPlugin AddFromType(this ICollection plugins, string? pluginName = null, IServiceProvider? serviceProvider = null) { @@ -246,7 +250,8 @@ public static KernelPlugin AddFromType(this ICollection plugins /// /// The same instance as . /// - /// Public methods that have the attribute will be included in the plugin. + /// Methods that have the attribute will be included in the plugin. + /// See attribute for details. /// public static IKernelBuilderPlugins AddFromType(this IKernelBuilderPlugins plugins, string? pluginName = null) { @@ -281,7 +286,8 @@ public static IKernelBuilderPlugins Add(this IKernelBuilderPlugins plugins, Kern /// /// A containing s for all relevant members of . /// - /// Public methods that have the attribute will be included in the plugin. + /// Methods that have the attribute will be included in the plugin. + /// See attribute for details. /// public static KernelPlugin ImportPluginFromObject(this Kernel kernel, object target, string? pluginName = null) { @@ -299,7 +305,8 @@ public static KernelPlugin ImportPluginFromObject(this Kernel kernel, object tar /// Service provider from which to resolve dependencies, such as . /// A containing s for all relevant members of . /// - /// Public methods that have the attribute will be included in the plugin. + /// Methods that have the attribute will be included in the plugin. + /// See attribute for details. /// public static KernelPlugin AddFromObject(this ICollection plugins, object target, string? pluginName = null, IServiceProvider? serviceProvider = null) { @@ -318,7 +325,8 @@ public static KernelPlugin AddFromObject(this ICollection plugins, /// /// The same instance as . /// - /// Public methods that have the attribute will be included in the plugin. + /// Methods that have the attribute will be included in the plugin. + /// See attribute for details. /// public static IKernelBuilderPlugins AddFromObject(this IKernelBuilderPlugins plugins, object target, string? pluginName = null) { diff --git a/dotnet/src/SemanticKernel.UnitTests/Functions/KernelFunctionFromMethodTests2.cs b/dotnet/src/SemanticKernel.UnitTests/Functions/KernelFunctionFromMethodTests2.cs index 33432d6f03ee..0cd64753780d 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Functions/KernelFunctionFromMethodTests2.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Functions/KernelFunctionFromMethodTests2.cs @@ -26,8 +26,8 @@ public void ItDoesntThrowForValidFunctionsViaDelegate() // Arrange var pluginInstance = new LocalExamplePlugin(); MethodInfo[] methods = pluginInstance.GetType() - .GetMethods(BindingFlags.Static | BindingFlags.Instance | BindingFlags.Public | BindingFlags.InvokeMethod) - .Where(m => m.Name is not "GetType" and not "Equals" and not "GetHashCode" and not "ToString") + .GetMethods(BindingFlags.Static | BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.InvokeMethod) + .Where(m => m.Name is not ("GetType" or "Equals" or "GetHashCode" or "ToString" or "Finalize" or "MemberwiseClone")) .ToArray(); KernelFunction[] functions = (from method in methods select KernelFunctionFactory.CreateFromMethod(method, pluginInstance, "plugin")).ToArray(); @@ -43,8 +43,8 @@ public void ItDoesNotThrowForValidFunctionsViaPlugin() // Arrange var pluginInstance = new LocalExamplePlugin(); MethodInfo[] methods = pluginInstance.GetType() - .GetMethods(BindingFlags.Static | BindingFlags.Instance | BindingFlags.Public | BindingFlags.InvokeMethod) - .Where(m => m.Name is not "GetType" and not "Equals" and not "GetHashCode" and not "ToString") + .GetMethods(BindingFlags.Static | BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.InvokeMethod) + .Where(m => m.Name is not ("GetType" or "Equals" or "GetHashCode" or "ToString" or "Finalize" or "MemberwiseClone")) .ToArray(); KernelFunction[] functions = [.. KernelPluginFactory.CreateFromObject(pluginInstance)]; @@ -329,13 +329,13 @@ public string Type05(string input) } [KernelFunction] - public string? Type05Nullable(string? input = null) + private string? Type05Nullable(string? input = null) { return ""; } [KernelFunction] - public string? Type05EmptyDefault(string? input = "") + internal string? Type05EmptyDefault(string? input = "") { return ""; }