Skip to content
This repository has been archived by the owner on Jan 23, 2023. It is now read-only.

Commit

Permalink
Streamline default EqualityComparer and Comparer for Enums (#21604)
Browse files Browse the repository at this point in the history
This borrows the implementation strategy for these from CoreRT. It makes it both simpler (fewer types and lines of code) and faster in some cases since we always use the exact right underlying type.

E.g. The following micro-benchmark is 25% faster with this change:

```
enum MyEnum : byte { x, y };

var comparer = Comparer<MyEnum>.Default;

for (int i = 0; i < 100000000; i++)
{
    comparer.Compare(MyEnum.x, MyEnum.y);
    comparer.Compare(MyEnum.y, MyEnum.x);
}
```
  • Loading branch information
jkotas committed Dec 21, 2018
1 parent 63ab188 commit 82e02f3
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 251 deletions.
103 changes: 4 additions & 99 deletions src/System.Private.CoreLib/src/System/Collections/Generic/Comparer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -130,21 +130,16 @@ public override int Compare(T x, T y)
// since we want to serialize as ObjectComparer for
// back-compat reasons (see below).
[Serializable]
internal sealed class Int32EnumComparer<T> : Comparer<T>, ISerializable where T : struct
internal sealed class EnumComparer<T> : Comparer<T>, ISerializable where T : struct, Enum
{
public Int32EnumComparer()
{
Debug.Assert(typeof(T).IsEnum);
}
internal EnumComparer() { }

// Used by the serialization engine.
private Int32EnumComparer(SerializationInfo info, StreamingContext context) { }
private EnumComparer(SerializationInfo info, StreamingContext context) { }

public override int Compare(T x, T y)
{
int ix = JitHelpers.UnsafeEnumCast(x);
int iy = JitHelpers.UnsafeEnumCast(y);
return ix.CompareTo(iy);
return System.Runtime.CompilerServices.JitHelpers.EnumCompareTo(x, y);
}

// Equals method for the comparer itself.
Expand All @@ -163,94 +158,4 @@ public void GetObjectData(SerializationInfo info, StreamingContext context)
info.SetType(typeof(ObjectComparer<T>));
}
}

[Serializable]
internal sealed class UInt32EnumComparer<T> : Comparer<T>, ISerializable where T : struct
{
public UInt32EnumComparer()
{
Debug.Assert(typeof(T).IsEnum);
}

// Used by the serialization engine.
private UInt32EnumComparer(SerializationInfo info, StreamingContext context) { }

public override int Compare(T x, T y)
{
uint ix = (uint)JitHelpers.UnsafeEnumCast(x);
uint iy = (uint)JitHelpers.UnsafeEnumCast(y);
return ix.CompareTo(iy);
}

// Equals method for the comparer itself.
public override bool Equals(object obj) =>
obj != null && GetType() == obj.GetType();

public override int GetHashCode() =>
GetType().GetHashCode();

public void GetObjectData(SerializationInfo info, StreamingContext context)
{
info.SetType(typeof(ObjectComparer<T>));
}
}

[Serializable]
internal sealed class Int64EnumComparer<T> : Comparer<T>, ISerializable where T : struct
{
public Int64EnumComparer()
{
Debug.Assert(typeof(T).IsEnum);
}

public override int Compare(T x, T y)
{
long lx = JitHelpers.UnsafeEnumCastLong(x);
long ly = JitHelpers.UnsafeEnumCastLong(y);
return lx.CompareTo(ly);
}

// Equals method for the comparer itself.
public override bool Equals(object obj) =>
obj != null && GetType() == obj.GetType();

public override int GetHashCode() =>
GetType().GetHashCode();

public void GetObjectData(SerializationInfo info, StreamingContext context)
{
info.SetType(typeof(ObjectComparer<T>));
}
}

[Serializable]
internal sealed class UInt64EnumComparer<T> : Comparer<T>, ISerializable where T : struct
{
public UInt64EnumComparer()
{
Debug.Assert(typeof(T).IsEnum);
}

// Used by the serialization engine.
private UInt64EnumComparer(SerializationInfo info, StreamingContext context) { }

public override int Compare(T x, T y)
{
ulong lx = (ulong)JitHelpers.UnsafeEnumCastLong(x);
ulong ly = (ulong)JitHelpers.UnsafeEnumCastLong(y);
return lx.CompareTo(ly);
}

// Equals method for the comparer itself.
public override bool Equals(object obj) =>
obj != null && GetType() == obj.GetType();

public override int GetHashCode() =>
GetType().GetHashCode();

public void GetObjectData(SerializationInfo info, StreamingContext context)
{
info.SetType(typeof(ObjectComparer<T>));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,12 @@ private static object TryCreateEnumComparer(RuntimeType enumType)
case TypeCode.SByte:
case TypeCode.Int16:
case TypeCode.Int32:
return RuntimeTypeHandle.CreateInstanceForAnotherGenericParameter((RuntimeType)typeof(Int32EnumComparer<int>), enumType);
case TypeCode.Byte:
case TypeCode.UInt16:
case TypeCode.UInt32:
return RuntimeTypeHandle.CreateInstanceForAnotherGenericParameter((RuntimeType)typeof(UInt32EnumComparer<uint>), enumType);
// 64-bit enums: Use `UnsafeEnumCastLong`
case TypeCode.Int64:
return RuntimeTypeHandle.CreateInstanceForAnotherGenericParameter((RuntimeType)typeof(Int64EnumComparer<long>), enumType);
case TypeCode.UInt64:
return RuntimeTypeHandle.CreateInstanceForAnotherGenericParameter((RuntimeType)typeof(UInt64EnumComparer<ulong>), enumType);
return RuntimeTypeHandle.CreateInstanceForAnotherGenericParameter((RuntimeType)typeof(EnumComparer<>), enumType);
}

return null;
Expand Down Expand Up @@ -194,11 +190,10 @@ private static object TryCreateEnumEqualityComparer(RuntimeType enumType)
case TypeCode.SByte:
case TypeCode.Byte:
case TypeCode.Int16:
case TypeCode.UInt16:
return RuntimeTypeHandle.CreateInstanceForAnotherGenericParameter((RuntimeType)typeof(EnumEqualityComparer<int>), enumType);
case TypeCode.Int64:
case TypeCode.UInt64:
return RuntimeTypeHandle.CreateInstanceForAnotherGenericParameter((RuntimeType)typeof(LongEnumEqualityComparer<long>), enumType);
case TypeCode.UInt16:
return RuntimeTypeHandle.CreateInstanceForAnotherGenericParameter((RuntimeType)typeof(EnumEqualityComparer<>), enumType);
}

return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ public override int GetHashCode() =>
[Serializable]
[System.Runtime.CompilerServices.TypeForwardedFrom("mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089")]
// Needs to be public to support binary serialization compatibility
public sealed class EnumEqualityComparer<T> : EqualityComparer<T>, ISerializable where T : struct
public sealed class EnumEqualityComparer<T> : EqualityComparer<T>, ISerializable where T : struct, Enum
{
internal EnumEqualityComparer() { }

Expand All @@ -333,9 +333,7 @@ public void GetObjectData(SerializationInfo info, StreamingContext context)
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public override bool Equals(T x, T y)
{
int x_final = System.Runtime.CompilerServices.JitHelpers.UnsafeEnumCast(x);
int y_final = System.Runtime.CompilerServices.JitHelpers.UnsafeEnumCast(y);
return x_final == y_final;
return System.Runtime.CompilerServices.JitHelpers.EnumEquals(x, y);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
Expand All @@ -353,85 +351,20 @@ public override int GetHashCode() =>

internal override int IndexOf(T[] array, T value, int startIndex, int count)
{
int toFind = JitHelpers.UnsafeEnumCast(value);
int endIndex = startIndex + count;
for (int i = startIndex; i < endIndex; i++)
{
int current = JitHelpers.UnsafeEnumCast(array[i]);
if (toFind == current) return i;
if (System.Runtime.CompilerServices.JitHelpers.EnumEquals(array[i], value)) return i;
}
return -1;
}

internal override int LastIndexOf(T[] array, T value, int startIndex, int count)
{
int toFind = JitHelpers.UnsafeEnumCast(value);
int endIndex = startIndex - count + 1;
for (int i = startIndex; i >= endIndex; i--)
{
int current = JitHelpers.UnsafeEnumCast(array[i]);
if (toFind == current) return i;
}
return -1;
}
}

[Serializable]
internal sealed class LongEnumEqualityComparer<T> : EqualityComparer<T>, ISerializable where T : struct
{
internal LongEnumEqualityComparer() { }

// This is used by the serialization engine.
private LongEnumEqualityComparer(SerializationInfo information, StreamingContext context) { }

public void GetObjectData(SerializationInfo info, StreamingContext context)
{
// The LongEnumEqualityComparer does not exist on 4.0 so we need to serialize this comparer as ObjectEqualityComparer
// to allow for roundtrip between 4.0 and 4.5.
info.SetType(typeof(ObjectEqualityComparer<T>));
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public override bool Equals(T x, T y)
{
long x_final = System.Runtime.CompilerServices.JitHelpers.UnsafeEnumCastLong(x);
long y_final = System.Runtime.CompilerServices.JitHelpers.UnsafeEnumCastLong(y);
return x_final == y_final;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public override int GetHashCode(T obj)
{
return obj.GetHashCode();
}

// Equals method for the comparer itself.
public override bool Equals(object obj) =>
obj != null && GetType() == obj.GetType();

public override int GetHashCode() =>
GetType().GetHashCode();

internal override int IndexOf(T[] array, T value, int startIndex, int count)
{
long toFind = JitHelpers.UnsafeEnumCastLong(value);
int endIndex = startIndex + count;
for (int i = startIndex; i < endIndex; i++)
{
long current = JitHelpers.UnsafeEnumCastLong(array[i]);
if (toFind == current) return i;
}
return -1;
}

internal override int LastIndexOf(T[] array, T value, int startIndex, int count)
{
long toFind = JitHelpers.UnsafeEnumCastLong(value);
int endIndex = startIndex - count + 1;
for (int i = startIndex; i >= endIndex; i--)
{
long current = JitHelpers.UnsafeEnumCastLong(array[i]);
if (toFind == current) return i;
if (System.Runtime.CompilerServices.JitHelpers.EnumEquals(array[i], value)) return i;
}
return -1;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,59 +85,20 @@ internal static StackCrawlMarkHandle GetStackCrawlMarkHandle(ref StackCrawlMark
return new StackCrawlMarkHandle((IntPtr)Unsafe.AsPointer(ref stackMark));
}

#if DEBUG
internal static int UnsafeEnumCast<T>(T val) where T : struct // Actually T must be 4 byte (or less) enum
internal static bool EnumEquals<T>(T x, T y) where T : struct, Enum
{
Debug.Assert(typeof(T).IsEnum
&& (Enum.GetUnderlyingType(typeof(T)) == typeof(int)
|| Enum.GetUnderlyingType(typeof(T)) == typeof(uint)
|| Enum.GetUnderlyingType(typeof(T)) == typeof(short)
|| Enum.GetUnderlyingType(typeof(T)) == typeof(ushort)
|| Enum.GetUnderlyingType(typeof(T)) == typeof(byte)
|| Enum.GetUnderlyingType(typeof(T)) == typeof(sbyte)),
"Error, T must be an 4 byte (or less) enum JitHelpers.UnsafeEnumCast!");
return UnsafeEnumCastInternal<T>(val);
// The body of this function will be replaced by the EE with unsafe code
// See getILIntrinsicImplementation for how this happens.
return x.Equals(y);
}

private static int UnsafeEnumCastInternal<T>(T val) where T : struct // Actually T must be 4 (or less) byte enum
internal static int EnumCompareTo<T>(T x, T y) where T : struct, Enum
{
// should be return (int) val; but C# does not allow, runtime does this magically
// See getILIntrinsicImplementation for how this happens.
throw new InvalidOperationException();
// The body of this function will be replaced by the EE with unsafe code
// See getILIntrinsicImplementation for how this happens.
return x.CompareTo(y);
}

internal static long UnsafeEnumCastLong<T>(T val) where T : struct // Actually T must be 8 byte enum
{
Debug.Assert(typeof(T).IsEnum
&& (Enum.GetUnderlyingType(typeof(T)) == typeof(long)
|| Enum.GetUnderlyingType(typeof(T)) == typeof(ulong)),
"Error, T must be an 8 byte enum JitHelpers.UnsafeEnumCastLong!");
return UnsafeEnumCastLongInternal<T>(val);
}

private static long UnsafeEnumCastLongInternal<T>(T val) where T : struct // Actually T must be 8 byte enum
{
// should be return (int) val; but C# does not allow, runtime does this magically
// See getILIntrinsicImplementation for how this happens.
throw new InvalidOperationException();
}
#else // DEBUG

internal static int UnsafeEnumCast<T>(T val) where T : struct // Actually T must be 4 byte (or less) enum
{
// should be return (int) val; but C# does not allow, runtime does this magically
// See getILIntrinsicImplementation for how this happens.
throw new InvalidOperationException();
}

internal static long UnsafeEnumCastLong<T>(T val) where T : struct // Actually T must be 8 byte enum
{
// should be return (long) val; but C# does not allow, runtime does this magically
// See getILIntrinsicImplementation for how this happens.
throw new InvalidOperationException();
}
#endif // DEBUG

// Set the given element in the array without any type or range checks
[MethodImplAttribute(MethodImplOptions.InternalCall)]
internal static extern void UnsafeSetArrayElement(object[] target, int index, object element);
Expand Down
Loading

0 comments on commit 82e02f3

Please sign in to comment.