diff --git a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs index 93d15dec1ae2..7db0c81e9602 100644 --- a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs +++ b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs @@ -3320,37 +3320,61 @@ internal static int LastIndexOfAnyExceptInRangeUnsignedNumber(ref T searchSpa public static int CountValueType(ref T current, T value, int length) where T : struct, IEquatable? { int count = 0; - ref T end = ref Unsafe.Add(ref current, length); + if (Vector128.IsHardwareAccelerated && length >= Vector128.Count) { if (Vector256.IsHardwareAccelerated && length >= Vector256.Count) { Vector256 targetVector = Vector256.Create(value); - ref T oneVectorAwayFromEndMinus1 = ref Unsafe.Subtract(ref end, Vector256.Count - 1); + ref T oneVectorAwayFromEnd = ref Unsafe.Subtract(ref end, Vector256.Count); do { count += BitOperations.PopCount(Vector256.Equals(Vector256.LoadUnsafe(ref current), targetVector).ExtractMostSignificantBits()); current = ref Unsafe.Add(ref current, Vector256.Count); } - while (Unsafe.IsAddressLessThan(ref current, ref oneVectorAwayFromEndMinus1)); + while (!Unsafe.IsAddressGreaterThan(ref current, ref oneVectorAwayFromEnd)); - if (Unsafe.IsAddressLessThan(ref current, ref Unsafe.Subtract(ref end, Vector128.Count - 1))) + // If there are just a few elements remaining, then processing these elements by the scalar loop + // is cheaper than doing bitmask + popcount on the full last vector. To avoid complicated type + // based checks, other remainder-count based logic to determine the correct cut-off, for simplicity + // a half-vector size is chosen (based on benchmarks). + uint remaining = (uint)Unsafe.ByteOffset(ref current, ref end) / (uint)Unsafe.SizeOf(); + if (remaining > Vector256.Count / 2) { - count += BitOperations.PopCount(Vector128.Equals(Vector128.LoadUnsafe(ref current), Vector128.Create(value)).ExtractMostSignificantBits()); - current = ref Unsafe.Add(ref current, Vector128.Count); + uint mask = Vector256.Equals(Vector256.LoadUnsafe(ref oneVectorAwayFromEnd), targetVector).ExtractMostSignificantBits(); + + // The mask contains some elements that may be double-checked, so shift them away in order to get the correct pop-count. + uint overlaps = (uint)Vector256.Count - remaining; + mask >>= (int)overlaps; + count += BitOperations.PopCount(mask); + + return count; } } else { Vector128 targetVector = Vector128.Create(value); - ref T oneVectorAwayFromEndMinus1 = ref Unsafe.Subtract(ref end, Vector128.Count - 1); + ref T oneVectorAwayFromEnd = ref Unsafe.Subtract(ref end, Vector128.Count); do { count += BitOperations.PopCount(Vector128.Equals(Vector128.LoadUnsafe(ref current), targetVector).ExtractMostSignificantBits()); current = ref Unsafe.Add(ref current, Vector128.Count); } - while (Unsafe.IsAddressLessThan(ref current, ref oneVectorAwayFromEndMinus1)); + while (!Unsafe.IsAddressGreaterThan(ref current, ref oneVectorAwayFromEnd)); + + uint remaining = (uint)Unsafe.ByteOffset(ref current, ref end) / (uint)Unsafe.SizeOf(); + if (remaining > Vector128.Count / 2) + { + uint mask = Vector128.Equals(Vector128.LoadUnsafe(ref oneVectorAwayFromEnd), targetVector).ExtractMostSignificantBits(); + + // The mask contains some elements that may be double-checked, so shift them away in order to get the correct pop-count. + uint overlaps = (uint)Vector128.Count - remaining; + mask >>= (int)overlaps; + count += BitOperations.PopCount(mask); + + return count; + } } }