diff --git a/src/Microsoft.ML.Core/EntryPoints/EntryPointUtils.cs b/src/Microsoft.ML.Core/EntryPoints/EntryPointUtils.cs index bbe59ed77f..52325dc6c1 100644 --- a/src/Microsoft.ML.Core/EntryPoints/EntryPointUtils.cs +++ b/src/Microsoft.ML.Core/EntryPoints/EntryPointUtils.cs @@ -14,6 +14,9 @@ namespace Microsoft.ML.EntryPoints [BestFriend] internal static class EntryPointUtils { + private static readonly FuncStaticMethodInfo1 _isValueWithinRangeMethodInfo + = new FuncStaticMethodInfo1(IsValueWithinRange); + private static bool IsValueWithinRange(TlcModule.RangeAttribute range, object obj) { T val; @@ -33,13 +36,12 @@ public static bool IsValueWithinRange(this TlcModule.RangeAttribute range, objec { Contracts.AssertValue(range); Contracts.AssertValue(val); - Func fn = IsValueWithinRange; // Avoid trying to cast double as float. If range // was specified using floats, but value being checked // is double, change range to be of type double if (range.Type == typeof(float) && val is double) range.CastToDouble(); - return Utils.MarshalInvoke(fn, range.Type, range, val); + return Utils.MarshalInvoke(_isValueWithinRangeMethodInfo, range.Type, range, val); } /// diff --git a/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo1`4.cs b/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo1`4.cs new file mode 100644 index 0000000000..cb0ae6451b --- /dev/null +++ b/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo1`4.cs @@ -0,0 +1,91 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#nullable enable + +using System; +using System.Linq.Expressions; +using System.Reflection; +using Microsoft.ML.Runtime; + +namespace Microsoft.ML.Internal.Utilities +{ + /// + /// Represents the for a generic function corresponding to , + /// with the following characteristics: + /// + /// + /// The method is an instance method on an object of type . + /// One generic type argument. + /// A return value of . + /// + /// + /// The type of the receiver of the instance method. + /// The type of the first parameter of the method. + /// The type of the second parameter of the method. + /// The type of the return value of the method. + internal sealed class FuncInstanceMethodInfo1 : FuncMethodInfo1 + where TTarget : class + { + private static readonly string _targetTypeCheckMessage = $"Should have a target type of '{typeof(TTarget)}'"; + + public FuncInstanceMethodInfo1(Func function) + : this(function.Method) + { + } + + private FuncInstanceMethodInfo1(MethodInfo methodInfo) + : base(methodInfo) + { + Contracts.CheckParam(!GenericMethodDefinition.IsStatic, nameof(methodInfo), "Should be an instance method"); + Contracts.CheckParam(GenericMethodDefinition.DeclaringType == typeof(TTarget), nameof(methodInfo), _targetTypeCheckMessage); + } + + /// + /// Creates a representing the + /// for a generic instance method. This helper method allows the instance to be created prior to the creation of + /// any instances of the target type. The following example shows the creation of an instance representing the + /// method: + /// + /// + /// FuncInstanceMethodInfo1<object, object, int>.Create(obj => obj.Equals) + /// + /// + /// The expression which creates the delegate for an instance of the target type. + /// A representing the + /// for the generic instance method. + public static FuncInstanceMethodInfo1 Create(Expression>> expression) + { + if (!(expression is { Body: UnaryExpression { Operand: MethodCallExpression methodCallExpression } })) + { + throw Contracts.ExceptParam(nameof(expression), "Unexpected expression form"); + } + + // Verify that we are calling MethodInfo.CreateDelegate(Type, object) + Contracts.CheckParam(methodCallExpression.Method.DeclaringType == typeof(MethodInfo), nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Method.Name == nameof(MethodInfo.CreateDelegate), nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Method.GetParameters().Length == 2, nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Method.GetParameters()[0].ParameterType == typeof(Type), nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Method.GetParameters()[1].ParameterType == typeof(object), nameof(expression), "Unexpected expression form"); + + // Verify that we are creating a delegate of type Func + Contracts.CheckParam(methodCallExpression.Arguments.Count == 2, nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Arguments[0] is ConstantExpression, nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(((ConstantExpression)methodCallExpression.Arguments[0]).Type == typeof(Type), nameof(expression), "Unexpected expression form"); + Contracts.CheckParam((Type)((ConstantExpression)methodCallExpression.Arguments[0]).Value == typeof(Func), nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Arguments[1] is ParameterExpression, nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Arguments[1] == expression.Parameters[0], nameof(expression), "Unexpected expression form"); + + // Check the MethodInfo + Contracts.CheckParam(methodCallExpression.Object is ConstantExpression, nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(((ConstantExpression)methodCallExpression.Object).Type == typeof(MethodInfo), nameof(expression), "Unexpected expression form"); + + var methodInfo = (MethodInfo)((ConstantExpression)methodCallExpression.Object).Value; + Contracts.CheckParam(expression.Body is UnaryExpression, nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(((UnaryExpression)expression.Body).Operand is MethodCallExpression, nameof(expression), "Unexpected expression form"); + + return new FuncInstanceMethodInfo1(methodInfo); + } + } +} diff --git a/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo2`4.cs b/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo2`4.cs new file mode 100644 index 0000000000..1a4f9c72ff --- /dev/null +++ b/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo2`4.cs @@ -0,0 +1,91 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#nullable enable + +using System; +using System.Linq.Expressions; +using System.Reflection; +using Microsoft.ML.Runtime; + +namespace Microsoft.ML.Internal.Utilities +{ + /// + /// Represents the for a generic function corresponding to , + /// with the following characteristics: + /// + /// + /// The method is an instance method on an object of type . + /// Two generic type arguments. + /// A return value of . + /// + /// + /// The type of the receiver of the instance method. + /// The type of the first parameter of the method. + /// The type of the second parameter of the method. + /// The type of the return value of the method. + internal sealed class FuncInstanceMethodInfo2 : FuncMethodInfo2 + where TTarget : class + { + private static readonly string _targetTypeCheckMessage = $"Should have a target type of '{typeof(TTarget)}'"; + + public FuncInstanceMethodInfo2(Func function) + : this(function.Method) + { + } + + private FuncInstanceMethodInfo2(MethodInfo methodInfo) + : base(methodInfo) + { + Contracts.CheckParam(!GenericMethodDefinition.IsStatic, nameof(methodInfo), "Should be an instance method"); + Contracts.CheckParam(GenericMethodDefinition.DeclaringType == typeof(TTarget), nameof(methodInfo), _targetTypeCheckMessage); + } + + /// + /// Creates a representing the + /// for a generic instance method. This helper method allows the instance to be created prior to the creation of + /// any instances of the target type. The following example shows the creation of an instance representing the + /// method: + /// + /// + /// FuncInstanceMethodInfo1<object, object, int>.Create(obj => obj.Equals) + /// + /// + /// The expression which creates the delegate for an instance of the target type. + /// A representing the + /// for the generic instance method. + public static FuncInstanceMethodInfo2 Create(Expression>> expression) + { + if (!(expression is { Body: UnaryExpression { Operand: MethodCallExpression methodCallExpression } })) + { + throw Contracts.ExceptParam(nameof(expression), "Unexpected expression form"); + } + + // Verify that we are calling MethodInfo.CreateDelegate(Type, object) + Contracts.CheckParam(methodCallExpression.Method.DeclaringType == typeof(MethodInfo), nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Method.Name == nameof(MethodInfo.CreateDelegate), nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Method.GetParameters().Length == 2, nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Method.GetParameters()[0].ParameterType == typeof(Type), nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Method.GetParameters()[1].ParameterType == typeof(object), nameof(expression), "Unexpected expression form"); + + // Verify that we are creating a delegate of type Func + Contracts.CheckParam(methodCallExpression.Arguments.Count == 2, nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Arguments[0] is ConstantExpression, nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(((ConstantExpression)methodCallExpression.Arguments[0]).Type == typeof(Type), nameof(expression), "Unexpected expression form"); + Contracts.CheckParam((Type)((ConstantExpression)methodCallExpression.Arguments[0]).Value == typeof(Func), nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Arguments[1] is ParameterExpression, nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Arguments[1] == expression.Parameters[0], nameof(expression), "Unexpected expression form"); + + // Check the MethodInfo + Contracts.CheckParam(methodCallExpression.Object is ConstantExpression, nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(((ConstantExpression)methodCallExpression.Object).Type == typeof(MethodInfo), nameof(expression), "Unexpected expression form"); + + var methodInfo = (MethodInfo)((ConstantExpression)methodCallExpression.Object).Value; + Contracts.CheckParam(expression.Body is UnaryExpression, nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(((UnaryExpression)expression.Body).Operand is MethodCallExpression, nameof(expression), "Unexpected expression form"); + + return new FuncInstanceMethodInfo2(methodInfo); + } + } +} diff --git a/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo3`4.cs b/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo3`4.cs new file mode 100644 index 0000000000..51b8dbbe34 --- /dev/null +++ b/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo3`4.cs @@ -0,0 +1,91 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#nullable enable + +using System; +using System.Linq.Expressions; +using System.Reflection; +using Microsoft.ML.Runtime; + +namespace Microsoft.ML.Internal.Utilities +{ + /// + /// Represents the for a generic function corresponding to , + /// with the following characteristics: + /// + /// + /// The method is an instance method on an object of type . + /// Three generic type arguments. + /// A return value of . + /// + /// + /// The type of the receiver of the instance method. + /// The type of the first parameter of the method. + /// The type of the second parameter of the method. + /// The type of the return value of the method. + internal sealed class FuncInstanceMethodInfo3 : FuncMethodInfo3 + where TTarget : class + { + private static readonly string _targetTypeCheckMessage = $"Should have a target type of '{typeof(TTarget)}'"; + + public FuncInstanceMethodInfo3(Func function) + : this(function.Method) + { + } + + private FuncInstanceMethodInfo3(MethodInfo methodInfo) + : base(methodInfo) + { + Contracts.CheckParam(!GenericMethodDefinition.IsStatic, nameof(methodInfo), "Should be an instance method"); + Contracts.CheckParam(GenericMethodDefinition.DeclaringType == typeof(TTarget), nameof(methodInfo), _targetTypeCheckMessage); + } + + /// + /// Creates a representing the + /// for a generic instance method. This helper method allows the instance to be created prior to the creation of + /// any instances of the target type. The following example shows the creation of an instance representing the + /// method: + /// + /// + /// FuncInstanceMethodInfo1<object, object, int>.Create(obj => obj.Equals) + /// + /// + /// The expression which creates the delegate for an instance of the target type. + /// A representing the + /// for the generic instance method. + public static FuncInstanceMethodInfo3 Create(Expression>> expression) + { + if (!(expression is { Body: UnaryExpression { Operand: MethodCallExpression methodCallExpression } })) + { + throw Contracts.ExceptParam(nameof(expression), "Unexpected expression form"); + } + + // Verify that we are calling MethodInfo.CreateDelegate(Type, object) + Contracts.CheckParam(methodCallExpression.Method.DeclaringType == typeof(MethodInfo), nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Method.Name == nameof(MethodInfo.CreateDelegate), nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Method.GetParameters().Length == 2, nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Method.GetParameters()[0].ParameterType == typeof(Type), nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Method.GetParameters()[1].ParameterType == typeof(object), nameof(expression), "Unexpected expression form"); + + // Verify that we are creating a delegate of type Func + Contracts.CheckParam(methodCallExpression.Arguments.Count == 2, nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Arguments[0] is ConstantExpression, nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(((ConstantExpression)methodCallExpression.Arguments[0]).Type == typeof(Type), nameof(expression), "Unexpected expression form"); + Contracts.CheckParam((Type)((ConstantExpression)methodCallExpression.Arguments[0]).Value == typeof(Func), nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Arguments[1] is ParameterExpression, nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Arguments[1] == expression.Parameters[0], nameof(expression), "Unexpected expression form"); + + // Check the MethodInfo + Contracts.CheckParam(methodCallExpression.Object is ConstantExpression, nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(((ConstantExpression)methodCallExpression.Object).Type == typeof(MethodInfo), nameof(expression), "Unexpected expression form"); + + var methodInfo = (MethodInfo)((ConstantExpression)methodCallExpression.Object).Value; + Contracts.CheckParam(expression.Body is UnaryExpression, nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(((UnaryExpression)expression.Body).Operand is MethodCallExpression, nameof(expression), "Unexpected expression form"); + + return new FuncInstanceMethodInfo3(methodInfo); + } + } +} diff --git a/src/Microsoft.ML.Core/Utilities/FuncMethodInfo1`3.cs b/src/Microsoft.ML.Core/Utilities/FuncMethodInfo1`3.cs new file mode 100644 index 0000000000..daf997ae3e --- /dev/null +++ b/src/Microsoft.ML.Core/Utilities/FuncMethodInfo1`3.cs @@ -0,0 +1,48 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#nullable enable + +using System; +using System.Collections.Immutable; +using System.Reflection; +using Microsoft.ML.Runtime; + +namespace Microsoft.ML.Internal.Utilities +{ + /// + /// Represents the for a generic function corresponding to , + /// with the following characteristics: + /// + /// + /// One generic type argument. + /// A return value of . + /// + /// + /// The type of the first parameter of the method. + /// The type of the second parameter of the method. + /// The type of the return value of the method. + internal abstract class FuncMethodInfo1 : FuncMethodInfo + { + private ImmutableDictionary _instanceMethodInfo; + + private protected FuncMethodInfo1(MethodInfo methodInfo) + : base(methodInfo) + { + _instanceMethodInfo = ImmutableDictionary.Empty; + + Contracts.CheckParam(GenericMethodDefinition.GetGenericArguments().Length == 1, nameof(methodInfo), + "Should have exactly one generic type parameter but does not"); + } + + public MethodInfo MakeGenericMethod(Type typeArg1) + { + return ImmutableInterlocked.GetOrAdd( + ref _instanceMethodInfo, + typeArg1, + (typeArg, methodInfo) => methodInfo.MakeGenericMethod(typeArg), + GenericMethodDefinition); + } + } +} diff --git a/src/Microsoft.ML.Core/Utilities/FuncMethodInfo2`3.cs b/src/Microsoft.ML.Core/Utilities/FuncMethodInfo2`3.cs new file mode 100644 index 0000000000..022930127d --- /dev/null +++ b/src/Microsoft.ML.Core/Utilities/FuncMethodInfo2`3.cs @@ -0,0 +1,48 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#nullable enable + +using System; +using System.Collections.Immutable; +using System.Reflection; +using Microsoft.ML.Runtime; + +namespace Microsoft.ML.Internal.Utilities +{ + /// + /// Represents the for a generic function corresponding to , + /// with the following characteristics: + /// + /// + /// Two generic type arguments. + /// A return value of . + /// + /// + /// The type of the first parameter of the method. + /// The type of the second parameter of the method. + /// The type of the return value of the method. + internal abstract class FuncMethodInfo2 : FuncMethodInfo + { + private ImmutableDictionary<(Type, Type), MethodInfo> _instanceMethodInfo; + + private protected FuncMethodInfo2(MethodInfo methodInfo) + : base(methodInfo) + { + _instanceMethodInfo = ImmutableDictionary<(Type, Type), MethodInfo>.Empty; + + Contracts.CheckParam(GenericMethodDefinition.GetGenericArguments().Length == 2, nameof(methodInfo), + "Should have exactly two generic type parameters but does not"); + } + + public MethodInfo MakeGenericMethod(Type typeArg1, Type typeArg2) + { + return ImmutableInterlocked.GetOrAdd( + ref _instanceMethodInfo, + (typeArg1, typeArg2), + (args, methodInfo) => methodInfo.MakeGenericMethod(args.Item1, args.Item2), + GenericMethodDefinition); + } + } +} diff --git a/src/Microsoft.ML.Core/Utilities/FuncMethodInfo3`3.cs b/src/Microsoft.ML.Core/Utilities/FuncMethodInfo3`3.cs new file mode 100644 index 0000000000..fb287a8a58 --- /dev/null +++ b/src/Microsoft.ML.Core/Utilities/FuncMethodInfo3`3.cs @@ -0,0 +1,48 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#nullable enable + +using System; +using System.Collections.Immutable; +using System.Reflection; +using Microsoft.ML.Runtime; + +namespace Microsoft.ML.Internal.Utilities +{ + /// + /// Represents the for a generic function corresponding to , + /// with the following characteristics: + /// + /// + /// Three generic type arguments. + /// A return value of . + /// + /// + /// The type of the first parameter of the method. + /// The type of the second parameter of the method. + /// The type of the return value of the method. + internal abstract class FuncMethodInfo3 : FuncMethodInfo + { + private ImmutableDictionary<(Type, Type, Type), MethodInfo> _instanceMethodInfo; + + private protected FuncMethodInfo3(MethodInfo methodInfo) + : base(methodInfo) + { + _instanceMethodInfo = ImmutableDictionary<(Type, Type, Type), MethodInfo>.Empty; + + Contracts.CheckParam(GenericMethodDefinition.GetGenericArguments().Length == 3, nameof(methodInfo), + "Should have exactly three generic type parameters but does not"); + } + + public MethodInfo MakeGenericMethod(Type typeArg1, Type typeArg2, Type typeArg3) + { + return ImmutableInterlocked.GetOrAdd( + ref _instanceMethodInfo, + (typeArg1, typeArg2, typeArg3), + (args, methodInfo) => methodInfo.MakeGenericMethod(args.Item1, args.Item2, args.Item3), + GenericMethodDefinition); + } + } +} diff --git a/src/Microsoft.ML.Core/Utilities/FuncMethodInfo`3.cs b/src/Microsoft.ML.Core/Utilities/FuncMethodInfo`3.cs new file mode 100644 index 0000000000..e77d026eec --- /dev/null +++ b/src/Microsoft.ML.Core/Utilities/FuncMethodInfo`3.cs @@ -0,0 +1,26 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#nullable enable + +using System.Reflection; +using Microsoft.ML.Runtime; + +namespace Microsoft.ML.Internal.Utilities +{ + internal abstract class FuncMethodInfo + { + private protected FuncMethodInfo(MethodInfo methodInfo) + { + Contracts.CheckValue(methodInfo, nameof(methodInfo)); + + Contracts.CheckParam(methodInfo.IsGenericMethod, nameof(methodInfo), "Should be generic but is not"); + + GenericMethodDefinition = methodInfo.GetGenericMethodDefinition(); + Contracts.CheckParam(typeof(TResult).IsAssignableFrom(GenericMethodDefinition.ReturnType), nameof(methodInfo), "Cannot be generic on return type"); + } + + protected MethodInfo GenericMethodDefinition { get; } + } +} diff --git a/src/Microsoft.ML.Core/Utilities/FuncStaticMethodInfo1`3.cs b/src/Microsoft.ML.Core/Utilities/FuncStaticMethodInfo1`3.cs new file mode 100644 index 0000000000..8c64785c1a --- /dev/null +++ b/src/Microsoft.ML.Core/Utilities/FuncStaticMethodInfo1`3.cs @@ -0,0 +1,34 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#nullable enable + +using System; +using System.Reflection; +using Microsoft.ML.Runtime; + +namespace Microsoft.ML.Internal.Utilities +{ + /// + /// Represents the for a generic function corresponding to , + /// with the following characteristics: + /// + /// + /// The method is static. + /// One generic type argument. + /// A return value of . + /// + /// + /// The type of the first parameter of the method. + /// The type of the second parameter of the method. + /// The type of the return value of the method. + internal sealed class FuncStaticMethodInfo1 : FuncMethodInfo1 + { + public FuncStaticMethodInfo1(Func function) + : base(function.Method) + { + Contracts.CheckParam(GenericMethodDefinition.IsStatic, nameof(function), "Should be a static method"); + } + } +} diff --git a/src/Microsoft.ML.Core/Utilities/FuncStaticMethodInfo2`3.cs b/src/Microsoft.ML.Core/Utilities/FuncStaticMethodInfo2`3.cs new file mode 100644 index 0000000000..d48b5d778c --- /dev/null +++ b/src/Microsoft.ML.Core/Utilities/FuncStaticMethodInfo2`3.cs @@ -0,0 +1,34 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#nullable enable + +using System; +using System.Reflection; +using Microsoft.ML.Runtime; + +namespace Microsoft.ML.Internal.Utilities +{ + /// + /// Represents the for a generic function corresponding to , + /// with the following characteristics: + /// + /// + /// The method is static. + /// Two generic type arguments. + /// A return value of . + /// + /// + /// The type of the first parameter of the method. + /// The type of the second parameter of the method. + /// The type of the return value of the method. + internal sealed class FuncStaticMethodInfo2 : FuncMethodInfo2 + { + public FuncStaticMethodInfo2(Func function) + : base(function.Method) + { + Contracts.CheckParam(GenericMethodDefinition.IsStatic, nameof(function), "Should be a static method"); + } + } +} diff --git a/src/Microsoft.ML.Core/Utilities/FuncStaticMethodInfo3`3.cs b/src/Microsoft.ML.Core/Utilities/FuncStaticMethodInfo3`3.cs new file mode 100644 index 0000000000..4612927134 --- /dev/null +++ b/src/Microsoft.ML.Core/Utilities/FuncStaticMethodInfo3`3.cs @@ -0,0 +1,34 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#nullable enable + +using System; +using System.Reflection; +using Microsoft.ML.Runtime; + +namespace Microsoft.ML.Internal.Utilities +{ + /// + /// Represents the for a generic function corresponding to , + /// with the following characteristics: + /// + /// + /// The method is static. + /// Three generic type arguments. + /// A return value of . + /// + /// + /// The type of the first parameter of the method. + /// The type of the second parameter of the method. + /// The type of the return value of the method. + internal sealed class FuncStaticMethodInfo3 : FuncMethodInfo3 + { + public FuncStaticMethodInfo3(Func function) + : base(function.Method) + { + Contracts.CheckParam(GenericMethodDefinition.IsStatic, nameof(function), "Should be a static method"); + } + } +} diff --git a/src/Microsoft.ML.Core/Utilities/Utils.cs b/src/Microsoft.ML.Core/Utilities/Utils.cs index 4912d7f84d..698a97fee1 100644 --- a/src/Microsoft.ML.Core/Utilities/Utils.cs +++ b/src/Microsoft.ML.Core/Utilities/Utils.cs @@ -966,14 +966,6 @@ private static MethodInfo MarshalInvokeCheckAndCreate(Type genArg, Delegat return meth; } - private static MethodInfo MarshalInvokeCheckAndCreate(Type[] genArgs, Delegate func) - { - var meth = MarshalActionInvokeCheckAndCreate(genArgs, func); - if (meth.ReturnType != typeof(TRet)) - throw Contracts.ExceptParam(nameof(func), "Cannot be generic on return type"); - return meth; - } - // REVIEW: n-argument versions? The multi-column re-application problem? // Think about how to address these. @@ -1053,10 +1045,58 @@ public static TResult MarshalInvoke(FuncStaticMethodInfo3 /// A two-argument version of . /// - public static TRet MarshalInvoke(Func func, Type genArg, TArg1 arg1, TArg2 arg2) + public static TResult MarshalInvoke(FuncInstanceMethodInfo1 func, TTarget target, Type genArg, TArg1 arg1, TArg2 arg2) + where TTarget : class { - var meth = MarshalInvokeCheckAndCreate(genArg, func); - return (TRet)meth.Invoke(func.Target, new object[] { arg1, arg2 }); + var meth = func.MakeGenericMethod(genArg); + return (TResult)meth.Invoke(target, new object[] { arg1, arg2 }); + } + + /// + /// A two-argument version of . + /// + public static TResult MarshalInvoke(FuncStaticMethodInfo1 func, Type genArg, TArg1 arg1, TArg2 arg2) + { + var meth = func.MakeGenericMethod(genArg); + return (TResult)meth.Invoke(null, new object[] { arg1, arg2 }); + } + + /// + /// A two-argument, two-type-parameter version of . + /// + public static TResult MarshalInvoke(FuncInstanceMethodInfo2 func, TTarget target, Type genArg1, Type genArg2, TArg1 arg1, TArg2 arg2) + where TTarget : class + { + var meth = func.MakeGenericMethod(genArg1, genArg2); + return (TResult)meth.Invoke(target, new object[] { arg1, arg2 }); + } + + /// + /// A two-argument, two-type-parameter version of . + /// + public static TResult MarshalInvoke(FuncStaticMethodInfo2 func, Type genArg1, Type genArg2, TArg1 arg1, TArg2 arg2) + { + var meth = func.MakeGenericMethod(genArg1, genArg2); + return (TResult)meth.Invoke(null, new object[] { arg1, arg2 }); + } + + /// + /// A two-argument, three-type-parameter version of . + /// + public static TResult MarshalInvoke(FuncInstanceMethodInfo3 func, TTarget target, Type genArg1, Type genArg2, Type genArg3, TArg1 arg1, TArg2 arg2) + where TTarget : class + { + var meth = func.MakeGenericMethod(genArg1, genArg2, genArg3); + return (TResult)meth.Invoke(target, new object[] { arg1, arg2 }); + } + + /// + /// A two-argument, three-type-parameter version of . + /// + public static TResult MarshalInvoke(FuncStaticMethodInfo3 func, Type genArg1, Type genArg2, Type genArg3, TArg1 arg1, TArg2 arg2) + { + var meth = func.MakeGenericMethod(genArg1, genArg2, genArg3); + return (TResult)meth.Invoke(null, new object[] { arg1, arg2 }); } /// @@ -1141,17 +1181,6 @@ public static TRet MarshalInvoke - /// A 2 argument and n type version of . - /// - public static TRet MarshalInvoke( - Func func, - Type[] genArgs, TArg1 arg1, TArg2 arg2) - { - var meth = MarshalInvokeCheckAndCreate(genArgs, func); - return (TRet)meth.Invoke(func.Target, new object[] { arg1, arg2}); - } - private static MethodInfo MarshalActionInvokeCheckAndCreate(Type genArg, Delegate func) { Contracts.CheckValue(genArg, nameof(genArg)); @@ -1164,18 +1193,6 @@ private static MethodInfo MarshalActionInvokeCheckAndCreate(Type genArg, Delegat return meth; } - private static MethodInfo MarshalActionInvokeCheckAndCreate(Type[] typeArguments, Delegate func) - { - Contracts.CheckValue(typeArguments, nameof(typeArguments)); - Contracts.CheckValue(func, nameof(func)); - var meth = func.GetMethodInfo(); - Contracts.CheckParam(meth.IsGenericMethod, nameof(func), "Should be generic but is not"); - Contracts.CheckParam(meth.GetGenericArguments().Length == typeArguments.Length, nameof(func), - "Method should have exactly the same number of generic type parameters as list passed in but it does not."); - meth = meth.GetGenericMethodDefinition().MakeGenericMethod(typeArguments); - return meth; - } - /// /// This is akin to , except applied to /// instead of . diff --git a/src/Microsoft.ML.Data/Commands/TypeInfoCommand.cs b/src/Microsoft.ML.Data/Commands/TypeInfoCommand.cs index ed6739eb7d..12cb66c79c 100644 --- a/src/Microsoft.ML.Data/Commands/TypeInfoCommand.cs +++ b/src/Microsoft.ML.Data/Commands/TypeInfoCommand.cs @@ -19,6 +19,9 @@ namespace Microsoft.ML.Data.Commands { internal sealed class TypeInfoCommand : ICommand { + private static readonly FuncInstanceMethodInfo1 _kindReportMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.KindReport); + internal const string LoadName = "TypeInfo"; internal const string Summary = "Displays information about the standard primitive " + "non-key types, and conversions between them."; @@ -89,7 +92,7 @@ public void Run() for (int i = 0; i < types.Length; ++i) { ch.AssertValue(types[i]); - var info = Utils.MarshalInvoke(KindReport, types[i].RawType, ch, types[i]); + var info = Utils.MarshalInvoke(_kindReportMethodInfo, this, types[i].RawType, ch, types[i]); var dstKinds = new HashSet(); Delegate del; diff --git a/src/Microsoft.ML.Data/Data/DataViewUtils.cs b/src/Microsoft.ML.Data/Data/DataViewUtils.cs index ad74142865..0524d4c728 100644 --- a/src/Microsoft.ML.Data/Data/DataViewUtils.cs +++ b/src/Microsoft.ML.Data/Data/DataViewUtils.cs @@ -289,6 +289,12 @@ public static DataViewRowCursor ConsolidateGeneric(IChannelProvider provider, Da /// private sealed class Splitter { + private static readonly FuncStaticMethodInfo1 _getPoolCoreMethodInfo + = new FuncStaticMethodInfo1(GetPoolCore); + + private static readonly FuncInstanceMethodInfo1 _createInPipeMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.CreateInPipe); + private readonly DataViewSchema _schema; private readonly object[] _cachePools; @@ -476,9 +482,7 @@ private static DataViewRowCursor ConsolidateCore(IChannelProvider provider, Data private static object GetPool(DataViewType type, object[] pools, int poolIdx) { - Func func = GetPoolCore; - var method = func.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(type.RawType); - return method.Invoke(null, new object[] { pools, poolIdx }); + return Utils.MarshalInvoke(_getPoolCoreMethodInfo, type.RawType, pools, poolIdx); } private static MadeObjectPool GetPoolCore(object[] pools, int poolIdx) @@ -519,9 +523,6 @@ private DataViewRowCursor[] SplitCore(IChannelProvider ch, DataViewRowCursor inp int[] colToActive; Utils.BuildSubsetMaps(_schema, input.IsColumnActive, out activeToCol, out colToActive); - Func createFunc = CreateInPipe; - var inGenMethod = createFunc.GetMethodInfo().GetGenericMethodDefinition(); - object[] arguments = new object[] { input, 0 }; // Only one set of in-pipes, one per column, as well as for extra side information. InPipe[] inPipes = new InPipe[activeToCol.Length + (int)ExtraIndex._Lim]; // There are as many sets of out pipes as there are output cursors. @@ -537,9 +538,8 @@ private DataViewRowCursor[] SplitCore(IChannelProvider ch, DataViewRowCursor inp var column = input.Schema[activeToCol[c]]; ch.Assert(input.IsColumnActive(column)); ch.Assert(column.Type.IsCacheable()); - arguments[1] = activeToCol[c]; var inPipe = inPipes[c] = - (InPipe)inGenMethod.MakeGenericMethod(column.Type.RawType).Invoke(this, arguments); + Utils.MarshalInvoke(_createInPipeMethodInfo, this, column.Type.RawType, input, activeToCol[c]); for (int i = 0; i < cthd; ++i) outPipes[i][c] = inPipe.CreateOutPipe(column.Type); } diff --git a/src/Microsoft.ML.Data/Data/RowCursorUtils.cs b/src/Microsoft.ML.Data/Data/RowCursorUtils.cs index 4279b3066d..a34070d791 100644 --- a/src/Microsoft.ML.Data/Data/RowCursorUtils.cs +++ b/src/Microsoft.ML.Data/Data/RowCursorUtils.cs @@ -16,6 +16,12 @@ namespace Microsoft.ML.Data [BestFriend] internal static class RowCursorUtils { + private static readonly FuncStaticMethodInfo1 _getGetterAsDelegateCoreMethodInfo + = new FuncStaticMethodInfo1(GetGetterAsDelegateCore); + + private static readonly FuncStaticMethodInfo1> _getIsNewGroupDelegateCoreMethodInfo + = new FuncStaticMethodInfo1>(GetIsNewGroupDelegateCore); + /// /// Returns an appropriate for a row given an active column /// index, but as a delegate. The type parameter for the delegate will correspond to the @@ -30,8 +36,7 @@ public static Delegate GetGetterAsDelegate(DataViewRow row, int col) Contracts.CheckParam(0 <= col && col < row.Schema.Count, nameof(col)); Contracts.CheckParam(row.IsColumnActive(row.Schema[col]), nameof(col), "column was not active"); - Func getGetter = GetGetterAsDelegateCore; - return Utils.MarshalInvoke(getGetter, row.Schema[col].Type.RawType, row, col); + return Utils.MarshalInvoke(_getGetterAsDelegateCoreMethodInfo, row.Schema[col].Type.RawType, row, col); } private static Delegate GetGetterAsDelegateCore(DataViewRow row, int col) @@ -302,7 +307,7 @@ public static Func GetIsNewGroupDelegate(DataViewRow cursor, int col) Contracts.Check(0 <= col && col < cursor.Schema.Count); DataViewType type = cursor.Schema[col].Type; Contracts.Check(type is KeyDataViewType); - return Utils.MarshalInvoke(GetIsNewGroupDelegateCore, type.RawType, cursor, col); + return Utils.MarshalInvoke(_getIsNewGroupDelegateCoreMethodInfo, type.RawType, cursor, col); } private static Func GetIsNewGroupDelegateCore(DataViewRow cursor, int col) diff --git a/src/Microsoft.ML.Data/DataDebuggerPreview.cs b/src/Microsoft.ML.Data/DataDebuggerPreview.cs index f74d8aeeb4..524891e843 100644 --- a/src/Microsoft.ML.Data/DataDebuggerPreview.cs +++ b/src/Microsoft.ML.Data/DataDebuggerPreview.cs @@ -16,6 +16,9 @@ namespace Microsoft.ML.Data /// public sealed class DataDebuggerPreview { + private static readonly FuncInstanceMethodInfo1>> _makeSetterMethodInfo + = FuncInstanceMethodInfo1>>.Create(target => target.MakeSetter); + internal static class Defaults { public const int MaxRows = 100; @@ -42,7 +45,7 @@ internal DataDebuggerPreview(IDataView data, int maxRows = Defaults.MaxRows) { var setters = new Action>[n]; for (int i = 0; i < n; i++) - setters[i] = Utils.MarshalInvoke(MakeSetter, data.Schema[i].Type.RawType, cursor, i); + setters[i] = Utils.MarshalInvoke(_makeSetterMethodInfo, this, data.Schema[i].Type.RawType, cursor, i); int count = 0; while (count < maxRows && cursor.MoveNext()) diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinarySaver.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinarySaver.cs index 80c1b59a06..b488fd5eb0 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinarySaver.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinarySaver.cs @@ -28,6 +28,9 @@ namespace Microsoft.ML.Data.IO [BestFriend] internal sealed class BinarySaver : IDataSaver { + private static readonly FuncInstanceMethodInfo1 _loadValueMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.LoadValue); + public sealed class Arguments { [Argument(ArgumentType.LastOccurrenceWins, HelpText = "The compression scheme to use for the blocks", ShortName = "comp")] @@ -887,22 +890,20 @@ public bool TryLoadTypeAndValue(Stream stream, out DataViewType type, out object value = null; return false; } - type = codec.Type; - Func, object> func = LoadValue; - var meth = func.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(codec.Type.RawType); - value = (meth.Invoke(this, new object[] { stream, codec })); + type = codec.Type; + value = Utils.MarshalInvoke(_loadValueMethodInfo, this, type.RawType, stream, codec); return true; } /// /// Deserializes and returns a value given a stream and codec. /// - private object LoadValue(Stream stream, IValueCodec codec) + private object LoadValue(Stream stream, IValueCodec codec) { _host.Assert(typeof(T) == codec.Type.RawType); T value = default(T); - using (var reader = codec.OpenReader(stream, 1)) + using (var reader = ((IValueCodec)codec).OpenReader(stream, 1)) { reader.MoveNext(); reader.Get(ref value); diff --git a/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs b/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs index 9c5e09d2bc..8d79076c05 100644 --- a/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs @@ -19,6 +19,9 @@ namespace Microsoft.ML.Data [BestFriend] internal static class DataViewConstructionUtils { + private static readonly FuncStaticMethodInfo1 _getAnnotationInfoMethodInfo + = new FuncStaticMethodInfo1(GetAnnotationInfo); + public static IDataView CreateFromList(IHostEnvironment env, IList data, SchemaDefinition schemaDefinition = null) where TRow : class @@ -73,7 +76,7 @@ internal static SchemaDefinition GetSchemaDefinition(IHostEnvironment env, { foreach (var annotation in annotations.Schema) { - var info = Utils.MarshalInvoke(GetAnnotationInfo, annotation.Type.RawType, annotation.Name, annotations); + var info = Utils.MarshalInvoke(_getAnnotationInfoMethodInfo, annotation.Type.RawType, annotation.Name, annotations); schemaDefinitionCol.AddAnnotation(annotation.Name , info); } } @@ -191,6 +194,9 @@ private static readonly FuncInstanceMethodInfo1, Delegate, De private static readonly FuncInstanceMethodInfo1, Delegate, Delegate> _createDirectGetterDelegateMethodInfo = FuncInstanceMethodInfo1, Delegate, Delegate>.Create(target => target.CreateDirectGetterDelegate); + private static readonly FuncInstanceMethodInfo1, Delegate, DataViewType, Delegate> _createKeyGetterDelegateMethodInfo + = FuncInstanceMethodInfo1, Delegate, DataViewType, Delegate>.Create(target => target.CreateKeyGetterDelegate); + private readonly int _colCount; private readonly Delegate[] _getters; protected readonly IHost Host; @@ -274,8 +280,7 @@ private Delegate CreateGetter(DataViewType colType, InternalSchemaDefinition.Col else { var keyRawType = colType.RawType; - Func delForKey = CreateKeyGetterDelegate; - return Utils.MarshalInvoke(delForKey, keyRawType, peek, colType); + return Utils.MarshalInvoke(_createKeyGetterDelegateMethodInfo, this, keyRawType, peek, colType); } } else if (DataViewTypeManager.Knows(colType)) diff --git a/src/Microsoft.ML.Data/DataView/Transposer.cs b/src/Microsoft.ML.Data/DataView/Transposer.cs index c424ec45be..77dc58275b 100644 --- a/src/Microsoft.ML.Data/DataView/Transposer.cs +++ b/src/Microsoft.ML.Data/DataView/Transposer.cs @@ -892,6 +892,9 @@ private Func CreateInputPredicate(Func pred, out bool[] ac /// private abstract class Splitter { + private static readonly FuncStaticMethodInfo1 _createCoreMethodInfo + = new FuncStaticMethodInfo1(CreateCore); + private readonly IDataView _view; private readonly int _col; public abstract int ColumnCount { get; } @@ -922,7 +925,7 @@ public static Splitter Create(IDataView view, int col) Contracts.Assert(type is PrimitiveDataViewType || vectorSize > 0); const int defaultSplitThreshold = 16; if (vectorSize <= defaultSplitThreshold) - return Utils.MarshalInvoke(CreateCore, type.RawType, view, col); + return Utils.MarshalInvoke(_createCoreMethodInfo, type.RawType, view, col); else { // There are serious practical problems with trying to save many thousands of columns. @@ -1321,6 +1324,9 @@ public override ValueGetter GetGetter(DataViewSchema.Column colu internal static class TransposerUtils { + private static readonly FuncStaticMethodInfo1 _getRowCursorShimCoreMethodInfo + = new FuncStaticMethodInfo1(GetRowCursorShimCore); + private static readonly FuncInstanceMethodInfo1 _slotCursorGetGetterMethodInfo = FuncInstanceMethodInfo1.Create(target => target.GetGetter); @@ -1387,7 +1393,7 @@ public static DataViewRowCursor GetRowCursorShim(IChannelProvider provider, Slot Contracts.CheckValue(provider, nameof(provider)); provider.CheckValue(cursor, nameof(cursor)); - return Utils.MarshalInvoke(GetRowCursorShimCore, cursor.GetSlotType().ItemType.RawType, provider, cursor); + return Utils.MarshalInvoke(_getRowCursorShimCoreMethodInfo, cursor.GetSlotType().ItemType.RawType, provider, cursor); } private static DataViewRowCursor GetRowCursorShimCore(IChannelProvider provider, SlotCursor cursor) diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index effd0eac04..9265e60424 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -644,6 +644,9 @@ internal sealed class SchemaBindableCalibratedModelParameters _getPredictorGetterMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.GetPredictorGetter); + private readonly SchemaBindableCalibratedModelParameters _parent; private readonly ISchemaBoundRowMapper _predictor; private readonly int _scoreCol; @@ -697,7 +700,7 @@ DataViewRow ISchemaBoundRowMapper.GetRow(DataViewRow input, IEnumerable, type.RawType, predictorRow, column.Index); + getters[column.Index] = Utils.MarshalInvoke(_getPredictorGetterMethodInfo, this, type.RawType, predictorRow, column.Index); } if (hasProbabilityColumn) diff --git a/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculation.cs b/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculation.cs index 7d13b204b1..8a98c9bad6 100644 --- a/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculation.cs +++ b/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculation.cs @@ -91,6 +91,9 @@ private static ISchemaBindableMapper Create(IHostEnvironment env, ModelLoadConte /// private sealed class BindableMapper : ISchemaBindableMapper, ICanSaveModel, IPredictor { + private static readonly FuncInstanceMethodInfo1 _getValueGetterMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.GetValueGetter); + private readonly int _topContributionsCount; private readonly int _bottomContributionsCount; private readonly bool _normalize; @@ -202,12 +205,10 @@ public Delegate GetContributionGetter(DataViewRow input, int colSrc) Contracts.Check(0 <= colSrc && colSrc < input.Schema.Count); var typeSrc = input.Schema[colSrc].Type; - Func>> del = GetValueGetter; // REVIEW: Assuming Feature contributions will be VBuffer. // For multiclass LR it needs to be(VBuffer[]. - var meth = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(typeSrc.RawType); - return (Delegate)meth.Invoke(this, new object[] { input, colSrc }); + return Utils.MarshalInvoke(_getValueGetterMethodInfo, this, typeSrc.RawType, input, colSrc); } private ReadOnlyMemory GetSlotName(int index, VBuffer> slotNames) diff --git a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs index 3f2e517d31..abc1eee4d4 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs @@ -31,6 +31,9 @@ public abstract class ThresholdArgumentsBase : ScorerArgumentsBase [BestFriend] private protected sealed class BindingsImpl : BindingsBase { + private static readonly FuncStaticMethodInfo1 _keyValueMetadataFromMetadataMethodInfo + = new FuncStaticMethodInfo1(KeyValueMetadataFromMetadata); + // Column index of the score column in Mapper's schema. public readonly int ScoreColumnIndex; // The type of the derived column. @@ -66,7 +69,7 @@ private BindingsImpl(DataViewSchema input, ISchemaBoundRowMapper mapper, string if (trainLabelColumn?.Type is VectorDataViewType trainLabelColVecType && (ulong)trainLabelColVecType.Size == predColKeyType.Count) { Contracts.Assert(trainLabelColVecType.Size > 0); - _predColMetadata = Utils.MarshalInvoke(KeyValueMetadataFromMetadata, trainLabelColVecType.RawType, + _predColMetadata = Utils.MarshalInvoke(_keyValueMetadataFromMetadataMethodInfo, trainLabelColVecType.RawType, scoreColMetadata, trainLabelColumn.Value); } } diff --git a/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs b/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs index 2bb2c40256..1bc5561942 100644 --- a/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs +++ b/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs @@ -7,6 +7,7 @@ using System.Linq; using System.Reflection; using Microsoft.ML.CommandLine; +using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Runtime; namespace Microsoft.ML.Data @@ -18,6 +19,9 @@ namespace Microsoft.ML.Data /// internal abstract class RowToRowScorerBase : RowToRowMapperTransformBase, IDataScorerTransform { + private static readonly FuncStaticMethodInfo1 _getGetterFromRowMethodInfo + = new FuncStaticMethodInfo1(GetGetterFromRow); + [BestFriend] private protected abstract class BindingsBase : ScorerBindingsBase { @@ -204,9 +208,7 @@ protected static Delegate GetGetterFromRow(DataViewRow row, int col) Contracts.Assert(row.IsColumnActive(row.Schema[col])); var type = row.Schema[col].Type; - Func> del = GetGetterFromRow; - var meth = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(type.RawType); - return (Delegate)meth.Invoke(null, new object[] { row, col }); + return Utils.MarshalInvoke(_getGetterFromRowMethodInfo, type.RawType, row, col); } protected static ValueGetter GetGetterFromRow(DataViewRow output, int col) diff --git a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs index 77ae754110..c8104cd502 100644 --- a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs +++ b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs @@ -34,6 +34,9 @@ namespace Microsoft.ML.Data internal abstract class SchemaBindablePredictorWrapperBase : ISchemaBindableMapper, ICanSaveModel, ICanSaveSummary, IBindableCanSavePfa, IBindableCanSaveOnnx { + private static readonly FuncInstanceMethodInfo2 _getValueGetterMethodInfo + = FuncInstanceMethodInfo2.Create(target => target.GetValueGetter); + // The ctor guarantees that Predictor is non-null. It also ensures that either // ValueMapper or FloatPredictor is non-null (or both). With these guarantees, // the score value type (_scoreType) can be determined. @@ -157,9 +160,7 @@ protected virtual Delegate GetPredictionGetter(DataViewRow input, int colSrc) Contracts.Assert(0 <= colSrc && colSrc < input.Schema.Count); var typeSrc = input.Schema[colSrc].Type; - Func> del = GetValueGetter; - var meth = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(typeSrc.RawType, ScoreType.RawType); - return (Delegate)meth.Invoke(this, new object[] { input, colSrc }); + return Utils.MarshalInvoke(_getValueGetterMethodInfo, this, typeSrc.RawType, ScoreType.RawType, input, colSrc); } private ValueGetter GetValueGetter(DataViewRow input, int colSrc) diff --git a/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs b/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs index 2144ef33bf..a8b9b8dc18 100644 --- a/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs +++ b/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs @@ -195,6 +195,9 @@ private protected override IRowMapper MakeRowMapper(DataViewSchema inputSchema) private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx { + private static readonly FuncStaticMethodInfo1 _makeGetterMethodInfo + = new FuncStaticMethodInfo1(MakeGetter); + private readonly DataViewSchema _schema; private readonly (string outputColumnName, string inputColumnName)[] _columns; @@ -213,14 +216,14 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func(DataViewRow row, int index) - => input.GetGetter(input.Schema[index]); - input.Schema.TryGetColumnIndex(_columns[iinfo].inputColumnName, out int colIndex); var type = input.Schema[colIndex].Type; - return Utils.MarshalInvoke(MakeGetter, type.RawType, input, colIndex); + return Utils.MarshalInvoke(_makeGetterMethodInfo, type.RawType, input, colIndex); } + private static Delegate MakeGetter(DataViewRow row, int index) + => row.GetGetter(row.Schema[index]); + protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore() { var result = new DataViewSchema.DetachedColumn[_columns.Length]; diff --git a/src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransformer.cs b/src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransformer.cs index be7e3fc551..87b8ff18a0 100644 --- a/src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransformer.cs @@ -164,6 +164,9 @@ private protected override IRowMapper MakeRowMapper(DataViewSchema schema) private class Mapper : OneToOneMapperBase { + private static readonly FuncInstanceMethodInfo1 _getValueGetterMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.GetValueGetter); + private readonly FeatureContributionCalculatingTransformer _parent; private readonly VBuffer> _slotNames; private readonly int _featureColumnIndex; @@ -204,7 +207,7 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func. // For multiclass LR it needs to be VBuffer[]. - return Utils.MarshalInvoke(GetValueGetter, _featureColumnType.RawType, input, ColMapNewToOld[iinfo]); + return Utils.MarshalInvoke(_getValueGetterMethodInfo, this, _featureColumnType.RawType, input, ColMapNewToOld[iinfo]); } private Delegate GetValueGetter(DataViewRow input, int colSrc) diff --git a/src/Microsoft.ML.Data/Transforms/NAFilter.cs b/src/Microsoft.ML.Data/Transforms/NAFilter.cs index 5e34f3e8c6..21a52335ac 100644 --- a/src/Microsoft.ML.Data/Transforms/NAFilter.cs +++ b/src/Microsoft.ML.Data/Transforms/NAFilter.cs @@ -250,6 +250,12 @@ private sealed class Cursor : LinkedRowFilterCursorBase { private abstract class Value { + private static readonly FuncStaticMethodInfo1 _createOneMethodInfo + = new FuncStaticMethodInfo1(CreateOne); + + private static readonly FuncStaticMethodInfo1 _createVecMethodInfo + = new FuncStaticMethodInfo1(CreateVec); + protected readonly Cursor Cursor; protected Value(Cursor cursor) @@ -267,18 +273,20 @@ public static Value Create(Cursor cursor, ColInfo info) Contracts.AssertValue(cursor); Contracts.AssertValue(info); - MethodInfo meth; + FuncStaticMethodInfo1 method; + Type genericArgument; if (info.Type is VectorDataViewType vecType) { - Func> d = CreateVec; - meth = d.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(vecType.ItemType.RawType); + method = _createVecMethodInfo; + genericArgument = vecType.ItemType.RawType; } else { - Func> d = CreateOne; - meth = d.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(info.Type.RawType); + method = _createOneMethodInfo; + genericArgument = info.Type.RawType; } - return (Value)meth.Invoke(null, new object[] { cursor, info }); + + return Utils.MarshalInvoke(method, genericArgument, cursor, info); } private static ValueOne CreateOne(Cursor cursor, ColInfo info) diff --git a/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs b/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs index 60e32327ff..ac38f6c7fe 100644 --- a/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs @@ -451,6 +451,12 @@ private static readonly FuncInstanceMethodInfo1 _makeOneTrivia private static readonly FuncInstanceMethodInfo1 _makeVecTrivialGetterMethodInfo = FuncInstanceMethodInfo1.Create(target => target.MakeVecTrivialGetter); + private static readonly FuncInstanceMethodInfo1 _makeVecGetterMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.MakeVecGetter); + + private static readonly FuncInstanceMethodInfo1 _getSrcGetterMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.GetSrcGetter); + private readonly SlotsDroppingTransformer _parent; private readonly int[] _cols; private readonly DataViewType[] _srcTypes; @@ -780,9 +786,7 @@ private Delegate MakeVecGetter(DataViewRow input, int iinfo) VectorDataViewType vectorType = (VectorDataViewType)_srcTypes[iinfo]; Host.Assert(!_suppressed[iinfo]); - Func>> del = MakeVecGetter; - var methodInfo = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(vectorType.ItemType.RawType); - return (Delegate)methodInfo.Invoke(this, new object[] { input, iinfo }); + return Utils.MarshalInvoke(_makeVecGetterMethodInfo, this, vectorType.ItemType.RawType, input, iinfo); } private ValueGetter> MakeVecGetter(DataViewRow input, int iinfo) @@ -816,9 +820,7 @@ private Delegate GetSrcGetter(DataViewType typeDst, DataViewRow row, int iinfo) Host.CheckValue(typeDst, nameof(typeDst)); Host.CheckValue(row, nameof(row)); - Func> del = GetSrcGetter; - var methodInfo = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(typeDst.RawType); - return (Delegate)methodInfo.Invoke(this, new object[] { row, iinfo }); + return Utils.MarshalInvoke(_getSrcGetterMethodInfo, this, typeDst.RawType, row, iinfo); } protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore() diff --git a/src/Microsoft.ML.Data/Transforms/TransformBase.cs b/src/Microsoft.ML.Data/Transforms/TransformBase.cs index 2651af83a3..3fe8d0bc37 100644 --- a/src/Microsoft.ML.Data/Transforms/TransformBase.cs +++ b/src/Microsoft.ML.Data/Transforms/TransformBase.cs @@ -473,6 +473,9 @@ private sealed class ColumnTmp : OneToOneColumn { } + private static readonly FuncInstanceMethodInfo1 _getSrcGetterMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.GetSrcGetter); + private readonly Bindings _bindings; // The ColInfos are exposed to sub-classes. They should be considered readonly. @@ -700,9 +703,7 @@ protected Delegate GetSrcGetter(DataViewType typeDst, DataViewRow row, int iinfo Host.CheckValue(typeDst, nameof(typeDst)); Host.CheckValue(row, nameof(row)); - Func> del = GetSrcGetter; - var methodInfo = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(typeDst.RawType); - return (Delegate)methodInfo.Invoke(this, new object[] { row, iinfo }); + return Utils.MarshalInvoke(_getSrcGetterMethodInfo, this, typeDst.RawType, row, iinfo); } /// diff --git a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs index df9f0d29fd..cfc315187b 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs @@ -701,6 +701,9 @@ private protected override IRowMapper MakeRowMapper(DataViewSchema schema) private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx, ISaveAsPfa { + private static readonly FuncInstanceMethodInfo1 _makeGetterMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.MakeGetter); + private readonly DataViewType[] _types; private readonly ValueToKeyMappingTransformer _parent; private readonly ColInfo[] _infos; @@ -756,7 +759,7 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func, type.RawType, input, iinfo); + return Utils.MarshalInvoke(_makeGetterMethodInfo, this, type.RawType, input, iinfo); } private Delegate MakeGetter(DataViewRow row, int src) => _termMap[src].GetMappingGetter(row); diff --git a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformerImpl.cs b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformerImpl.cs index 120510de4a..9dbe1bb2da 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformerImpl.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformerImpl.cs @@ -20,6 +20,9 @@ public sealed partial class ValueToKeyMappingTransformer /// private abstract class Builder { + private static readonly FuncStaticMethodInfo1 _createCoreMethodInfo + = new FuncStaticMethodInfo1(CreateCore); + /// /// The item type we are building into a term map. /// @@ -51,7 +54,7 @@ public static Builder Create(DataViewType type, ValueToKeyMappingEstimator.KeyOr Contracts.AssertValue(itemType); if (itemType is TextDataViewType) return new TextImpl(sorted); - return Utils.MarshalInvoke(CreateCore, itemType.RawType, itemType, sorted); + return Utils.MarshalInvoke(_createCoreMethodInfo, itemType.RawType, itemType, sorted); } private static Builder CreateCore(PrimitiveDataViewType type, bool sorted) @@ -1074,6 +1077,12 @@ public override void AddMetadata(DataViewSchema.Annotations.Builder builder) /// private sealed class KeyImpl : Base { + private static readonly FuncInstanceMethodInfo1, DataViewType, DataViewSchema.Annotations.Builder, bool> _addMetadataCoreMethodInfo + = FuncInstanceMethodInfo1, DataViewType, DataViewSchema.Annotations.Builder, bool>.Create(target => target.AddMetadataCore); + + private static readonly FuncInstanceMethodInfo1, PrimitiveDataViewType, TextWriter, bool> _writeTextTermsCoreMethodInfo + = FuncInstanceMethodInfo1, PrimitiveDataViewType, TextWriter, bool>.Create(target => target.WriteTextTermsCore); + public KeyImpl(IHostEnvironment env, DataViewSchema schema, TermMap map, ColInfo[] infos, bool[] textMetadata, int iinfo) : base(env, schema, map, infos, textMetadata, iinfo) { @@ -1088,7 +1097,7 @@ public override void AddMetadata(DataViewSchema.Annotations.Builder builder) _schema.TryGetColumnIndex(_infos[_iinfo].InputColumnName, out int srcCol); VectorDataViewType srcMetaType = _schema[srcCol].Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.KeyValues)?.Type as VectorDataViewType; if (srcMetaType == null || srcMetaType.Size != TypedMap.ItemType.GetKeyCountAsInt32(_host) || - TypedMap.ItemType.GetKeyCountAsInt32(_host) == 0 || !Utils.MarshalInvoke(AddMetadataCore, srcMetaType.ItemType.RawType, srcMetaType.ItemType, builder)) + TypedMap.ItemType.GetKeyCountAsInt32(_host) == 0 || !Utils.MarshalInvoke(_addMetadataCoreMethodInfo, this, srcMetaType.ItemType.RawType, srcMetaType.ItemType, builder)) { // No valid input key-value metadata. Back off to the base implementation. base.AddMetadata(builder); @@ -1169,7 +1178,7 @@ public override void WriteTextTerms(TextWriter writer) _schema.TryGetColumnIndex(_infos[_iinfo].InputColumnName, out int srcCol); VectorDataViewType srcMetaType = _schema[srcCol].Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.KeyValues)?.Type as VectorDataViewType; if (srcMetaType == null || srcMetaType.Size != TypedMap.ItemType.GetKeyCountAsInt32(_host) || - TypedMap.ItemType.GetKeyCountAsInt32(_host) == 0 || !Utils.MarshalInvoke(WriteTextTermsCore, srcMetaType.ItemType.RawType, srcMetaType.ItemType, writer)) + TypedMap.ItemType.GetKeyCountAsInt32(_host) == 0 || !Utils.MarshalInvoke(_writeTextTermsCoreMethodInfo, this, srcMetaType.ItemType.RawType, srcMetaType.ItemType, writer)) { // No valid input key-value metadata. Back off to the base implementation. base.WriteTextTerms(writer); diff --git a/src/Microsoft.ML.Data/Utils/ApiUtils.cs b/src/Microsoft.ML.Data/Utils/ApiUtils.cs index 6b5e5531ad..baece31a27 100644 --- a/src/Microsoft.ML.Data/Utils/ApiUtils.cs +++ b/src/Microsoft.ML.Data/Utils/ApiUtils.cs @@ -19,7 +19,16 @@ namespace Microsoft.ML internal static class ApiUtils { - private static readonly FuncStaticMethodInfo3 _generatePokeMethodInfo + private static readonly FuncStaticMethodInfo3 _generatePeekFieldMethodInfo + = new FuncStaticMethodInfo3(GeneratePeek); + + private static readonly FuncStaticMethodInfo3 _generatePeekPropertyMethodInfo + = new FuncStaticMethodInfo3(GeneratePeek); + + private static readonly FuncStaticMethodInfo3 _generatePokeFieldMethodInfo + = new FuncStaticMethodInfo3(GeneratePoke); + + private static readonly FuncStaticMethodInfo3 _generatePokePropertyMethodInfo = new FuncStaticMethodInfo3(GeneratePoke); private static OpCode GetAssignmentOpCode(Type t, IEnumerable attributes) @@ -62,21 +71,13 @@ internal static Delegate GeneratePeek(InternalSchemaDefinition.Colum { case FieldInfo fieldInfo: Type fieldType = fieldInfo.FieldType; - var assignmentOpCode = GetAssignmentOpCode(fieldType, fieldInfo.GetCustomAttributes()); - Func func = GeneratePeek; - var methInfo = func.GetMethodInfo().GetGenericMethodDefinition() - .MakeGenericMethod(typeof(TOwn), typeof(TRow), fieldType); - return (Delegate)methInfo.Invoke(null, new object[] { fieldInfo, assignmentOpCode }); + return Utils.MarshalInvoke(_generatePeekFieldMethodInfo, typeof(TOwn), typeof(TRow), fieldType, fieldInfo, assignmentOpCode); case PropertyInfo propertyInfo: Type propertyType = propertyInfo.PropertyType; - var assignmentOpCodeProp = GetAssignmentOpCode(propertyType, propertyInfo.GetCustomAttributes()); - Func funcProp = GeneratePeek; - var methInfoProp = funcProp.GetMethodInfo().GetGenericMethodDefinition() - .MakeGenericMethod(typeof(TOwn), typeof(TRow), propertyType); - return (Delegate)methInfoProp.Invoke(null, new object[] { propertyInfo, assignmentOpCodeProp }); + return Utils.MarshalInvoke(_generatePeekPropertyMethodInfo, typeof(TOwn), typeof(TRow), propertyType, propertyInfo, assignmentOpCodeProp); default: Contracts.Assert(false); @@ -138,18 +139,13 @@ internal static Delegate GeneratePoke(InternalSchemaDefinition.Colum { case FieldInfo fieldInfo: Type fieldType = fieldInfo.FieldType; - var assignmentOpCode = GetAssignmentOpCode(fieldType, fieldInfo.GetCustomAttributes()); - Func func = GeneratePoke; - var methInfo = func.GetMethodInfo().GetGenericMethodDefinition() - .MakeGenericMethod(typeof(TOwn), typeof(TRow), fieldType); - return (Delegate)methInfo.Invoke(null, new object[] { fieldInfo, assignmentOpCode }); + return Utils.MarshalInvoke(_generatePokeFieldMethodInfo, typeof(TOwn), typeof(TRow), fieldType, fieldInfo, assignmentOpCode); case PropertyInfo propertyInfo: Type propertyType = propertyInfo.PropertyType; - var assignmentOpCodeProp = GetAssignmentOpCode(propertyType, propertyInfo.GetCustomAttributes()); - return Utils.MarshalInvoke(_generatePokeMethodInfo, typeof(TOwn), typeof(TRow), propertyType, propertyInfo); + return Utils.MarshalInvoke(_generatePokePropertyMethodInfo, typeof(TOwn), typeof(TRow), propertyType, propertyInfo); default: Contracts.Assert(false); diff --git a/src/Microsoft.ML.Featurizers/CategoricalImputer.cs b/src/Microsoft.ML.Featurizers/CategoricalImputer.cs index b765bf1758..a297ebf1e5 100644 --- a/src/Microsoft.ML.Featurizers/CategoricalImputer.cs +++ b/src/Microsoft.ML.Featurizers/CategoricalImputer.cs @@ -737,6 +737,8 @@ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffe private sealed class Mapper : MapperBase { + private static readonly FuncInstanceMethodInfo1 _makeGetterMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.MakeGetter); #region Class data members @@ -776,7 +778,7 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func, inputType, input, iinfo); + return Utils.MarshalInvoke(_makeGetterMethodInfo, this, inputType, input, iinfo); } private protected override Func GetDependenciesCore(Func activeOutput) diff --git a/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs b/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs index 734f8e58bf..014961104b 100644 --- a/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs +++ b/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs @@ -701,6 +701,8 @@ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffe private sealed class Mapper : MapperBase { + private static readonly FuncInstanceMethodInfo2 _makeGetterMethodInfo + = FuncInstanceMethodInfo2.Create(target => target.MakeGetter); #region Class data members private static readonly DateTime _unixEpoch = new DateTime(1970, 1, 1); @@ -825,7 +827,7 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func, new Type[] { input.Schema[_parent._column.Source].Type.RawType, ((DateTimeEstimator.ColumnsProduced)iinfo + 1).GetRawColumnType() }, input, iinfo); + return Utils.MarshalInvoke(_makeGetterMethodInfo, this, input.Schema[_parent._column.Source].Type.RawType, ((DateTimeEstimator.ColumnsProduced)iinfo + 1).GetRawColumnType(), input, iinfo); } diff --git a/src/Microsoft.ML.Featurizers/RobustScaler.cs b/src/Microsoft.ML.Featurizers/RobustScaler.cs index 33c5892eb9..7e59b8df28 100644 --- a/src/Microsoft.ML.Featurizers/RobustScaler.cs +++ b/src/Microsoft.ML.Featurizers/RobustScaler.cs @@ -1559,6 +1559,8 @@ public override Type ReturnType() private sealed class Mapper : MapperBase { + private static readonly FuncInstanceMethodInfo2 _makeGetterMethodInfo + = FuncInstanceMethodInfo2.Create(target => target.MakeGetter); #region Class data members @@ -1599,7 +1601,7 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func, new Type[] { inputType, outputType }, input, iinfo); + return Utils.MarshalInvoke(_makeGetterMethodInfo, this, inputType, outputType, input, iinfo); } private protected override Func GetDependenciesCore(Func activeOutput) diff --git a/src/Microsoft.ML.Featurizers/ToStringTransformer.cs b/src/Microsoft.ML.Featurizers/ToStringTransformer.cs index d892b8f9a2..87a275f094 100644 --- a/src/Microsoft.ML.Featurizers/ToStringTransformer.cs +++ b/src/Microsoft.ML.Featurizers/ToStringTransformer.cs @@ -1456,6 +1456,8 @@ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffe private sealed class Mapper : MapperBase { + private static readonly FuncInstanceMethodInfo1 _makeGetterMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.MakeGetter); #region Class data members @@ -1495,7 +1497,7 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func, inputType, input, iinfo); + return Utils.MarshalInvoke(_makeGetterMethodInfo, this, inputType, input, iinfo); } private protected override Func GetDependenciesCore(Func activeOutput) diff --git a/src/Microsoft.ML.Parquet/PartitionedFileLoader.cs b/src/Microsoft.ML.Parquet/PartitionedFileLoader.cs index 2b7496663b..106bd27cb0 100644 --- a/src/Microsoft.ML.Parquet/PartitionedFileLoader.cs +++ b/src/Microsoft.ML.Parquet/PartitionedFileLoader.cs @@ -367,6 +367,9 @@ private sealed class Cursor : RootCursorBase private static readonly FuncInstanceMethodInfo1 _createSubGetterDelegateCoreMethodInfo = FuncInstanceMethodInfo1.Create(target => target.CreateSubGetterDelegateCore); + private static readonly FuncInstanceMethodInfo1 _createGetterDelegateCoreMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.CreateGetterDelegateCore); + private PartitionedFileLoader _parent; private readonly bool[] _active; @@ -589,7 +592,7 @@ private Delegate[] CreateGetters() else { int idx = i - SubColumnCount; - getters[i] = Utils.MarshalInvoke(CreateGetterDelegateCore, type.RawType, idx, type); + getters[i] = Utils.MarshalInvoke(_createGetterDelegateCoreMethodInfo, this, type.RawType, idx, type); } } diff --git a/src/Microsoft.ML.Transforms/GroupTransform.cs b/src/Microsoft.ML.Transforms/GroupTransform.cs index 1f37f5d59e..7301fa692d 100644 --- a/src/Microsoft.ML.Transforms/GroupTransform.cs +++ b/src/Microsoft.ML.Transforms/GroupTransform.cs @@ -392,6 +392,9 @@ private sealed class Cursor : RootCursorBase /// private sealed class GroupKeyColumnChecker { + private static readonly FuncStaticMethodInfo1> _makeSameCheckerMethodInfo + = new FuncStaticMethodInfo1>(MakeSameChecker); + public readonly Func IsSameKey; private static Func MakeSameChecker(DataViewRow row, int col) @@ -425,9 +428,7 @@ public GroupKeyColumnChecker(DataViewRow row, int col) Contracts.AssertValue(row); var type = row.Schema[col].Type; - Func> del = MakeSameChecker; - var mi = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(type.RawType); - IsSameKey = (Func)mi.Invoke(null, new object[] { row, col }); + IsSameKey = Utils.MarshalInvoke(_makeSameCheckerMethodInfo, type.RawType, row, col); } } diff --git a/src/Microsoft.ML.Transforms/HashJoiningTransform.cs b/src/Microsoft.ML.Transforms/HashJoiningTransform.cs index 28f49eaa20..7242b129b7 100644 --- a/src/Microsoft.ML.Transforms/HashJoiningTransform.cs +++ b/src/Microsoft.ML.Transforms/HashJoiningTransform.cs @@ -454,9 +454,14 @@ private void GetSlotNames(int iinfo, ref VBuffer> dst) private delegate uint HashDelegate(in TSrc value, uint seed); // generic method generators - private static MethodInfo _methGetterOneToOne; - private static MethodInfo _methGetterVecToVec; - private static MethodInfo _methGetterVecToOne; + private static readonly FuncInstanceMethodInfo1 _composeGetterOneToOneMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.ComposeGetterOneToOne); + + private static readonly FuncInstanceMethodInfo1 _composeGetterVecToVecMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.ComposeGetterVecToVec); + + private static readonly FuncInstanceMethodInfo1 _composeGetterVecToOneMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.ComposeGetterVecToOne); protected override Delegate GetGetterCore(IChannel ch, DataViewRow input, int iinfo, out Action disposer) { @@ -465,45 +470,27 @@ protected override Delegate GetGetterCore(IChannel ch, DataViewRow input, int ii Host.Assert(0 <= iinfo && iinfo < Infos.Length); disposer = null; - // Construct MethodInfos templates that we need for the generic methods. - if (_methGetterOneToOne == null) - { - Func> del = ComposeGetterOneToOne; - Interlocked.CompareExchange(ref _methGetterOneToOne, del.GetMethodInfo().GetGenericMethodDefinition(), null); - } - if (_methGetterVecToVec == null) - { - Func>> del = ComposeGetterVecToVec; - Interlocked.CompareExchange(ref _methGetterVecToVec, del.GetMethodInfo().GetGenericMethodDefinition(), null); - } - if (_methGetterVecToOne == null) - { - Func> del = ComposeGetterVecToOne; - Interlocked.CompareExchange(ref _methGetterVecToOne, del.GetMethodInfo().GetGenericMethodDefinition(), null); - } - // Magic code to generate a correct getter. // First, we take a method info for GetGetter // Then, we replace with correct type of the input // And then we generate a delegate using the generic delegate generator DataViewType itemType; - MethodInfo mi; + FuncInstanceMethodInfo1 mi; if (!(Infos[iinfo].TypeSrc is VectorDataViewType vectorType)) { itemType = Infos[iinfo].TypeSrc; - mi = _methGetterOneToOne; + mi = _composeGetterOneToOneMethodInfo; } else { itemType = vectorType.ItemType; if (_exes[iinfo].OutputValueCount == 1) - mi = _methGetterVecToOne; + mi = _composeGetterVecToOneMethodInfo; else - mi = _methGetterVecToVec; + mi = _composeGetterVecToVecMethodInfo; } - mi = mi.MakeGenericMethod(itemType.RawType); - return (Delegate)mi.Invoke(this, new object[] { input, iinfo }); + return Utils.MarshalInvoke(mi, this, itemType.RawType, input, iinfo); } /// diff --git a/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs index de808b4843..73080dbc22 100644 --- a/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs +++ b/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs @@ -171,6 +171,9 @@ private protected override void SaveModel(ModelSaveContext ctx) private sealed class Mapper : OneToOneMapperBase { + private static readonly FuncInstanceMethodInfo1 _makeVecGetterMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.MakeVecGetter); + private readonly MissingValueDroppingTransformer _parent; private readonly DataViewType[] _srcTypes; @@ -210,15 +213,14 @@ protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore() } return result; } + protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func activeOutput, out Action disposer) { Contracts.AssertValue(input); Contracts.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); disposer = null; - Func>> del = MakeVecGetter; - var methodInfo = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(_srcTypes[iinfo].GetItemType().RawType); - return (Delegate)methodInfo.Invoke(this, new object[] { input, iinfo }); + return Utils.MarshalInvoke(_makeVecGetterMethodInfo, this, _srcTypes[iinfo].GetItemType().RawType, input, iinfo); } private ValueGetter> MakeVecGetter(DataViewRow input, int iinfo) diff --git a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs index 6ba7fafd48..dde8e81c06 100644 --- a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs +++ b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs @@ -146,6 +146,12 @@ private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx private static readonly FuncStaticMethodInfo1 _getIsNADelegateMethodInfo = new FuncStaticMethodInfo1(GetIsNADelegate); + private static readonly FuncInstanceMethodInfo1> _composeGetterOneMethodInfo + = FuncInstanceMethodInfo1>.Create(target => target.ComposeGetterOne); + + private static readonly FuncInstanceMethodInfo1>> _composeGetterVecMethodInfo + = FuncInstanceMethodInfo1>>.Create(target => target.ComposeGetterVec); + private readonly MissingValueIndicatorTransformer _parent; private readonly ColInfo[] _infos; @@ -241,7 +247,7 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func private ValueGetter ComposeGetterOne(DataViewRow input, int iinfo) - => Utils.MarshalInvoke(ComposeGetterOne, _infos[iinfo].InputType.RawType, input, iinfo); + => Utils.MarshalInvoke(_composeGetterOneMethodInfo, this, _infos[iinfo].InputType.RawType, input, iinfo); private ValueGetter ComposeGetterOne(DataViewRow input, int iinfo) { @@ -263,7 +269,7 @@ private ValueGetter ComposeGetterOne(DataViewRow input, int iinfo) /// Getter generator for vector valued inputs. /// private ValueGetter> ComposeGetterVec(DataViewRow input, int iinfo) - => Utils.MarshalInvoke(ComposeGetterVec, _infos[iinfo].InputType.GetItemType().RawType, input, iinfo); + => Utils.MarshalInvoke(_composeGetterVecMethodInfo, this, _infos[iinfo].InputType.GetItemType().RawType, input, iinfo); private ValueGetter> ComposeGetterVec(DataViewRow input, int iinfo) { diff --git a/src/Microsoft.ML.Transforms/MissingValueReplacing.cs b/src/Microsoft.ML.Transforms/MissingValueReplacing.cs index 0759232564..56c8a857b7 100644 --- a/src/Microsoft.ML.Transforms/MissingValueReplacing.cs +++ b/src/Microsoft.ML.Transforms/MissingValueReplacing.cs @@ -38,6 +38,9 @@ namespace Microsoft.ML.Transforms // REVIEW: May make sense to implement the transform template interface. public sealed partial class MissingValueReplacingTransformer : OneToOneTransformerBase { + private static readonly FuncInstanceMethodInfo1 _computeDefaultSlotsMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.ComputeDefaultSlots); + internal enum ReplacementKind : byte { // REVIEW: What should the full list of options for this transform be? @@ -354,21 +357,20 @@ private void GetReplacementValues(IDataView input, MissingValueReplacingEstimato int slot = columnsToImpute[ii]; if (repValues[slot] is Array) { - Func func = ComputeDefaultSlots; - var meth = func.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(types[slot].GetItemType().RawType); - slotIsDefault[slot] = (BitArray)meth.Invoke(this, new object[] { types[slot], repValues[slot] }); + slotIsDefault[slot] = Utils.MarshalInvoke(_computeDefaultSlotsMethodInfo, this, types[slot].GetItemType().RawType, types[slot], (Array)repValues[slot]); } } } - private BitArray ComputeDefaultSlots(DataViewType type, T[] values) + private BitArray ComputeDefaultSlots(DataViewType type, Array values) { Host.Assert(values.Length == type.GetVectorSize()); BitArray defaultSlots = new BitArray(values.Length); InPredicate defaultPred = Data.Conversion.Conversions.Instance.GetIsDefaultPredicate(type.GetItemType()); + T[] typedValues = (T[])values; for (int slot = 0; slot < values.Length; slot++) { - if (defaultPred(in values[slot])) + if (defaultPred(in typedValues[slot])) defaultSlots[slot] = true; } return defaultSlots; @@ -536,6 +538,12 @@ public ColInfo(string outputColumnName, string inputColumnName, DataViewType typ } } + private static readonly FuncInstanceMethodInfo1 _composeGetterOneMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.ComposeGetterOne); + + private static readonly FuncInstanceMethodInfo1 _composeGetterVecMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.ComposeGetterVec); + private readonly MissingValueReplacingTransformer _parent; private readonly ColInfo[] _infos; private readonly DataViewType[] _types; @@ -624,7 +632,7 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func private Delegate ComposeGetterOne(DataViewRow input, int iinfo) - => Utils.MarshalInvoke(ComposeGetterOne, _infos[iinfo].TypeSrc.RawType, input, iinfo); + => Utils.MarshalInvoke(_composeGetterOneMethodInfo, this, _infos[iinfo].TypeSrc.RawType, input, iinfo); /// /// Replaces NA values for scalars. @@ -650,7 +658,7 @@ private Delegate ComposeGetterOne(DataViewRow input, int iinfo) /// Getter generator for vector valued inputs. /// private Delegate ComposeGetterVec(DataViewRow input, int iinfo) - => Utils.MarshalInvoke(ComposeGetterVec, _infos[iinfo].TypeSrc.GetItemType().RawType, input, iinfo); + => Utils.MarshalInvoke(_composeGetterVecMethodInfo, this, _infos[iinfo].TypeSrc.GetItemType().RawType, input, iinfo); /// /// Replaces NA values for vectors. diff --git a/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs b/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs index a63dea2fbe..98b100c15f 100644 --- a/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs +++ b/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs @@ -245,6 +245,9 @@ private static VersionInfo GetVersionInfo() loaderAssemblyName: typeof(OptionalColumnTransform).Assembly.FullName); } + private static readonly FuncInstanceMethodInfo1 _getSrcGetterMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.GetSrcGetter); + private static readonly FuncInstanceMethodInfo1 _makeGetterOneMethodInfo = FuncInstanceMethodInfo1.Create(target => target.MakeGetterOne); @@ -391,9 +394,7 @@ protected override Delegate[] CreateGetters(DataViewRow input, IEnumerable> srcDel = GetSrcGetter; - var meth = srcDel.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(_bindings.ColumnTypes[iinfo].GetItemType().RawType); - getters[iinfo] = (Delegate)meth.Invoke(this, new object[] { input, iinfo }); + getters[iinfo] = Utils.MarshalInvoke(_getSrcGetterMethodInfo, this, _bindings.ColumnTypes[iinfo].GetItemType().RawType, input, iinfo); } } return getters;