From 82e02f37564f83c3a7c90e5e28a652a39017701a Mon Sep 17 00:00:00 2001 From: Jan Kotas Date: Fri, 21 Dec 2018 09:25:24 -0800 Subject: [PATCH] Streamline default EqualityComparer and Comparer for Enums (#21604) This borrows the implementation strategy for these from CoreRT. It makes it both simpler (fewer types and lines of code) and faster in some cases since we always use the exact right underlying type. E.g. The following micro-benchmark is 25% faster with this change: ``` enum MyEnum : byte { x, y }; var comparer = Comparer.Default; for (int i = 0; i < 100000000; i++) { comparer.Compare(MyEnum.x, MyEnum.y); comparer.Compare(MyEnum.y, MyEnum.x); } ``` --- .../System/Collections/Generic/Comparer.cs | 103 +----------------- .../Collections/Generic/ComparerHelpers.cs | 11 +- .../Collections/Generic/EqualityComparer.cs | 75 +------------ .../Runtime/CompilerServices/jithelpers.cs | 55 ++-------- src/vm/jitinterface.cpp | 61 ++++++++--- src/vm/mscorlib.h | 12 +- 6 files changed, 66 insertions(+), 251 deletions(-) diff --git a/src/System.Private.CoreLib/src/System/Collections/Generic/Comparer.cs b/src/System.Private.CoreLib/src/System/Collections/Generic/Comparer.cs index d5d2e9165f9b..3bf7ebf2a72e 100644 --- a/src/System.Private.CoreLib/src/System/Collections/Generic/Comparer.cs +++ b/src/System.Private.CoreLib/src/System/Collections/Generic/Comparer.cs @@ -130,21 +130,16 @@ public override int Compare(T x, T y) // since we want to serialize as ObjectComparer for // back-compat reasons (see below). [Serializable] - internal sealed class Int32EnumComparer : Comparer, ISerializable where T : struct + internal sealed class EnumComparer : Comparer, ISerializable where T : struct, Enum { - public Int32EnumComparer() - { - Debug.Assert(typeof(T).IsEnum); - } + internal EnumComparer() { } // Used by the serialization engine. - private Int32EnumComparer(SerializationInfo info, StreamingContext context) { } + private EnumComparer(SerializationInfo info, StreamingContext context) { } public override int Compare(T x, T y) { - int ix = JitHelpers.UnsafeEnumCast(x); - int iy = JitHelpers.UnsafeEnumCast(y); - return ix.CompareTo(iy); + return System.Runtime.CompilerServices.JitHelpers.EnumCompareTo(x, y); } // Equals method for the comparer itself. @@ -163,94 +158,4 @@ public void GetObjectData(SerializationInfo info, StreamingContext context) info.SetType(typeof(ObjectComparer)); } } - - [Serializable] - internal sealed class UInt32EnumComparer : Comparer, ISerializable where T : struct - { - public UInt32EnumComparer() - { - Debug.Assert(typeof(T).IsEnum); - } - - // Used by the serialization engine. - private UInt32EnumComparer(SerializationInfo info, StreamingContext context) { } - - public override int Compare(T x, T y) - { - uint ix = (uint)JitHelpers.UnsafeEnumCast(x); - uint iy = (uint)JitHelpers.UnsafeEnumCast(y); - return ix.CompareTo(iy); - } - - // Equals method for the comparer itself. - public override bool Equals(object obj) => - obj != null && GetType() == obj.GetType(); - - public override int GetHashCode() => - GetType().GetHashCode(); - - public void GetObjectData(SerializationInfo info, StreamingContext context) - { - info.SetType(typeof(ObjectComparer)); - } - } - - [Serializable] - internal sealed class Int64EnumComparer : Comparer, ISerializable where T : struct - { - public Int64EnumComparer() - { - Debug.Assert(typeof(T).IsEnum); - } - - public override int Compare(T x, T y) - { - long lx = JitHelpers.UnsafeEnumCastLong(x); - long ly = JitHelpers.UnsafeEnumCastLong(y); - return lx.CompareTo(ly); - } - - // Equals method for the comparer itself. - public override bool Equals(object obj) => - obj != null && GetType() == obj.GetType(); - - public override int GetHashCode() => - GetType().GetHashCode(); - - public void GetObjectData(SerializationInfo info, StreamingContext context) - { - info.SetType(typeof(ObjectComparer)); - } - } - - [Serializable] - internal sealed class UInt64EnumComparer : Comparer, ISerializable where T : struct - { - public UInt64EnumComparer() - { - Debug.Assert(typeof(T).IsEnum); - } - - // Used by the serialization engine. - private UInt64EnumComparer(SerializationInfo info, StreamingContext context) { } - - public override int Compare(T x, T y) - { - ulong lx = (ulong)JitHelpers.UnsafeEnumCastLong(x); - ulong ly = (ulong)JitHelpers.UnsafeEnumCastLong(y); - return lx.CompareTo(ly); - } - - // Equals method for the comparer itself. - public override bool Equals(object obj) => - obj != null && GetType() == obj.GetType(); - - public override int GetHashCode() => - GetType().GetHashCode(); - - public void GetObjectData(SerializationInfo info, StreamingContext context) - { - info.SetType(typeof(ObjectComparer)); - } - } } diff --git a/src/System.Private.CoreLib/src/System/Collections/Generic/ComparerHelpers.cs b/src/System.Private.CoreLib/src/System/Collections/Generic/ComparerHelpers.cs index 2575d4e34687..4563bc155c56 100644 --- a/src/System.Private.CoreLib/src/System/Collections/Generic/ComparerHelpers.cs +++ b/src/System.Private.CoreLib/src/System/Collections/Generic/ComparerHelpers.cs @@ -96,16 +96,12 @@ private static object TryCreateEnumComparer(RuntimeType enumType) case TypeCode.SByte: case TypeCode.Int16: case TypeCode.Int32: - return RuntimeTypeHandle.CreateInstanceForAnotherGenericParameter((RuntimeType)typeof(Int32EnumComparer), enumType); case TypeCode.Byte: case TypeCode.UInt16: case TypeCode.UInt32: - return RuntimeTypeHandle.CreateInstanceForAnotherGenericParameter((RuntimeType)typeof(UInt32EnumComparer), enumType); - // 64-bit enums: Use `UnsafeEnumCastLong` case TypeCode.Int64: - return RuntimeTypeHandle.CreateInstanceForAnotherGenericParameter((RuntimeType)typeof(Int64EnumComparer), enumType); case TypeCode.UInt64: - return RuntimeTypeHandle.CreateInstanceForAnotherGenericParameter((RuntimeType)typeof(UInt64EnumComparer), enumType); + return RuntimeTypeHandle.CreateInstanceForAnotherGenericParameter((RuntimeType)typeof(EnumComparer<>), enumType); } return null; @@ -194,11 +190,10 @@ private static object TryCreateEnumEqualityComparer(RuntimeType enumType) case TypeCode.SByte: case TypeCode.Byte: case TypeCode.Int16: - case TypeCode.UInt16: - return RuntimeTypeHandle.CreateInstanceForAnotherGenericParameter((RuntimeType)typeof(EnumEqualityComparer), enumType); case TypeCode.Int64: case TypeCode.UInt64: - return RuntimeTypeHandle.CreateInstanceForAnotherGenericParameter((RuntimeType)typeof(LongEnumEqualityComparer), enumType); + case TypeCode.UInt16: + return RuntimeTypeHandle.CreateInstanceForAnotherGenericParameter((RuntimeType)typeof(EnumEqualityComparer<>), enumType); } return null; diff --git a/src/System.Private.CoreLib/src/System/Collections/Generic/EqualityComparer.cs b/src/System.Private.CoreLib/src/System/Collections/Generic/EqualityComparer.cs index 5e778fdb281f..82051affc084 100644 --- a/src/System.Private.CoreLib/src/System/Collections/Generic/EqualityComparer.cs +++ b/src/System.Private.CoreLib/src/System/Collections/Generic/EqualityComparer.cs @@ -315,7 +315,7 @@ public override int GetHashCode() => [Serializable] [System.Runtime.CompilerServices.TypeForwardedFrom("mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089")] // Needs to be public to support binary serialization compatibility - public sealed class EnumEqualityComparer : EqualityComparer, ISerializable where T : struct + public sealed class EnumEqualityComparer : EqualityComparer, ISerializable where T : struct, Enum { internal EnumEqualityComparer() { } @@ -333,9 +333,7 @@ public void GetObjectData(SerializationInfo info, StreamingContext context) [MethodImpl(MethodImplOptions.AggressiveInlining)] public override bool Equals(T x, T y) { - int x_final = System.Runtime.CompilerServices.JitHelpers.UnsafeEnumCast(x); - int y_final = System.Runtime.CompilerServices.JitHelpers.UnsafeEnumCast(y); - return x_final == y_final; + return System.Runtime.CompilerServices.JitHelpers.EnumEquals(x, y); } [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -353,85 +351,20 @@ public override int GetHashCode() => internal override int IndexOf(T[] array, T value, int startIndex, int count) { - int toFind = JitHelpers.UnsafeEnumCast(value); int endIndex = startIndex + count; for (int i = startIndex; i < endIndex; i++) { - int current = JitHelpers.UnsafeEnumCast(array[i]); - if (toFind == current) return i; + if (System.Runtime.CompilerServices.JitHelpers.EnumEquals(array[i], value)) return i; } return -1; } internal override int LastIndexOf(T[] array, T value, int startIndex, int count) { - int toFind = JitHelpers.UnsafeEnumCast(value); int endIndex = startIndex - count + 1; for (int i = startIndex; i >= endIndex; i--) { - int current = JitHelpers.UnsafeEnumCast(array[i]); - if (toFind == current) return i; - } - return -1; - } - } - - [Serializable] - internal sealed class LongEnumEqualityComparer : EqualityComparer, ISerializable where T : struct - { - internal LongEnumEqualityComparer() { } - - // This is used by the serialization engine. - private LongEnumEqualityComparer(SerializationInfo information, StreamingContext context) { } - - public void GetObjectData(SerializationInfo info, StreamingContext context) - { - // The LongEnumEqualityComparer does not exist on 4.0 so we need to serialize this comparer as ObjectEqualityComparer - // to allow for roundtrip between 4.0 and 4.5. - info.SetType(typeof(ObjectEqualityComparer)); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public override bool Equals(T x, T y) - { - long x_final = System.Runtime.CompilerServices.JitHelpers.UnsafeEnumCastLong(x); - long y_final = System.Runtime.CompilerServices.JitHelpers.UnsafeEnumCastLong(y); - return x_final == y_final; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public override int GetHashCode(T obj) - { - return obj.GetHashCode(); - } - - // Equals method for the comparer itself. - public override bool Equals(object obj) => - obj != null && GetType() == obj.GetType(); - - public override int GetHashCode() => - GetType().GetHashCode(); - - internal override int IndexOf(T[] array, T value, int startIndex, int count) - { - long toFind = JitHelpers.UnsafeEnumCastLong(value); - int endIndex = startIndex + count; - for (int i = startIndex; i < endIndex; i++) - { - long current = JitHelpers.UnsafeEnumCastLong(array[i]); - if (toFind == current) return i; - } - return -1; - } - - internal override int LastIndexOf(T[] array, T value, int startIndex, int count) - { - long toFind = JitHelpers.UnsafeEnumCastLong(value); - int endIndex = startIndex - count + 1; - for (int i = startIndex; i >= endIndex; i--) - { - long current = JitHelpers.UnsafeEnumCastLong(array[i]); - if (toFind == current) return i; + if (System.Runtime.CompilerServices.JitHelpers.EnumEquals(array[i], value)) return i; } return -1; } diff --git a/src/System.Private.CoreLib/src/System/Runtime/CompilerServices/jithelpers.cs b/src/System.Private.CoreLib/src/System/Runtime/CompilerServices/jithelpers.cs index f765e18c4596..b8bb7c498184 100644 --- a/src/System.Private.CoreLib/src/System/Runtime/CompilerServices/jithelpers.cs +++ b/src/System.Private.CoreLib/src/System/Runtime/CompilerServices/jithelpers.cs @@ -85,59 +85,20 @@ internal static StackCrawlMarkHandle GetStackCrawlMarkHandle(ref StackCrawlMark return new StackCrawlMarkHandle((IntPtr)Unsafe.AsPointer(ref stackMark)); } -#if DEBUG - internal static int UnsafeEnumCast(T val) where T : struct // Actually T must be 4 byte (or less) enum + internal static bool EnumEquals(T x, T y) where T : struct, Enum { - Debug.Assert(typeof(T).IsEnum - && (Enum.GetUnderlyingType(typeof(T)) == typeof(int) - || Enum.GetUnderlyingType(typeof(T)) == typeof(uint) - || Enum.GetUnderlyingType(typeof(T)) == typeof(short) - || Enum.GetUnderlyingType(typeof(T)) == typeof(ushort) - || Enum.GetUnderlyingType(typeof(T)) == typeof(byte) - || Enum.GetUnderlyingType(typeof(T)) == typeof(sbyte)), - "Error, T must be an 4 byte (or less) enum JitHelpers.UnsafeEnumCast!"); - return UnsafeEnumCastInternal(val); + // The body of this function will be replaced by the EE with unsafe code + // See getILIntrinsicImplementation for how this happens. + return x.Equals(y); } - private static int UnsafeEnumCastInternal(T val) where T : struct // Actually T must be 4 (or less) byte enum + internal static int EnumCompareTo(T x, T y) where T : struct, Enum { - // should be return (int) val; but C# does not allow, runtime does this magically - // See getILIntrinsicImplementation for how this happens. - throw new InvalidOperationException(); + // The body of this function will be replaced by the EE with unsafe code + // See getILIntrinsicImplementation for how this happens. + return x.CompareTo(y); } - internal static long UnsafeEnumCastLong(T val) where T : struct // Actually T must be 8 byte enum - { - Debug.Assert(typeof(T).IsEnum - && (Enum.GetUnderlyingType(typeof(T)) == typeof(long) - || Enum.GetUnderlyingType(typeof(T)) == typeof(ulong)), - "Error, T must be an 8 byte enum JitHelpers.UnsafeEnumCastLong!"); - return UnsafeEnumCastLongInternal(val); - } - - private static long UnsafeEnumCastLongInternal(T val) where T : struct // Actually T must be 8 byte enum - { - // should be return (int) val; but C# does not allow, runtime does this magically - // See getILIntrinsicImplementation for how this happens. - throw new InvalidOperationException(); - } -#else // DEBUG - - internal static int UnsafeEnumCast(T val) where T : struct // Actually T must be 4 byte (or less) enum - { - // should be return (int) val; but C# does not allow, runtime does this magically - // See getILIntrinsicImplementation for how this happens. - throw new InvalidOperationException(); - } - - internal static long UnsafeEnumCastLong(T val) where T : struct // Actually T must be 8 byte enum - { - // should be return (long) val; but C# does not allow, runtime does this magically - // See getILIntrinsicImplementation for how this happens. - throw new InvalidOperationException(); - } -#endif // DEBUG - // Set the given element in the array without any type or range checks [MethodImplAttribute(MethodImplOptions.InternalCall)] internal static extern void UnsafeSetArrayElement(object[] target, int index, object element); diff --git a/src/vm/jitinterface.cpp b/src/vm/jitinterface.cpp index 2d87c276c89b..3e1765a0d795 100644 --- a/src/vm/jitinterface.cpp +++ b/src/vm/jitinterface.cpp @@ -6973,7 +6973,7 @@ bool getILIntrinsicImplementation(MethodDesc * ftn, // Compare tokens to cover all generic instantiations // The body of the first method is simply ret Arg0. The second one first casts the arg to I4. - if (tk == MscorlibBinder::GetMethod(METHOD__JIT_HELPERS__UNSAFE_ENUM_CAST)->GetMemberDef()) + if (tk == MscorlibBinder::GetMethod(METHOD__JIT_HELPERS__ENUM_EQUALS)->GetMemberDef()) { // Normally we would follow the above pattern and unconditionally replace the IL, // relying on generic type constraints to guarantee that it will only ever be instantiated @@ -7000,19 +7000,20 @@ bool getILIntrinsicImplementation(MethodDesc * ftn, et == ELEMENT_TYPE_I2 || et == ELEMENT_TYPE_U2 || et == ELEMENT_TYPE_I1 || - et == ELEMENT_TYPE_U1) + et == ELEMENT_TYPE_U1 || + et == ELEMENT_TYPE_I8 || + et == ELEMENT_TYPE_U8) { - // Cast to I4 and return the argument that was passed in. - static const BYTE ilcode[] = { CEE_LDARG_0, CEE_CONV_I4, CEE_RET }; + static const BYTE ilcode[] = { CEE_LDARG_0, CEE_LDARG_1, CEE_PREFIX1, (CEE_CEQ & 0xFF), CEE_RET }; methInfo->ILCode = const_cast(ilcode); methInfo->ILCodeSize = sizeof(ilcode); - methInfo->maxStack = 1; + methInfo->maxStack = 2; methInfo->EHcount = 0; methInfo->options = (CorInfoOptions)0; return true; } } - else if (tk == MscorlibBinder::GetMethod(METHOD__JIT_HELPERS__UNSAFE_ENUM_CAST_LONG)->GetMemberDef()) + else if (tk == MscorlibBinder::GetMethod(METHOD__JIT_HELPERS__ENUM_COMPARE_TO)->GetMemberDef()) { // The the comment above on why this is is not an unconditional replacement. This case handles // Enums backed by 8 byte values. @@ -7022,14 +7023,43 @@ bool getILIntrinsicImplementation(MethodDesc * ftn, _ASSERTE(inst.GetNumArgs() == 1); CorElementType et = inst[0].GetVerifierCorElementType(); - if (et == ELEMENT_TYPE_I8 || + if (et == ELEMENT_TYPE_I4 || + et == ELEMENT_TYPE_U4 || + et == ELEMENT_TYPE_I2 || + et == ELEMENT_TYPE_U2 || + et == ELEMENT_TYPE_I1 || + et == ELEMENT_TYPE_U1 || + et == ELEMENT_TYPE_I8 || et == ELEMENT_TYPE_U8) { - // Cast to I8 and return the argument that was passed in. - static const BYTE ilcode[] = { CEE_LDARG_0, CEE_CONV_I8, CEE_RET }; - methInfo->ILCode = const_cast(ilcode); - methInfo->ILCodeSize = sizeof(ilcode); - methInfo->maxStack = 1; + static BYTE ilcode[8][9]; + + TypeHandle thUnderlyingType = MscorlibBinder::GetElementType(et); + + TypeHandle thIComparable = TypeHandle(MscorlibBinder::GetClass(CLASS__ICOMPARABLEGENERIC)).Instantiate(Instantiation(&thUnderlyingType, 1)); + + MethodDesc * pCompareToMD = thUnderlyingType.AsMethodTable()->GetMethodDescForInterfaceMethod( + thIComparable, MscorlibBinder::GetMethod(METHOD__ICOMPARABLEGENERIC__COMPARE_TO), TRUE /* throwOnConflict */); + + // Call CompareTo method on the primitive type + int tokCompareTo = pCompareToMD->GetMemberDef(); + + int index = (et - ELEMENT_TYPE_I1); + _ASSERTE(index < _countof(ilcode)); + + ilcode[index][0] = CEE_LDARGA_S; + ilcode[index][1] = 0; + ilcode[index][2] = CEE_LDARG_1; + ilcode[index][3] = CEE_CALL; + ilcode[index][4] = (BYTE)(tokCompareTo); + ilcode[index][5] = (BYTE)(tokCompareTo >> 8); + ilcode[index][6] = (BYTE)(tokCompareTo >> 16); + ilcode[index][7] = (BYTE)(tokCompareTo >> 24); + ilcode[index][8] = CEE_RET; + + methInfo->ILCode = const_cast(ilcode[index]); + methInfo->ILCodeSize = sizeof(ilcode[index]); + methInfo->maxStack = 2; methInfo->EHcount = 0; methInfo->options = (CorInfoOptions)0; return true; @@ -9122,15 +9152,10 @@ CORINFO_CLASS_HANDLE CEEInfo::getDefaultEqualityComparerClassHelper(CORINFO_CLAS case ELEMENT_TYPE_U2: case ELEMENT_TYPE_I4: case ELEMENT_TYPE_U4: - { - targetClass = MscorlibBinder::GetClass(CLASS__ENUM_EQUALITYCOMPARER); - break; - } - case ELEMENT_TYPE_I8: case ELEMENT_TYPE_U8: { - targetClass = MscorlibBinder::GetClass(CLASS__LONG_ENUM_EQUALITYCOMPARER); + targetClass = MscorlibBinder::GetClass(CLASS__ENUM_EQUALITYCOMPARER); break; } diff --git a/src/vm/mscorlib.h b/src/vm/mscorlib.h index 48f3a5e5808a..9a5fcca78b46 100644 --- a/src/vm/mscorlib.h +++ b/src/vm/mscorlib.h @@ -679,13 +679,8 @@ DEFINE_METHOD(RUNTIME_HELPERS, EXECUTE_BACKOUT_CODE_HELPER, ExecuteBackoutC DEFINE_METHOD(RUNTIME_HELPERS, IS_REFERENCE_OR_CONTAINS_REFERENCES, IsReferenceOrContainsReferences, NoSig) DEFINE_CLASS(JIT_HELPERS, CompilerServices, JitHelpers) -#ifdef _DEBUG -DEFINE_METHOD(JIT_HELPERS, UNSAFE_ENUM_CAST, UnsafeEnumCastInternal, NoSig) -DEFINE_METHOD(JIT_HELPERS, UNSAFE_ENUM_CAST_LONG, UnsafeEnumCastLongInternal, NoSig) -#else // _DEBUG -DEFINE_METHOD(JIT_HELPERS, UNSAFE_ENUM_CAST, UnsafeEnumCast, NoSig) -DEFINE_METHOD(JIT_HELPERS, UNSAFE_ENUM_CAST_LONG, UnsafeEnumCastLong, NoSig) -#endif // _DEBUG +DEFINE_METHOD(JIT_HELPERS, ENUM_EQUALS, EnumEquals, NoSig) +DEFINE_METHOD(JIT_HELPERS, ENUM_COMPARE_TO, EnumCompareTo, NoSig) DEFINE_METHOD(JIT_HELPERS, GET_RAW_SZ_ARRAY_DATA, GetRawSzArrayData, NoSig) DEFINE_CLASS(UNSAFE, InternalCompilerServices, Unsafe) @@ -1330,6 +1325,8 @@ DEFINE_CLASS(IDICTIONARYGENERIC, CollectionsGeneric, IDictionary`2) DEFINE_CLASS(KEYVALUEPAIRGENERIC, CollectionsGeneric, KeyValuePair`2) DEFINE_CLASS(ICOMPARABLEGENERIC, System, IComparable`1) +DEFINE_METHOD(ICOMPARABLEGENERIC, COMPARE_TO, CompareTo, NoSig) + DEFINE_CLASS(IEQUATABLEGENERIC, System, IEquatable`1) DEFINE_CLASS_U(Reflection, LoaderAllocator, LoaderAllocatorObject) @@ -1383,7 +1380,6 @@ DEFINE_METHOD(UTF8BUFFERMARSHALER, CONVERT_TO_MANAGED, ConvertToManaged, NoSig) DEFINE_CLASS(BYTE_EQUALITYCOMPARER, CollectionsGeneric, ByteEqualityComparer) DEFINE_CLASS(ENUM_EQUALITYCOMPARER, CollectionsGeneric, EnumEqualityComparer`1) -DEFINE_CLASS(LONG_ENUM_EQUALITYCOMPARER, CollectionsGeneric, LongEnumEqualityComparer`1) DEFINE_CLASS(NULLABLE_EQUALITYCOMPARER, CollectionsGeneric, NullableEqualityComparer`1) DEFINE_CLASS(GENERIC_EQUALITYCOMPARER, CollectionsGeneric, GenericEqualityComparer`1) DEFINE_CLASS(OBJECT_EQUALITYCOMPARER, CollectionsGeneric, ObjectEqualityComparer`1)