From 6889a7bb7712ec7fa32f37ce54ccb74f9e5ad4b8 Mon Sep 17 00:00:00 2001 From: Ben Adams Date: Sat, 30 Nov 2019 01:58:52 +0000 Subject: [PATCH] Intrinsicify SequenceCompareTo(char) --- .../src/System/SpanHelpers.Char.cs | 151 ++++++++++++++---- 1 file changed, 121 insertions(+), 30 deletions(-) diff --git a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Char.cs b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Char.cs index 25dc8f6fae4c6..5d6536e411fa9 100644 --- a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Char.cs +++ b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Char.cs @@ -66,68 +66,155 @@ public static int IndexOf(ref char searchSpace, int searchSpaceLength, ref char } [MethodImpl(MethodImplOptions.AggressiveOptimization)] - public static unsafe int SequenceCompareTo(ref char first, int firstLength, ref char second, int secondLength) + public static unsafe int SequenceCompareTo(ref char firstStart, int firstLength, ref char secondStart, int secondLength) { Debug.Assert(firstLength >= 0); Debug.Assert(secondLength >= 0); - int lengthDelta = firstLength - secondLength; - - if (Unsafe.AreSame(ref first, ref second)) + if (Unsafe.AreSame(ref firstStart, ref secondStart)) goto Equal; - IntPtr minLength = (IntPtr)((firstLength < secondLength) ? firstLength : secondLength); - IntPtr i = (IntPtr)0; // Use IntPtr for arithmetic to avoid unnecessary 64->32->64 truncations + int minLength = (firstLength < secondLength) ? firstLength : secondLength; + + int offset = 0; + int lengthToExamine = minLength; - if ((byte*)minLength >= (byte*)(sizeof(UIntPtr) / sizeof(char))) + if (Avx2.IsSupported) { - if (Vector.IsHardwareAccelerated && (byte*)minLength >= (byte*)Vector.Count) + // When we move into a Vectorized block, we process everything of Vector size; + // and then for any remainder we do a final compare of Vector size but starting at + // the end and forwards, which may overlap on an earlier compare. + if (lengthToExamine >= Vector256.Count) { - IntPtr nLength = minLength - Vector.Count; - do + lengthToExamine -= Vector256.Count; + uint matches; + while (lengthToExamine > offset) { - if (Unsafe.ReadUnaligned>(ref Unsafe.As(ref Unsafe.Add(ref first, i))) != - Unsafe.ReadUnaligned>(ref Unsafe.As(ref Unsafe.Add(ref second, i)))) + matches = (uint)Avx2.MoveMask(Avx2.CompareEqual(LoadVector256(ref firstStart, offset), LoadVector256(ref secondStart, offset)).AsByte()); + // Note that MoveMask has converted the equal vector elements into a set of bit flags, + // So the bit position in 'matches' corresponds to the element offset. + + // 32 elements in Vector256 so we compare to uint.MaxValue to check if everything matched + if (matches == uint.MaxValue) { - break; + // All matched + offset += Vector256.Count; + continue; } - i += Vector.Count; + + goto Difference; + } + // Move to Vector length from end for final compare + offset = lengthToExamine; + // Same as method as above + matches = (uint)Avx2.MoveMask(Avx2.CompareEqual(LoadVector256(ref firstStart, offset), LoadVector256(ref secondStart, offset)).AsByte()); + if (matches == uint.MaxValue) + { + // All matched + goto Equal; } - while ((byte*)nLength >= (byte*)i); + Difference: + // Invert matches to find differences + uint differences = ~matches; + // Find bitflag offset of first difference and add to current offset, + // flags are in bytes so divide for chars + offset += BitOperations.TrailingZeroCount((int)differences) / sizeof(char); + + int result = Unsafe.Add(ref firstStart, offset).CompareTo(Unsafe.Add(ref secondStart, offset)); + Debug.Assert(result != 0); + + return result; } + } - while ((byte*)minLength >= (byte*)(i + sizeof(UIntPtr) / sizeof(char))) + if (Sse2.IsSupported) + { + // When we move into a Vectorized block, we process everything of Vector size; + // and then for any remainder we do a final compare of Vector size but starting at + // the end and forwards, which may overlap on an earlier compare. + if (lengthToExamine >= Vector128.Count) { - if (Unsafe.ReadUnaligned(ref Unsafe.As(ref Unsafe.Add(ref first, i))) != - Unsafe.ReadUnaligned(ref Unsafe.As(ref Unsafe.Add(ref second, i)))) + lengthToExamine -= Vector128.Count; + uint matches; + while (lengthToExamine > offset) { - break; + matches = (uint)Sse2.MoveMask(Sse2.CompareEqual(LoadVector128(ref firstStart, offset), LoadVector128(ref secondStart, offset)).AsByte()); + // Note that MoveMask has converted the equal vector elements into a set of bit flags, + // So the bit position in 'matches' corresponds to the element offset. + + // 16 elements in Vector128 so we compare to ushort.MaxValue to check if everything matched + if (matches == ushort.MaxValue) + { + // All matched + offset += Vector128.Count; + continue; + } + + goto Difference; } - i += sizeof(UIntPtr) / sizeof(char); + // Move to Vector length from end for final compare + offset = lengthToExamine; + // Same as method as above + matches = (uint)Sse2.MoveMask(Sse2.CompareEqual(LoadVector128(ref firstStart, offset), LoadVector128(ref secondStart, offset)).AsByte()); + if (matches == ushort.MaxValue) + { + // All matched + goto Equal; + } + Difference: + // Invert matches to find differences + uint differences = ~matches; + // Find bitflag offset of first difference and add to current offset, + // flags are in bytes so divide for chars + offset += BitOperations.TrailingZeroCount((int)differences) / sizeof(char); + + int result = Unsafe.Add(ref firstStart, offset).CompareTo(Unsafe.Add(ref secondStart, offset)); + Debug.Assert(result != 0); + + return result; + } + } + else if (Vector.IsHardwareAccelerated) + { + if (lengthToExamine > Vector.Count) + { + lengthToExamine -= Vector.Count; + while (lengthToExamine > offset) + { + if (LoadVector(ref firstStart, offset) != LoadVector(ref secondStart, offset)) + { + goto CharwiseCheck; + } + offset += Vector.Count; + } + goto CharwiseCheck; } } -#if BIT64 - if ((byte*)minLength >= (byte*)(i + sizeof(int) / sizeof(char))) + if (lengthToExamine > sizeof(UIntPtr) / sizeof(char)) { - if (Unsafe.ReadUnaligned(ref Unsafe.As(ref Unsafe.Add(ref first, i))) == - Unsafe.ReadUnaligned(ref Unsafe.As(ref Unsafe.Add(ref second, i)))) + lengthToExamine -= sizeof(UIntPtr) / sizeof(char); + while (lengthToExamine > offset) { - i += sizeof(int) / sizeof(char); + if (LoadUIntPtr(ref firstStart, offset) != LoadUIntPtr(ref secondStart, offset)) + { + goto CharwiseCheck; + } + offset += sizeof(UIntPtr) / sizeof(char); } } -#endif - while ((byte*)i < (byte*)minLength) + CharwiseCheck: + while (minLength > offset) { - int result = Unsafe.Add(ref first, i).CompareTo(Unsafe.Add(ref second, i)); + int result = Unsafe.Add(ref firstStart, offset).CompareTo(Unsafe.Add(ref secondStart, offset)); if (result != 0) return result; - i += 1; + offset += 1; } Equal: - return lengthDelta; + return firstLength - secondLength; } // Adapted from IndexOf(...) @@ -1033,6 +1120,10 @@ private static int LocateLastFoundChar(ulong match) return 3 - (BitOperations.LeadingZeroCount(match) >> 4); } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe UIntPtr LoadUIntPtr(ref char start, nint offset) + => Unsafe.ReadUnaligned(ref Unsafe.As(ref Unsafe.Add(ref start, (IntPtr)offset))); + [MethodImpl(MethodImplOptions.AggressiveInlining)] private static unsafe Vector LoadVector(ref char start, nint offset) => Unsafe.ReadUnaligned>(ref Unsafe.As(ref Unsafe.Add(ref start, (IntPtr)offset)));