diff --git a/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo1`3.cs b/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo1`3.cs new file mode 100644 index 0000000000..9117a1b488 --- /dev/null +++ b/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo1`3.cs @@ -0,0 +1,90 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#nullable enable + +using System; +using System.Linq.Expressions; +using System.Reflection; +using Microsoft.ML.Runtime; + +namespace Microsoft.ML.Internal.Utilities +{ + /// + /// Represents the for a generic function corresponding to , + /// with the following characteristics: + /// + /// + /// The method is an instance method on an object of type . + /// One generic type argument. + /// A return value of . + /// + /// + /// The type of the receiver of the instance method. + /// The type of the parameter of the method. + /// The type of the return value of the method. + internal sealed class FuncInstanceMethodInfo1 : FuncMethodInfo1 + where TTarget : class + { + private static readonly string _targetTypeCheckMessage = $"Should have a target type of '{typeof(TTarget)}'"; + + public FuncInstanceMethodInfo1(Func function) + : this(function.Method) + { + } + + private FuncInstanceMethodInfo1(MethodInfo methodInfo) + : base(methodInfo) + { + Contracts.CheckParam(!GenericMethodDefinition.IsStatic, nameof(methodInfo), "Should be an instance method"); + Contracts.CheckParam(GenericMethodDefinition.DeclaringType == typeof(TTarget), nameof(methodInfo), _targetTypeCheckMessage); + } + + /// + /// Creates a representing the + /// for a generic instance method. This helper method allows the instance to be created prior to the creation of + /// any instances of the target type. The following example shows the creation of an instance representing the + /// method: + /// + /// + /// FuncInstanceMethodInfo1<object, object, int>.Create(obj => obj.Equals) + /// + /// + /// The expression which creates the delegate for an instance of the target type. + /// A representing the + /// for the generic instance method. + public static FuncInstanceMethodInfo1 Create(Expression>> expression) + { + if (!(expression is { Body: UnaryExpression { Operand: MethodCallExpression methodCallExpression } })) + { + throw Contracts.ExceptParam(nameof(expression), "Unexpected expression form"); + } + + // Verify that we are calling MethodInfo.CreateDelegate(Type, object) + Contracts.CheckParam(methodCallExpression.Method.DeclaringType == typeof(MethodInfo), nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Method.Name == nameof(MethodInfo.CreateDelegate), nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Method.GetParameters().Length == 2, nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Method.GetParameters()[0].ParameterType == typeof(Type), nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Method.GetParameters()[1].ParameterType == typeof(object), nameof(expression), "Unexpected expression form"); + + // Verify that we are creating a delegate of type Func + Contracts.CheckParam(methodCallExpression.Arguments.Count == 2, nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Arguments[0] is ConstantExpression, nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(((ConstantExpression)methodCallExpression.Arguments[0]).Type == typeof(Type), nameof(expression), "Unexpected expression form"); + Contracts.CheckParam((Type)((ConstantExpression)methodCallExpression.Arguments[0]).Value == typeof(Func), nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Arguments[1] is ParameterExpression, nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Arguments[1] == expression.Parameters[0], nameof(expression), "Unexpected expression form"); + + // Check the MethodInfo + Contracts.CheckParam(methodCallExpression.Object is ConstantExpression, nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(((ConstantExpression)methodCallExpression.Object).Type == typeof(MethodInfo), nameof(expression), "Unexpected expression form"); + + var methodInfo = (MethodInfo)((ConstantExpression)methodCallExpression.Object).Value; + Contracts.CheckParam(expression.Body is UnaryExpression, nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(((UnaryExpression)expression.Body).Operand is MethodCallExpression, nameof(expression), "Unexpected expression form"); + + return new FuncInstanceMethodInfo1(methodInfo); + } + } +} diff --git a/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo3`3.cs b/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo3`3.cs new file mode 100644 index 0000000000..91c1c0d747 --- /dev/null +++ b/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo3`3.cs @@ -0,0 +1,90 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#nullable enable + +using System; +using System.Linq.Expressions; +using System.Reflection; +using Microsoft.ML.Runtime; + +namespace Microsoft.ML.Internal.Utilities +{ + /// + /// Represents the for a generic function corresponding to , + /// with the following characteristics: + /// + /// + /// The method is an instance method on an object of type . + /// Three generic type arguments. + /// A return value of . + /// + /// + /// The type of the receiver of the instance method. + /// The type of the parameter of the method. + /// The type of the return value of the method. + internal sealed class FuncInstanceMethodInfo3 : FuncMethodInfo3 + where TTarget : class + { + private static readonly string _targetTypeCheckMessage = $"Should have a target type of '{typeof(TTarget)}'"; + + public FuncInstanceMethodInfo3(Func function) + : this(function.Method) + { + } + + private FuncInstanceMethodInfo3(MethodInfo methodInfo) + : base(methodInfo) + { + Contracts.CheckParam(!GenericMethodDefinition.IsStatic, nameof(methodInfo), "Should be an instance method"); + Contracts.CheckParam(GenericMethodDefinition.DeclaringType == typeof(TTarget), nameof(methodInfo), _targetTypeCheckMessage); + } + + /// + /// Creates a representing the + /// for a generic instance method. This helper method allows the instance to be created prior to the creation of + /// any instances of the target type. The following example shows the creation of an instance representing the + /// method: + /// + /// + /// FuncInstanceMethodInfo1<object, object, int>.Create(obj => obj.Equals) + /// + /// + /// The expression which creates the delegate for an instance of the target type. + /// A representing the + /// for the generic instance method. + public static FuncInstanceMethodInfo3 Create(Expression>> expression) + { + if (!(expression is { Body: UnaryExpression { Operand: MethodCallExpression methodCallExpression } })) + { + throw Contracts.ExceptParam(nameof(expression), "Unexpected expression form"); + } + + // Verify that we are calling MethodInfo.CreateDelegate(Type, object) + Contracts.CheckParam(methodCallExpression.Method.DeclaringType == typeof(MethodInfo), nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Method.Name == nameof(MethodInfo.CreateDelegate), nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Method.GetParameters().Length == 2, nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Method.GetParameters()[0].ParameterType == typeof(Type), nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Method.GetParameters()[1].ParameterType == typeof(object), nameof(expression), "Unexpected expression form"); + + // Verify that we are creating a delegate of type Func + Contracts.CheckParam(methodCallExpression.Arguments.Count == 2, nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Arguments[0] is ConstantExpression, nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(((ConstantExpression)methodCallExpression.Arguments[0]).Type == typeof(Type), nameof(expression), "Unexpected expression form"); + Contracts.CheckParam((Type)((ConstantExpression)methodCallExpression.Arguments[0]).Value == typeof(Func), nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Arguments[1] is ParameterExpression, nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Arguments[1] == expression.Parameters[0], nameof(expression), "Unexpected expression form"); + + // Check the MethodInfo + Contracts.CheckParam(methodCallExpression.Object is ConstantExpression, nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(((ConstantExpression)methodCallExpression.Object).Type == typeof(MethodInfo), nameof(expression), "Unexpected expression form"); + + var methodInfo = (MethodInfo)((ConstantExpression)methodCallExpression.Object).Value; + Contracts.CheckParam(expression.Body is UnaryExpression, nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(((UnaryExpression)expression.Body).Operand is MethodCallExpression, nameof(expression), "Unexpected expression form"); + + return new FuncInstanceMethodInfo3(methodInfo); + } + } +} diff --git a/src/Microsoft.ML.Core/Utilities/FuncMethodInfo1`2.cs b/src/Microsoft.ML.Core/Utilities/FuncMethodInfo1`2.cs new file mode 100644 index 0000000000..a93d3f8485 --- /dev/null +++ b/src/Microsoft.ML.Core/Utilities/FuncMethodInfo1`2.cs @@ -0,0 +1,47 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#nullable enable + +using System; +using System.Collections.Immutable; +using System.Reflection; +using Microsoft.ML.Runtime; + +namespace Microsoft.ML.Internal.Utilities +{ + /// + /// Represents the for a generic function corresponding to , + /// with the following characteristics: + /// + /// + /// One generic type argument. + /// A return value of . + /// + /// + /// The type of the parameter of the method. + /// The type of the return value of the method. + internal abstract class FuncMethodInfo1 : FuncMethodInfo + { + private ImmutableDictionary _instanceMethodInfo; + + private protected FuncMethodInfo1(MethodInfo methodInfo) + : base(methodInfo) + { + _instanceMethodInfo = ImmutableDictionary.Empty; + + Contracts.CheckParam(GenericMethodDefinition.GetGenericArguments().Length == 1, nameof(methodInfo), + "Should have exactly one generic type parameter but does not"); + } + + public MethodInfo MakeGenericMethod(Type typeArg1) + { + return ImmutableInterlocked.GetOrAdd( + ref _instanceMethodInfo, + typeArg1, + (typeArg, methodInfo) => methodInfo.MakeGenericMethod(typeArg), + GenericMethodDefinition); + } + } +} diff --git a/src/Microsoft.ML.Core/Utilities/FuncMethodInfo3`2.cs b/src/Microsoft.ML.Core/Utilities/FuncMethodInfo3`2.cs new file mode 100644 index 0000000000..eea8c6b824 --- /dev/null +++ b/src/Microsoft.ML.Core/Utilities/FuncMethodInfo3`2.cs @@ -0,0 +1,47 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#nullable enable + +using System; +using System.Collections.Immutable; +using System.Reflection; +using Microsoft.ML.Runtime; + +namespace Microsoft.ML.Internal.Utilities +{ + /// + /// Represents the for a generic function corresponding to , + /// with the following characteristics: + /// + /// + /// Three generic type arguments. + /// A return value of . + /// + /// + /// The type of the parameter of the method. + /// The type of the return value of the method. + internal abstract class FuncMethodInfo3 : FuncMethodInfo + { + private ImmutableDictionary<(Type, Type, Type), MethodInfo> _instanceMethodInfo; + + private protected FuncMethodInfo3(MethodInfo methodInfo) + : base(methodInfo) + { + _instanceMethodInfo = ImmutableDictionary<(Type, Type, Type), MethodInfo>.Empty; + + Contracts.CheckParam(GenericMethodDefinition.GetGenericArguments().Length == 3, nameof(methodInfo), + "Should have exactly three generic type parameters but does not"); + } + + public MethodInfo MakeGenericMethod(Type typeArg1, Type typeArg2, Type typeArg3) + { + return ImmutableInterlocked.GetOrAdd( + ref _instanceMethodInfo, + (typeArg1, typeArg2, typeArg3), + (args, methodInfo) => methodInfo.MakeGenericMethod(args.Item1, args.Item2, args.Item3), + GenericMethodDefinition); + } + } +} diff --git a/src/Microsoft.ML.Core/Utilities/FuncMethodInfo`2.cs b/src/Microsoft.ML.Core/Utilities/FuncMethodInfo`2.cs new file mode 100644 index 0000000000..117453d4d7 --- /dev/null +++ b/src/Microsoft.ML.Core/Utilities/FuncMethodInfo`2.cs @@ -0,0 +1,26 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#nullable enable + +using System.Reflection; +using Microsoft.ML.Runtime; + +namespace Microsoft.ML.Internal.Utilities +{ + internal abstract class FuncMethodInfo + { + private protected FuncMethodInfo(MethodInfo methodInfo) + { + Contracts.CheckValue(methodInfo, nameof(methodInfo)); + + Contracts.CheckParam(methodInfo.IsGenericMethod, nameof(methodInfo), "Should be generic but is not"); + + GenericMethodDefinition = methodInfo.GetGenericMethodDefinition(); + Contracts.CheckParam(typeof(TResult).IsAssignableFrom(GenericMethodDefinition.ReturnType), nameof(methodInfo), "Cannot be generic on return type"); + } + + protected MethodInfo GenericMethodDefinition { get; } + } +} diff --git a/src/Microsoft.ML.Core/Utilities/FuncStaticMethodInfo1`2.cs b/src/Microsoft.ML.Core/Utilities/FuncStaticMethodInfo1`2.cs new file mode 100644 index 0000000000..d04f6534c1 --- /dev/null +++ b/src/Microsoft.ML.Core/Utilities/FuncStaticMethodInfo1`2.cs @@ -0,0 +1,33 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#nullable enable + +using System; +using System.Reflection; +using Microsoft.ML.Runtime; + +namespace Microsoft.ML.Internal.Utilities +{ + /// + /// Represents the for a generic function corresponding to , + /// with the following characteristics: + /// + /// + /// The method is static. + /// One generic type argument. + /// A return value of . + /// + /// + /// The type of the parameter of the method. + /// The type of the return value of the method. + internal sealed class FuncStaticMethodInfo1 : FuncMethodInfo1 + { + public FuncStaticMethodInfo1(Func function) + : base(function.Method) + { + Contracts.CheckParam(GenericMethodDefinition.IsStatic, nameof(function), "Should be a static method"); + } + } +} diff --git a/src/Microsoft.ML.Core/Utilities/FuncStaticMethodInfo3`2.cs b/src/Microsoft.ML.Core/Utilities/FuncStaticMethodInfo3`2.cs new file mode 100644 index 0000000000..7cb7e31779 --- /dev/null +++ b/src/Microsoft.ML.Core/Utilities/FuncStaticMethodInfo3`2.cs @@ -0,0 +1,33 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#nullable enable + +using System; +using System.Reflection; +using Microsoft.ML.Runtime; + +namespace Microsoft.ML.Internal.Utilities +{ + /// + /// Represents the for a generic function corresponding to , + /// with the following characteristics: + /// + /// + /// The method is static. + /// Three generic type arguments. + /// A return value of . + /// + /// + /// The type of the parameter of the method. + /// The type of the return value of the method. + internal sealed class FuncStaticMethodInfo3 : FuncMethodInfo3 + { + public FuncStaticMethodInfo3(Func function) + : base(function.Method) + { + Contracts.CheckParam(GenericMethodDefinition.IsStatic, nameof(function), "Should be a static method"); + } + } +} diff --git a/src/Microsoft.ML.Core/Utilities/Utils.cs b/src/Microsoft.ML.Core/Utilities/Utils.cs index 36ee1ffd89..4912d7f84d 100644 --- a/src/Microsoft.ML.Core/Utilities/Utils.cs +++ b/src/Microsoft.ML.Core/Utilities/Utils.cs @@ -1015,10 +1015,39 @@ public static TResult MarshalInvoke(FuncStaticMethodInfo1 func /// /// A one-argument version of . /// - public static TRet MarshalInvoke(Func func, Type genArg, TArg1 arg1) + public static TResult MarshalInvoke(FuncInstanceMethodInfo1 func, TTarget target, Type genArg, TArg1 arg1) + where TTarget : class { - var meth = MarshalInvokeCheckAndCreate(genArg, func); - return (TRet)meth.Invoke(func.Target, new object[] { arg1 }); + var meth = func.MakeGenericMethod(genArg); + return (TResult)meth.Invoke(target, new object[] { arg1 }); + } + + /// + /// A one-argument version of . + /// + public static TResult MarshalInvoke(FuncStaticMethodInfo1 func, Type genArg, TArg1 arg1) + { + var meth = func.MakeGenericMethod(genArg); + return (TResult)meth.Invoke(null, new object[] { arg1 }); + } + + /// + /// A one-argument, three-type-parameter version of . + /// + public static TResult MarshalInvoke(FuncInstanceMethodInfo3 func, TTarget target, Type genArg1, Type genArg2, Type genArg3, TArg1 arg1) + where TTarget : class + { + var meth = func.MakeGenericMethod(genArg1, genArg2, genArg3); + return (TResult)meth.Invoke(target, new object[] { arg1 }); + } + + /// + /// A one-argument, three-type-parameter version of . + /// + public static TResult MarshalInvoke(FuncStaticMethodInfo3 func, Type genArg1, Type genArg2, Type genArg3, TArg1 arg1) + { + var meth = func.MakeGenericMethod(genArg1, genArg2, genArg3); + return (TResult)meth.Invoke(null, new object[] { arg1 }); } /// @@ -1112,17 +1141,6 @@ public static TRet MarshalInvoke - /// A 1 argument and n type version of . - /// - public static TRet MarshalInvoke( - Func func, - Type[] genArgs, TArg1 arg1) - { - var meth = MarshalInvokeCheckAndCreate(genArgs, func); - return (TRet)meth.Invoke(func.Target, new object[] { arg1}); - } - /// /// A 2 argument and n type version of . /// diff --git a/src/Microsoft.ML.Data/Data/Conversion.cs b/src/Microsoft.ML.Data/Data/Conversion.cs index bdb450ac00..82de43bb48 100644 --- a/src/Microsoft.ML.Data/Data/Conversion.cs +++ b/src/Microsoft.ML.Data/Data/Conversion.cs @@ -48,6 +48,9 @@ namespace Microsoft.ML.Data.Conversion [BestFriend] internal sealed class Conversions { + private static readonly FuncInstanceMethodInfo1 _getKeyParseMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.GetKeyParse); + // REVIEW: Reconcile implementations with TypeUtils, and clarify the distinction. // Singleton pattern. @@ -546,9 +549,7 @@ private TryParseMapper GetKeyTryParse(KeyDataViewType key) private Delegate GetKeyParse(KeyDataViewType key) { - Func> del = GetKeyParse; - var meth = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(key.RawType); - return (Delegate)meth.Invoke(this, new object[] { key }); + return Utils.MarshalInvoke(_getKeyParseMethodInfo, this, key.RawType, key); } private ValueMapper GetKeyParse(KeyDataViewType key) diff --git a/src/Microsoft.ML.Data/Data/DataViewUtils.cs b/src/Microsoft.ML.Data/Data/DataViewUtils.cs index 8e464ebd2e..6462680c95 100644 --- a/src/Microsoft.ML.Data/Data/DataViewUtils.cs +++ b/src/Microsoft.ML.Data/Data/DataViewUtils.cs @@ -1138,6 +1138,9 @@ public override ValueGetter GetGetter(DataViewSchema.Column colu /// internal sealed class SynchronousConsolidatingCursor : RootCursorBase { + private static readonly FuncInstanceMethodInfo1 _createGetterMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.CreateGetter); + private readonly DataViewRowCursor[] _cursors; private readonly Delegate[] _getters; @@ -1145,7 +1148,6 @@ internal sealed class SynchronousConsolidatingCursor : RootCursorBase private readonly Heap _mins; private readonly int[] _activeToCol; private readonly int[] _colToActive; - private readonly MethodInfo _methInfo; // The batch number of the current input cursor, or -1 if this cursor is not in Good state. private long _batch; @@ -1182,9 +1184,6 @@ public SynchronousConsolidatingCursor(IChannelProvider provider, DataViewRowCurs Utils.BuildSubsetMaps(_schema, _cursors[0].IsColumnActive, out _activeToCol, out _colToActive); - Func func = CreateGetter; - _methInfo = func.GetMethodInfo().GetGenericMethodDefinition(); - _getters = new Delegate[_activeToCol.Length]; for (int i = 0; i < _activeToCol.Length; ++i) _getters[i] = CreateGetter(_activeToCol[i]); @@ -1238,8 +1237,7 @@ public override ValueGetter GetIdGetter() private Delegate CreateGetter(int col) { - var methInfo = _methInfo.MakeGenericMethod(Schema[col].Type.RawType); - return (Delegate)methInfo.Invoke(this, new object[] { col }); + return Utils.MarshalInvoke(_createGetterMethodInfo, this, Schema[col].Type.RawType, col); } private Delegate CreateGetter(int col) diff --git a/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseLoaderCursor.cs b/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseLoaderCursor.cs index b8e7ea87e5..f19d96830f 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseLoaderCursor.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseLoaderCursor.cs @@ -13,6 +13,9 @@ public sealed partial class DatabaseLoader { private sealed class Cursor : RootCursorBase { + private static readonly FuncInstanceMethodInfo1 _createGetterDelegateMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.CreateGetterDelegate); + private readonly Bindings _bindings; private readonly bool[] _active; // Which columns are active. private readonly DatabaseSource _source; @@ -163,7 +166,7 @@ public override ValueGetter GetGetter(DataViewSchema.Column colu private Delegate CreateGetterDelegate(int col) { - return Utils.MarshalInvoke(CreateGetterDelegate, _bindings.Infos[col].ColType.RawType, col); + return Utils.MarshalInvoke(_createGetterDelegateMethodInfo, this, _bindings.Infos[col].ColType.RawType, col); } private Delegate CreateGetterDelegate(int col) diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs index 0a157268bc..dd45e8c240 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs @@ -24,6 +24,12 @@ public sealed partial class TextLoader /// private sealed class ValueCreatorCache { + private static readonly FuncInstanceMethodInfo1> _getCreatorOneCoreMethodInfo + = FuncInstanceMethodInfo1>.Create(target => target.GetCreatorOneCore); + + private static readonly FuncInstanceMethodInfo1> _getCreatorVecCoreMethodInfo + = FuncInstanceMethodInfo1>.Create(target => target.GetCreatorVecCore); + private static volatile ValueCreatorCache _instance; public static ValueCreatorCache Instance { @@ -36,8 +42,6 @@ public static ValueCreatorCache Instance } private readonly Conversions _conv; - private readonly MethodInfo _methOne; - private readonly MethodInfo _methVec; // Indexed by DataKind.ToIndex() private readonly Func[] _creatorsOne; @@ -46,10 +50,6 @@ public static ValueCreatorCache Instance private ValueCreatorCache() { _conv = Conversions.Instance; - _methOne = new Func>(GetCreatorOneCore) - .GetMethodInfo().GetGenericMethodDefinition(); - _methVec = new Func>(GetCreatorVecCore) - .GetMethodInfo().GetGenericMethodDefinition(); _creatorsOne = new Func[InternalDataKindExtensions.KindCount]; _creatorsVec = new Func[InternalDataKindExtensions.KindCount]; @@ -63,8 +63,7 @@ private ValueCreatorCache() private Func GetCreatorOneCore(PrimitiveDataViewType type) { - MethodInfo meth = _methOne.MakeGenericMethod(type.RawType); - return (Func)meth.Invoke(this, new object[] { type }); + return Utils.MarshalInvoke(_getCreatorOneCoreMethodInfo, this, type.RawType, type); } private Func GetCreatorOneCore(PrimitiveDataViewType type) @@ -77,8 +76,7 @@ private Func GetCreatorOneCore(PrimitiveDataViewType type private Func GetCreatorVecCore(PrimitiveDataViewType type) { - MethodInfo meth = _methVec.MakeGenericMethod(type.RawType); - return (Func)meth.Invoke(this, new object[] { type }); + return Utils.MarshalInvoke(_getCreatorVecCoreMethodInfo, this, type.RawType, type); } private Func GetCreatorVecCore(PrimitiveDataViewType type) @@ -92,15 +90,13 @@ private Func GetCreatorVecCore(PrimitiveDataViewType type public Func GetCreatorOne(KeyDataViewType key) { // Have to produce a specific one - can't use a cached one. - MethodInfo meth = _methOne.MakeGenericMethod(key.RawType); - return (Func)meth.Invoke(this, new object[] { key }); + return Utils.MarshalInvoke(_getCreatorOneCoreMethodInfo, this, key.RawType, key); } public Func GetCreatorVec(KeyDataViewType key) { // Have to produce a specific one - can't use a cached one. - MethodInfo meth = _methVec.MakeGenericMethod(key.RawType); - return (Func)meth.Invoke(this, new object[] { key }); + return Utils.MarshalInvoke(_getCreatorVecCoreMethodInfo, this, key.RawType, key); } public Func GetCreatorOne(InternalDataKind kind) diff --git a/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs index 1c4b7140a0..6339333cef 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs @@ -317,6 +317,9 @@ protected override void VerifyView(IDataView view) } } + private static readonly FuncInstanceMethodInfo1 _getSlotCursorCoreMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.GetSlotCursorCore); + // Positive if explicit, otherwise let the sub-binary loader decide for themselves. private readonly int _threads; @@ -647,7 +650,7 @@ SlotCursor ITransposeDataView.GetSlotCursor(int col) DataViewRowCursor inputCursor = view.GetRowCursorForAllColumns(); try { - return Utils.MarshalInvoke(GetSlotCursorCore, cursorType.RawType, inputCursor); + return Utils.MarshalInvoke(_getSlotCursorCoreMethodInfo, this, cursorType.RawType, inputCursor); } catch (Exception) { diff --git a/src/Microsoft.ML.Data/DataView/AppendRowsDataView.cs b/src/Microsoft.ML.Data/DataView/AppendRowsDataView.cs index 0ae6fbed77..0773507497 100644 --- a/src/Microsoft.ML.Data/DataView/AppendRowsDataView.cs +++ b/src/Microsoft.ML.Data/DataView/AppendRowsDataView.cs @@ -158,6 +158,9 @@ public DataViewRowCursor[] GetRowCursorSet(IEnumerable co private abstract class CursorBase : RootCursorBase { + private static readonly FuncInstanceMethodInfo1 _createTypedGetterMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.CreateTypedGetter); + protected readonly IDataView[] Sources; protected readonly Delegate[] Getters; @@ -178,9 +181,7 @@ protected Delegate CreateGetter(int col) { DataViewType colType = Schema[col].Type; Ch.AssertValue(colType); - Func creator = CreateTypedGetter; - var typedCreator = creator.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(colType.RawType); - return (Delegate)typedCreator.Invoke(this, new object[] { col }); + return Utils.MarshalInvoke(_createTypedGetterMethodInfo, this, colType.RawType, col); } protected abstract ValueGetter CreateTypedGetter(int col); diff --git a/src/Microsoft.ML.Data/DataView/CacheDataView.cs b/src/Microsoft.ML.Data/DataView/CacheDataView.cs index e24ed9c59d..56d655ec10 100644 --- a/src/Microsoft.ML.Data/DataView/CacheDataView.cs +++ b/src/Microsoft.ML.Data/DataView/CacheDataView.cs @@ -1136,6 +1136,9 @@ public Wrapper(BlockRandomIndex index) private abstract class RowCursorSeekerBase : DataViewRowCursor { + private static readonly FuncInstanceMethodInfo1 _createGetterDelegateMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.CreateGetterDelegate); + protected readonly CacheDataView Parent; protected readonly IChannel Ch; protected long PositionCore; @@ -1213,7 +1216,7 @@ private Delegate CreateGetterDelegate(int col) { Ch.Assert(0 <= col && col < _colToActivesIndex.Length); Ch.Assert(_colToActivesIndex[col] >= 0); - return Utils.MarshalInvoke(CreateGetterDelegate, Schema[col].Type.RawType, col); + return Utils.MarshalInvoke(_createGetterDelegateMethodInfo, this, Schema[col].Type.RawType, col); } private Delegate CreateGetterDelegate(int col) diff --git a/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs b/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs index 797435d5da..9c5e09d2bc 100644 --- a/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs @@ -182,6 +182,15 @@ protected override TRow GetCurrentRowObject() public abstract class InputRowBase : DataViewRow where TRow : class { + private static readonly FuncInstanceMethodInfo1, Delegate, Delegate> _createDirectArrayGetterDelegateMethodInfo + = FuncInstanceMethodInfo1, Delegate, Delegate>.Create(target => target.CreateDirectArrayGetterDelegate); + + private static readonly FuncInstanceMethodInfo1, Delegate, Delegate> _createDirectVBufferGetterDelegateMethodInfo + = FuncInstanceMethodInfo1, Delegate, Delegate>.Create(target => target.CreateDirectVBufferGetterDelegate); + + private static readonly FuncInstanceMethodInfo1, Delegate, Delegate> _createDirectGetterDelegateMethodInfo + = FuncInstanceMethodInfo1, Delegate, Delegate>.Create(target => target.CreateDirectGetterDelegate); + private readonly int _colCount; private readonly Delegate[] _getters; protected readonly IHost Host; @@ -213,7 +222,7 @@ private Delegate CreateGetter(DataViewType colType, InternalSchemaDefinition.Col { var outputType = column.OutputType; var genericType = outputType; - Func del; + FuncInstanceMethodInfo1, Delegate, Delegate> del; if (outputType.IsArray) { @@ -232,7 +241,7 @@ private Delegate CreateGetter(DataViewType colType, InternalSchemaDefinition.Col Host.Assert(Nullable.GetUnderlyingType(outputType.GetElementType()) == vectorType.ItemType.RawType); else Host.Assert(outputType.GetElementType() == vectorType.ItemType.RawType); - del = CreateDirectArrayGetterDelegate; + del = _createDirectArrayGetterDelegateMethodInfo; genericType = outputType.GetElementType(); } else if (colType is VectorDataViewType vectorType) @@ -242,7 +251,7 @@ private Delegate CreateGetter(DataViewType colType, InternalSchemaDefinition.Col Host.Assert(outputType.IsGenericType); Host.Assert(outputType.GetGenericTypeDefinition() == typeof(VBuffer<>)); Host.Assert(outputType.GetGenericArguments()[0] == vectorType.ItemType.RawType); - del = CreateDirectVBufferGetterDelegate; + del = _createDirectVBufferGetterDelegateMethodInfo; genericType = vectorType.ItemType.RawType; } else if (colType is PrimitiveDataViewType) @@ -261,7 +270,7 @@ private Delegate CreateGetter(DataViewType colType, InternalSchemaDefinition.Col Host.Assert(colType.RawType == outputType); if (!(colType is KeyDataViewType keyType)) - del = CreateDirectGetterDelegate; + del = _createDirectGetterDelegateMethodInfo; else { var keyRawType = colType.RawType; @@ -271,14 +280,14 @@ private Delegate CreateGetter(DataViewType colType, InternalSchemaDefinition.Col } else if (DataViewTypeManager.Knows(colType)) { - del = CreateDirectGetterDelegate; + del = _createDirectGetterDelegateMethodInfo; } else { // REVIEW: Is this even possible? throw Host.ExceptNotSupp("Type '{0}' is not yet supported.", outputType.FullName); } - return Utils.MarshalInvoke(del, genericType, peek); + return Utils.MarshalInvoke(del, this, genericType, peek); } // REVIEW: The converting getter invokes a type conversion delegate on every call, so it's inherently slower diff --git a/src/Microsoft.ML.Data/DataView/Transposer.cs b/src/Microsoft.ML.Data/DataView/Transposer.cs index fb70b731f8..c424ec45be 100644 --- a/src/Microsoft.ML.Data/DataView/Transposer.cs +++ b/src/Microsoft.ML.Data/DataView/Transposer.cs @@ -23,6 +23,9 @@ namespace Microsoft.ML.Data [BestFriend] internal sealed class Transposer : ITransposeDataView, IDisposable { + private static readonly FuncInstanceMethodInfo1 _getSlotCursorCoreMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.GetSlotCursorCore); + private readonly IHost _host; // The input view. private readonly IDataView _view; @@ -244,7 +247,7 @@ public SlotCursor GetSlotCursor(int col) _host.Assert(0 <= tcol && tcol < _cols.Length); _host.Assert(_cols[tcol].Index == col); - return Utils.MarshalInvoke(GetSlotCursorCore, type, col); + return Utils.MarshalInvoke(_getSlotCursorCoreMethodInfo, this, type, col); } private SlotCursor GetSlotCursorCore(int col) @@ -1397,6 +1400,9 @@ private static DataViewRowCursor GetRowCursorShimCore(IChannelProvider provid /// public sealed class SlotDataView : IDataView { + private static readonly FuncInstanceMethodInfo1 _getRowCursorMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.GetRowCursor); + private readonly IHost _host; private readonly ITransposeDataView _data; private readonly int _col; @@ -1434,7 +1440,7 @@ public SlotDataView(IHostEnvironment env, ITransposeDataView data, int col) public DataViewRowCursor GetRowCursor(IEnumerable columnsNeeded, Random rand = null) { bool hasZero = columnsNeeded != null && columnsNeeded.Any(x => x.Index == 0); - return Utils.MarshalInvoke(GetRowCursor, _type.GetItemType().RawType, hasZero); + return Utils.MarshalInvoke(_getRowCursorMethodInfo, this, _type.GetItemType().RawType, hasZero); } private DataViewRowCursor GetRowCursor(bool active) diff --git a/src/Microsoft.ML.Data/EntryPoints/InputBuilder.cs b/src/Microsoft.ML.Data/EntryPoints/InputBuilder.cs index 1a32b903b8..9a14583a68 100644 --- a/src/Microsoft.ML.Data/EntryPoints/InputBuilder.cs +++ b/src/Microsoft.ML.Data/EntryPoints/InputBuilder.cs @@ -37,6 +37,12 @@ public Attributes(ArgumentAttribute input, TlcModule.RangeAttribute range, bool } } + private static readonly FuncStaticMethodInfo1 _makeNullableMethodInfo + = new FuncStaticMethodInfo1(MakeNullable); + + private static readonly FuncStaticMethodInfo1 _makeOptionalMethodInfo + = new FuncStaticMethodInfo1(MakeOptional); + private readonly IExceptionContext _ectx; private readonly object _instance; private readonly Type _type; @@ -622,13 +628,13 @@ private static object MakeOptionalIfNeeded(IExceptionContext ectx, object innerV } bool isOptional = outerType.GetGenericTypeDefinition() == typeof(Optional<>); - Func creator; + FuncStaticMethodInfo1 creator; if (isOptional) - creator = MakeOptional; + creator = _makeOptionalMethodInfo; else { ectx.Assert(genericType == typeof(Nullable<>)); - creator = MakeNullable; + creator = _makeNullableMethodInfo; } return Utils.MarshalInvoke(creator, outerType.GetGenericArguments()[0], innerValue); diff --git a/src/Microsoft.ML.Data/Scorers/MulticlassClassificationScorer.cs b/src/Microsoft.ML.Data/Scorers/MulticlassClassificationScorer.cs index 05d61f2620..8563e1ffe4 100644 --- a/src/Microsoft.ML.Data/Scorers/MulticlassClassificationScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/MulticlassClassificationScorer.cs @@ -66,6 +66,9 @@ private static VersionInfo GetVersionInfo() // less ridiculously verbose than this. public sealed class LabelNameBindableMapper : ISchemaBindableMapper, ICanSaveModel, IBindableCanSavePfa, IBindableCanSaveOnnx { + private static readonly FuncInstanceMethodInfo1 _decodeInitMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.DecodeInit); + public const string LoaderSignature = "LabelSlotNameMapper"; private const string _innerDir = "InnerMapper"; private readonly ISchemaBindableMapper _bindable; @@ -138,7 +141,7 @@ private LabelNameBindableMapper(IHost host, ModelLoadContext ctx) _type = type as VectorDataViewType; _host.CheckDecode(_type != null); _host.CheckDecode(value != null); - _getter = Utils.MarshalInvoke(DecodeInit, _type.ItemType.RawType, value); + _getter = Utils.MarshalInvoke(_decodeInitMethodInfo, this, _type.ItemType.RawType, value); _metadataKind = ctx.Header.ModelVerReadable >= VersionAddedMetadataKind ? ctx.LoadNonEmptyString() : AnnotationUtils.Kinds.SlotNames; } diff --git a/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs b/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs index 64284a92c1..ba70941461 100644 --- a/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs @@ -521,6 +521,12 @@ private BoundColumn MakeColumn(DataViewSchema inputSchema, int iinfo) /// private sealed class BoundColumn { + private static readonly FuncInstanceMethodInfo1 _makeIdentityGetterMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.MakeIdentityGetter); + + private static readonly FuncInstanceMethodInfo1 _makeGetterMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.MakeGetter); + public readonly int[] SrcIndices; private readonly ColumnOptions _columnOptions; @@ -669,9 +675,9 @@ private void GetSlotNames(ref VBuffer> dst) public Delegate MakeGetter(DataViewRow input) { if (_isIdentity) - return Utils.MarshalInvoke(MakeIdentityGetter, OutputType.RawType, input); + return Utils.MarshalInvoke(_makeIdentityGetterMethodInfo, this, OutputType.RawType, input); - return Utils.MarshalInvoke(MakeGetter, OutputType.ItemType.RawType, input); + return Utils.MarshalInvoke(_makeGetterMethodInfo, this, OutputType.ItemType.RawType, input); } private Delegate MakeIdentityGetter(DataViewRow input) diff --git a/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs b/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs index 72984958ce..aec9c87ca3 100644 --- a/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs @@ -453,6 +453,9 @@ public void Fetch(int idx, ref T value) protected abstract void Copy(in T src, ref T dst); } + private static readonly FuncInstanceMethodInfo1 _createGetterDelegateMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.CreateGetterDelegate); + // The number of examples to have in each synchronization block. This should be >= 1. private const int _blockSize = 16; // The number of spare blocks to keep the filler worker busy on. This should be >= 1. @@ -692,8 +695,7 @@ private Delegate CreateGetterDelegate(int col) { Ch.Assert(0 <= col && col < _colToActivesIndex.Length); Ch.Assert(_colToActivesIndex[col] >= 0); - Func createDel = CreateGetterDelegate; - return Utils.MarshalInvoke(createDel, Schema[col].Type.RawType, col); + return Utils.MarshalInvoke(_createGetterDelegateMethodInfo, this, Schema[col].Type.RawType, col); } private Delegate CreateGetterDelegate(int col) diff --git a/src/Microsoft.ML.Data/Transforms/ValueMapping.cs b/src/Microsoft.ML.Data/Transforms/ValueMapping.cs index 0896d1077b..23f23e5cb4 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueMapping.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueMapping.cs @@ -825,6 +825,12 @@ private static ValueMap CreateValueMapInvoke(DataViewSchema.Column /// private class ValueMap : ValueMap { + private static readonly FuncStaticMethodInfo1 _getVectorMethodInfo + = new FuncStaticMethodInfo1(GetVector); + + private static readonly FuncStaticMethodInfo1 _getValueMethodInfo + = new FuncStaticMethodInfo1(GetValue); + private Dictionary _mapping; private TValue _missingValue; @@ -889,9 +895,9 @@ private TValue MapValue(TKey key) if (_mapping.ContainsKey(key)) { if (ValueColumn.Type is VectorDataViewType vectorType) - return Utils.MarshalInvoke(GetVector, vectorType.ItemType.RawType, _mapping[key]); + return Utils.MarshalInvoke(_getVectorMethodInfo, vectorType.ItemType.RawType, _mapping[key]); else - return Utils.MarshalInvoke(GetValue, ValueColumn.Type.RawType, _mapping[key]); + return Utils.MarshalInvoke(_getValueMethodInfo, ValueColumn.Type.RawType, _mapping[key]); } else return _missingValue; diff --git a/src/Microsoft.ML.Data/Utils/ApiUtils.cs b/src/Microsoft.ML.Data/Utils/ApiUtils.cs index cf24199647..6b5e5531ad 100644 --- a/src/Microsoft.ML.Data/Utils/ApiUtils.cs +++ b/src/Microsoft.ML.Data/Utils/ApiUtils.cs @@ -8,6 +8,7 @@ using System.Reflection; using System.Reflection.Emit; using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Runtime; namespace Microsoft.ML @@ -18,6 +19,9 @@ namespace Microsoft.ML internal static class ApiUtils { + private static readonly FuncStaticMethodInfo3 _generatePokeMethodInfo + = new FuncStaticMethodInfo3(GeneratePoke); + private static OpCode GetAssignmentOpCode(Type t, IEnumerable attributes) { // REVIEW: This should be a Dictionary based solution. @@ -145,10 +149,7 @@ internal static Delegate GeneratePoke(InternalSchemaDefinition.Colum Type propertyType = propertyInfo.PropertyType; var assignmentOpCodeProp = GetAssignmentOpCode(propertyType, propertyInfo.GetCustomAttributes()); - Func funcProp = GeneratePoke; - var methInfoProp = funcProp.GetMethodInfo().GetGenericMethodDefinition() - .MakeGenericMethod(typeof(TOwn), typeof(TRow), propertyType); - return (Delegate)methInfoProp.Invoke(null, new object[] { propertyInfo }); + return Utils.MarshalInvoke(_generatePokeMethodInfo, typeof(TOwn), typeof(TRow), propertyType, propertyInfo); default: Contracts.Assert(false); diff --git a/src/Microsoft.ML.Parquet/PartitionedFileLoader.cs b/src/Microsoft.ML.Parquet/PartitionedFileLoader.cs index cd8373716a..2b7496663b 100644 --- a/src/Microsoft.ML.Parquet/PartitionedFileLoader.cs +++ b/src/Microsoft.ML.Parquet/PartitionedFileLoader.cs @@ -364,6 +364,9 @@ private ILegacyDataLoader CreateLoaderFromBytes(byte[] loaderBytes, IMultiStream private sealed class Cursor : RootCursorBase { + private static readonly FuncInstanceMethodInfo1 _createSubGetterDelegateCoreMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.CreateSubGetterDelegateCore); + private PartitionedFileLoader _parent; private readonly bool[] _active; @@ -581,7 +584,7 @@ private Delegate[] CreateGetters() // Use sub-cursor for all sub-columns. if (IsSubColumn(i)) { - getters[i] = Utils.MarshalInvoke(CreateSubGetterDelegateCore, type.RawType, i); + getters[i] = Utils.MarshalInvoke(_createSubGetterDelegateCoreMethodInfo, this, type.RawType, i); } else { diff --git a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs index a5d099c424..6ba7fafd48 100644 --- a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs +++ b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs @@ -143,6 +143,9 @@ private protected override void SaveModel(ModelSaveContext ctx) private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx { + private static readonly FuncStaticMethodInfo1 _getIsNADelegateMethodInfo + = new FuncStaticMethodInfo1(GetIsNADelegate); + private readonly MissingValueIndicatorTransformer _parent; private readonly ColInfo[] _infos; @@ -215,8 +218,7 @@ protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore() /// private static Delegate GetIsNADelegate(DataViewType type) { - Func func = GetIsNADelegate; - return Utils.MarshalInvoke(func, type.GetItemType().RawType, type); + return Utils.MarshalInvoke(_getIsNADelegateMethodInfo, type.GetItemType().RawType, type); } private static Delegate GetIsNADelegate(DataViewType type) diff --git a/src/Microsoft.ML.Transforms/MissingValueReplacing.cs b/src/Microsoft.ML.Transforms/MissingValueReplacing.cs index 0e0d1a819b..0759232564 100644 --- a/src/Microsoft.ML.Transforms/MissingValueReplacing.cs +++ b/src/Microsoft.ML.Transforms/MissingValueReplacing.cs @@ -122,6 +122,12 @@ internal sealed class Options : TransformInputBase public bool ImputeBySlot = MissingValueReplacingEstimator.Defaults.ImputeBySlot; } + private static readonly FuncStaticMethodInfo1 _testTypeMethodInfo + = new FuncStaticMethodInfo1(TestType); + + private static readonly FuncInstanceMethodInfo1 _getIsNADelegateMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.GetIsNADelegate); + internal const string LoadName = "NAReplaceTransform"; private static VersionInfo GetVersionInfo() @@ -146,9 +152,8 @@ private static VersionInfo GetVersionInfo() internal static string TestType(DataViewType type) { // Item type must have an NA value that exists and is not equal to its default value. - Func func = TestType; var itemType = type.GetItemType(); - return Utils.MarshalInvoke(func, itemType.RawType, itemType); + return Utils.MarshalInvoke(_testTypeMethodInfo, itemType.RawType, itemType); } private static string TestType(DataViewType type) @@ -383,8 +388,7 @@ private object GetDefault(DataViewType type) /// private Delegate GetIsNADelegate(DataViewType type) { - Func func = GetIsNADelegate; - return Utils.MarshalInvoke(func, type.GetItemType().RawType, type); + return Utils.MarshalInvoke(_getIsNADelegateMethodInfo, this, type.GetItemType().RawType, type); } private Delegate GetIsNADelegate(DataViewType type) diff --git a/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs b/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs index 16c093935e..efecb8fee8 100644 --- a/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs +++ b/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs @@ -409,6 +409,9 @@ internal static bool IsValidColumnType(DataViewType type) private sealed class Impl { + private static readonly FuncStaticMethodInfo1 _makeKeyMapperMethodInfo + = new FuncStaticMethodInfo1(MakeKeyMapper); + private readonly IHost _host; private readonly BinFinderBase _binFinder; private int _numBins; @@ -627,11 +630,10 @@ private Single[] ComputeMutualInformation(Transposer trans, int col) } ulong keyCount = itemType.GetKeyCount(); Contracts.Assert(keyCount < Utils.ArrayMaxSize); - Func> del = MakeKeyMapper; - var methodInfo = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(itemType.RawType); + var mapper = Utils.MarshalInvoke(_makeKeyMapperMethodInfo, itemType.RawType, itemType); ComputeMutualInformationDelegate cmiDel = ComputeMutualInformation; var cmiMethodInfo = cmiDel.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(itemType.RawType); - return (Single[])cmiMethodInfo.Invoke(this, new object[] { trans, col, methodInfo.Invoke(null, new object[] { itemType }) }); + return (Single[])cmiMethodInfo.Invoke(this, new object[] { trans, col, mapper }); } private delegate float[] ComputeMutualInformationDelegate(Transposer trans, int col, Mapper mapper); diff --git a/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs b/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs index b6bcbb09bd..a63dea2fbe 100644 --- a/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs +++ b/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs @@ -248,6 +248,9 @@ private static VersionInfo GetVersionInfo() private static readonly FuncInstanceMethodInfo1 _makeGetterOneMethodInfo = FuncInstanceMethodInfo1.Create(target => target.MakeGetterOne); + private static readonly FuncInstanceMethodInfo1 _makeGetterVecMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.MakeGetterVec); + private readonly Bindings _bindings; private const string RegistrationName = "OptionalColumn"; @@ -406,7 +409,7 @@ private Delegate MakeGetter(int iinfo) { var columnType = _bindings.ColumnTypes[iinfo]; if (columnType is VectorDataViewType vectorType) - return Utils.MarshalInvoke(MakeGetterVec, vectorType.ItemType.RawType, vectorType.Size); + return Utils.MarshalInvoke(_makeGetterVecMethodInfo, this, vectorType.ItemType.RawType, vectorType.Size); return Utils.MarshalInvoke(_makeGetterOneMethodInfo, this, columnType.RawType); } @@ -426,6 +429,9 @@ private sealed class Cursor : SynchronizedCursorBase private static readonly FuncInstanceMethodInfo1 _makeGetterOneMethodInfo = FuncInstanceMethodInfo1.Create(target => target.MakeGetterOne); + private static readonly FuncInstanceMethodInfo1 _makeGetterVecMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.MakeGetterVec); + private readonly Bindings _bindings; private readonly bool[] _active; private readonly Delegate[] _getters; @@ -489,7 +495,7 @@ private Delegate MakeGetter(int iinfo) { var columnType = _bindings.ColumnTypes[iinfo]; if (columnType is VectorDataViewType vectorType) - return Utils.MarshalInvoke(MakeGetterVec, vectorType.ItemType.RawType, vectorType.Size); + return Utils.MarshalInvoke(_makeGetterVecMethodInfo, this, vectorType.ItemType.RawType, vectorType.Size); return Utils.MarshalInvoke(_makeGetterOneMethodInfo, this, columnType.RawType); } diff --git a/src/Microsoft.ML.Transforms/UngroupTransform.cs b/src/Microsoft.ML.Transforms/UngroupTransform.cs index 43aa5d5426..7b5ad5aa57 100644 --- a/src/Microsoft.ML.Transforms/UngroupTransform.cs +++ b/src/Microsoft.ML.Transforms/UngroupTransform.cs @@ -445,6 +445,9 @@ public int GetCommonPivotColumnSize() private sealed class Cursor : LinkedRootCursorBase { + private static readonly FuncInstanceMethodInfo1> _makeSizeGetterMethodInfo + = FuncInstanceMethodInfo1>.Create(target => target.MakeSizeGetter); + private readonly UngroupBinding _ungroupBinding; // The size of the pivot column in the current row. If the cursor is in good state, this is positive. @@ -501,9 +504,7 @@ public Cursor(IChannelProvider provider, DataViewRowCursor input, UngroupBinding // This will also create and cache a getter for the pivot column. // That's why MakeSizeGetter is an instance method. var rawItemType = info.ItemType.RawType; - Func> del = MakeSizeGetter; - var mi = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(rawItemType); - var sizeGetter = (Func)mi.Invoke(this, new object[] { info.Index }); + var sizeGetter = Utils.MarshalInvoke(_makeSizeGetterMethodInfo, this, rawItemType, info.Index); needed.Add(sizeGetter); } }