From 75094fdef6defe4ac4b977cfcb1ad995ae1c7729 Mon Sep 17 00:00:00 2001 From: Ben Adams Date: Fri, 25 Jan 2019 03:48:43 +0100 Subject: [PATCH] Speedup .SequenceCompareTo(byte, ...) (dotnet/coreclr#22127) * Speedup .SequenceCompareTo(byte, ...) * Rename jump location * Better annotations for clarity Signed-off-by: dotnet-bot --- .../shared/System/SpanHelpers.Byte.cs | 213 +++++++++++++++--- 1 file changed, 184 insertions(+), 29 deletions(-) diff --git a/src/System.Private.CoreLib/shared/System/SpanHelpers.Byte.cs b/src/System.Private.CoreLib/shared/System/SpanHelpers.Byte.cs index 63a564f0de7..3062a405b51 100644 --- a/src/System.Private.CoreLib/shared/System/SpanHelpers.Byte.cs +++ b/src/System.Private.CoreLib/shared/System/SpanHelpers.Byte.cs @@ -276,13 +276,16 @@ public static unsafe int IndexOf(ref byte searchSpace, byte value, int length) { Vector256 search = LoadVector256(ref searchSpace, offset); int matches = Avx2.MoveMask(Avx2.CompareEqual(values, search)); + // Note that MoveMask has converted the equal vector elements into a set of bit flags, + // So the bit position in 'matches' corresponds to the element offset. if (matches == 0) { + // Zero flags set so no matches offset += Vector256.Count; continue; } - // Find offset of first match + // Find bitflag offset of first match and add to current offset return ((int)(byte*)offset) + BitOps.TrailingZeroCount(matches); } while ((byte*)nLength > (byte*)offset); } @@ -293,14 +296,16 @@ public static unsafe int IndexOf(ref byte searchSpace, byte value, int length) Vector128 values = Vector128.Create(value); Vector128 search = LoadVector128(ref searchSpace, offset); + // Same method as above int matches = Sse2.MoveMask(Sse2.CompareEqual(values, search)); if (matches == 0) { + // Zero flags set so no matches offset += Vector128.Count; } else { - // Find offset of first match + // Find bitflag offset of first match and add to current offset return ((int)(byte*)offset) + BitOps.TrailingZeroCount(matches); } } @@ -323,14 +328,16 @@ public static unsafe int IndexOf(ref byte searchSpace, byte value, int length) { Vector128 search = LoadVector128(ref searchSpace, offset); + // Same method as above int matches = Sse2.MoveMask(Sse2.CompareEqual(values, search)); if (matches == 0) { + // Zero flags set so no matches offset += Vector128.Count; continue; } - // Find offset of first match + // Find bitflag offset of first match and add to current offset return ((int)(byte*)offset) + BitOps.TrailingZeroCount(matches); } @@ -358,7 +365,7 @@ public static unsafe int IndexOf(ref byte searchSpace, byte value, int length) continue; } - // Find offset of first match + // Find offset of first match and add to current offset return (int)(byte*)offset + LocateFirstFoundByte(matches); } @@ -499,7 +506,7 @@ public static unsafe int LastIndexOf(ref byte searchSpace, byte value, int lengt continue; } - // Find offset of first match + // Find offset of first match and add to current offset return (int)(offset) - Vector.Count + LocateLastFoundByte(matches); } if ((byte*)offset > (byte*)0) @@ -628,15 +635,19 @@ public static unsafe int IndexOfAny(ref byte searchSpace, byte value0, byte valu do { Vector256 search = LoadVector256(ref searchSpace, offset); + // Note that MoveMask has converted the equal vector elements into a set of bit flags, + // So the bit position in 'matches' corresponds to the element offset. int matches = Avx2.MoveMask(Avx2.CompareEqual(values0, search)); + // Bitwise Or to combine the flagged matches for the second value to our match flags matches |= Avx2.MoveMask(Avx2.CompareEqual(values1, search)); if (matches == 0) { + // Zero flags set so no matches offset += Vector256.Count; continue; } - // Find offset of first match + // Find bitflag offset of first match and add to current offset return ((int)(byte*)offset) + BitOps.TrailingZeroCount(matches); } while ((byte*)nLength > (byte*)offset); } @@ -648,15 +659,17 @@ public static unsafe int IndexOfAny(ref byte searchSpace, byte value0, byte valu Vector128 values1 = Vector128.Create(value1); Vector128 search = LoadVector128(ref searchSpace, offset); + // Same method as above int matches = Sse2.MoveMask(Sse2.CompareEqual(values0, search)); matches |= Sse2.MoveMask(Sse2.CompareEqual(values1, search)); if (matches == 0) { + // Zero flags set so no matches offset += Vector128.Count; } else { - // Find offset of first match + // Find bitflag offset of first match and add to current offset return ((int)(byte*)offset) + BitOps.TrailingZeroCount(matches); } } @@ -680,15 +693,17 @@ public static unsafe int IndexOfAny(ref byte searchSpace, byte value0, byte valu while ((byte*)nLength > (byte*)offset) { Vector128 search = LoadVector128(ref searchSpace, offset); + // Same method as above int matches = Sse2.MoveMask(Sse2.CompareEqual(values0, search)); matches |= Sse2.MoveMask(Sse2.CompareEqual(values1, search)); if (matches == 0) { + // Zero flags set so no matches offset += Vector128.Count; continue; } - // Find offset of first match + // Find bitflag offset of first match and add to current offset return ((int)(byte*)offset) + BitOps.TrailingZeroCount(matches); } @@ -720,7 +735,7 @@ public static unsafe int IndexOfAny(ref byte searchSpace, byte value0, byte valu continue; } - // Find offset of first match + // Find offset of first match and add to current offset return (int)(byte*)offset + LocateFirstFoundByte(matches); } @@ -755,8 +770,8 @@ public static unsafe int IndexOfAny(ref byte searchSpace, byte value0, byte valu Debug.Assert(length >= 0); uint uValue0 = value0; // Use uint for comparisons to avoid unnecessary 8->32 extensions - uint uValue1 = value1; // Use uint for comparisons to avoid unnecessary 8->32 extensions - uint uValue2 = value2; // Use uint for comparisons to avoid unnecessary 8->32 extensions + uint uValue1 = value1; + uint uValue2 = value2; IntPtr offset = (IntPtr)0; // Use IntPtr for arithmetic to avoid unnecessary 64->32->64 truncations IntPtr nLength = (IntPtr)length; @@ -853,16 +868,21 @@ public static unsafe int IndexOfAny(ref byte searchSpace, byte value0, byte valu do { Vector256 search = LoadVector256(ref searchSpace, offset); + // Note that MoveMask has converted the equal vector elements into a set of bit flags, + // So the bit position in 'matches' corresponds to the element offset. int matches = Avx2.MoveMask(Avx2.CompareEqual(values0, search)); + // Bitwise Or to combine the flagged matches for the second value to our match flags matches |= Avx2.MoveMask(Avx2.CompareEqual(values1, search)); + // Bitwise Or to combine the flagged matches for the third value to our match flags matches |= Avx2.MoveMask(Avx2.CompareEqual(values2, search)); if (matches == 0) { + // Zero flags set so no matches offset += Vector256.Count; continue; } - // Find offset of first match + // Find bitflag offset of first match and add to current offset return ((int)(byte*)offset) + BitOps.TrailingZeroCount(matches); } while ((byte*)nLength > (byte*)offset); } @@ -875,16 +895,18 @@ public static unsafe int IndexOfAny(ref byte searchSpace, byte value0, byte valu Vector128 values2 = Vector128.Create(value2); Vector128 search = LoadVector128(ref searchSpace, offset); + // Same method as above int matches = Sse2.MoveMask(Sse2.CompareEqual(values0, search)); matches |= Sse2.MoveMask(Sse2.CompareEqual(values1, search)); matches |= Sse2.MoveMask(Sse2.CompareEqual(values2, search)); if (matches == 0) { + // Zero flags set so no matches offset += Vector128.Count; } else { - // Find offset of first match + // Find bitflag offset of first match and add to current offset return ((int)(byte*)offset) + BitOps.TrailingZeroCount(matches); } } @@ -909,16 +931,18 @@ public static unsafe int IndexOfAny(ref byte searchSpace, byte value0, byte valu while ((byte*)nLength > (byte*)offset) { Vector128 search = LoadVector128(ref searchSpace, offset); + // Same method as above int matches = Sse2.MoveMask(Sse2.CompareEqual(values0, search)); matches |= Sse2.MoveMask(Sse2.CompareEqual(values1, search)); matches |= Sse2.MoveMask(Sse2.CompareEqual(values2, search)); if (matches == 0) { + // Zero flags set so no matches offset += Vector128.Count; continue; } - // Find offset of first match + // Find bitflag offset of first match and add to current offset return ((int)(byte*)offset) + BitOps.TrailingZeroCount(matches); } @@ -955,7 +979,7 @@ public static unsafe int IndexOfAny(ref byte searchSpace, byte value0, byte valu continue; } - // Find offset of first match + // Find offset of first match and add to current offset return (int)(byte*)offset + LocateFirstFoundByte(matches); } @@ -990,7 +1014,7 @@ public static unsafe int LastIndexOfAny(ref byte searchSpace, byte value0, byte Debug.Assert(length >= 0); uint uValue0 = value0; // Use uint for comparisons to avoid unnecessary 8->32 extensions - uint uValue1 = value1; // Use uint for comparisons to avoid unnecessary 8->32 extensions + uint uValue1 = value1; IntPtr offset = (IntPtr)length; // Use IntPtr for arithmetic to avoid unnecessary 64->32->64 truncations IntPtr nLength = (IntPtr)length; @@ -1080,7 +1104,7 @@ public static unsafe int LastIndexOfAny(ref byte searchSpace, byte value0, byte continue; } - // Find offset of first match + // Find offset of first match and add to current offset return (int)(offset) - Vector.Count + LocateLastFoundByte(matches); } @@ -1114,8 +1138,8 @@ public static unsafe int LastIndexOfAny(ref byte searchSpace, byte value0, byte Debug.Assert(length >= 0); uint uValue0 = value0; // Use uint for comparisons to avoid unnecessary 8->32 extensions - uint uValue1 = value1; // Use uint for comparisons to avoid unnecessary 8->32 extensions - uint uValue2 = value2; // Use uint for comparisons to avoid unnecessary 8->32 extensions + uint uValue1 = value1; + uint uValue2 = value2; IntPtr offset = (IntPtr)length; // Use IntPtr for arithmetic to avoid unnecessary 64->32->64 truncations IntPtr nLength = (IntPtr)length; @@ -1210,7 +1234,7 @@ public static unsafe int LastIndexOfAny(ref byte searchSpace, byte value0, byte continue; } - // Find offset of first match + // Find offset of first match and add to current offset return (int)(offset) - Vector.Count + LocateLastFoundByte(matches); } @@ -1324,18 +1348,149 @@ public static unsafe int SequenceCompareTo(ref byte first, int firstLength, ref IntPtr offset = (IntPtr)0; // Use IntPtr for arithmetic to avoid unnecessary 64->32->64 truncations IntPtr nLength = (IntPtr)(void*)minLength; - if (Vector.IsHardwareAccelerated && (byte*)nLength > (byte*)Vector.Count) + if (Avx2.IsSupported) { - nLength -= Vector.Count; - while ((byte*)nLength > (byte*)offset) + if ((byte*)nLength >= (byte*)Vector256.Count) { - if (LoadVector(ref first, offset) != LoadVector(ref second, offset)) + nLength -= Vector256.Count; + uint matches; + while ((byte*)nLength > (byte*)offset) { - goto NotEqual; + matches = (uint)Avx2.MoveMask(Avx2.CompareEqual(LoadVector256(ref first, offset), LoadVector256(ref second, offset))); + // Note that MoveMask has converted the equal vector elements into a set of bit flags, + // So the bit position in 'matches' corresponds to the element offset. + + // 32 elements in Vector256 so we compare to uint.MaxValue to check if everything matched + if (matches == uint.MaxValue) + { + // All matched + offset += Vector256.Count; + continue; + } + + goto Difference; } - offset += Vector.Count; + // Move to Vector length from end for final compare + offset = nLength; + // Same as method as above + matches = (uint)Avx2.MoveMask(Avx2.CompareEqual(LoadVector256(ref first, offset), LoadVector256(ref second, offset))); + if (matches == uint.MaxValue) + { + // All matched + goto Equal; + } + Difference: + // Invert matches to find differences + uint differences = ~matches; + // Find bitflag offset of first difference and add to current offset + offset = (IntPtr)((int)(byte*)offset + BitOps.TrailingZeroCount((int)differences)); + + int result = Unsafe.AddByteOffset(ref first, offset).CompareTo(Unsafe.AddByteOffset(ref second, offset)); + Debug.Assert(result != 0); + + return result; + } + + if ((byte*)nLength >= (byte*)Vector128.Count) + { + nLength -= Vector128.Count; + uint matches; + if ((byte*)nLength > (byte*)offset) + { + matches = (uint)Sse2.MoveMask(Sse2.CompareEqual(LoadVector128(ref first, offset), LoadVector128(ref second, offset))); + // Note that MoveMask has converted the equal vector elements into a set of bit flags, + // So the bit position in 'matches' corresponds to the element offset. + + // 16 elements in Vector128 so we compare to ushort.MaxValue to check if everything matched + if (matches == ushort.MaxValue) + { + // All matched + offset += Vector128.Count; + } + else + { + goto Difference; + } + } + // Move to Vector length from end for final compare + offset = nLength; + // Same as method as above + matches = (uint)Sse2.MoveMask(Sse2.CompareEqual(LoadVector128(ref first, offset), LoadVector128(ref second, offset))); + if (matches == ushort.MaxValue) + { + // All matched + goto Equal; + } + Difference: + // Invert matches to find differences + uint differences = ~matches; + // Find bitflag offset of first difference and add to current offset + offset = (IntPtr)((int)(byte*)offset + BitOps.TrailingZeroCount((int)differences)); + + int result = Unsafe.AddByteOffset(ref first, offset).CompareTo(Unsafe.AddByteOffset(ref second, offset)); + Debug.Assert(result != 0); + + return result; + } + } + else if (Sse2.IsSupported) + { + if ((byte*)nLength >= (byte*)Vector128.Count) + { + nLength -= Vector128.Count; + uint matches; + while ((byte*)nLength > (byte*)offset) + { + matches = (uint)Sse2.MoveMask(Sse2.CompareEqual(LoadVector128(ref first, offset), LoadVector128(ref second, offset))); + // Note that MoveMask has converted the equal vector elements into a set of bit flags, + // So the bit position in 'matches' corresponds to the element offset. + + // 16 elements in Vector128 so we compare to ushort.MaxValue to check if everything matched + if (matches == ushort.MaxValue) + { + // All matched + offset += Vector128.Count; + continue; + } + + goto Difference; + } + // Move to Vector length from end for final compare + offset = nLength; + // Same as method as above + matches = (uint)Sse2.MoveMask(Sse2.CompareEqual(LoadVector128(ref first, offset), LoadVector128(ref second, offset))); + if (matches == ushort.MaxValue) + { + // All matched + goto Equal; + } + Difference: + // Invert matches to find differences + uint differences = ~matches; + // Find bitflag offset of first difference and add to current offset + offset = (IntPtr)((int)(byte*)offset + BitOps.TrailingZeroCount((int)differences)); + + int result = Unsafe.AddByteOffset(ref first, offset).CompareTo(Unsafe.AddByteOffset(ref second, offset)); + Debug.Assert(result != 0); + + return result; + } + } + else if (Vector.IsHardwareAccelerated) + { + if ((byte*)nLength > (byte*)Vector.Count) + { + nLength -= Vector.Count; + while ((byte*)nLength > (byte*)offset) + { + if (LoadVector(ref first, offset) != LoadVector(ref second, offset)) + { + goto BytewiseCheck; + } + offset += Vector.Count; + } + goto BytewiseCheck; } - goto NotEqual; } if ((byte*)nLength > (byte*)sizeof(UIntPtr)) @@ -1345,13 +1500,13 @@ public static unsafe int SequenceCompareTo(ref byte first, int firstLength, ref { if (LoadUIntPtr(ref first, offset) != LoadUIntPtr(ref second, offset)) { - goto NotEqual; + goto BytewiseCheck; } offset += sizeof(UIntPtr); } } - NotEqual: // Workaround for https://github.com/dotnet/coreclr/issues/13549 + BytewiseCheck: // Workaround for https://github.com/dotnet/coreclr/issues/13549 while ((byte*)minLength > (byte*)offset) { int result = Unsafe.AddByteOffset(ref first, offset).CompareTo(Unsafe.AddByteOffset(ref second, offset));