Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Begin using the xplat hardware intrinsics in BitArray #63722

Merged
merged 2 commits into from
Feb 2, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
254 changes: 67 additions & 187 deletions src/libraries/System.Collections/src/System/Collections/BitArray.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
using System.Buffers.Binary;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;
using System.Runtime.Intrinsics.Arm;
using Internal.Runtime.CompilerServices;

namespace System.Collections
{
Expand Down Expand Up @@ -145,81 +147,32 @@ public unsafe BitArray(bool[] values)
// (true for any non-zero values, false for 0) - any values between 2-255 will be interpreted as false.
// Instead, We compare with zeroes (== false) then negate the result to ensure compatibility.

if (Avx2.IsSupported)
ref byte value = ref Unsafe.As<bool, byte>(ref MemoryMarshal.GetArrayDataReference<bool>(values));
stephentoub marked this conversation as resolved.
Show resolved Hide resolved

if (Vector256.IsHardwareAccelerated)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've preserved the Vector256 path given that it was already here and I would presume has undergone the necessary checks to ensure it is worth doing on x86/x64.

Arm64 doesn't support V256 and so will only go down the V128 codepath.

{
// JIT does not support code hoisting for SIMD yet
Vector256<byte> zero = Vector256<byte>.Zero;
fixed (bool* ptr = values)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We now expose helper intrinsics that directly operate on ref: LoadUnsafe(ref T source, nuint elementOffset).

This helps avoid pinning, which can have measurable overhead for small counts and which can hinder the GC in the case of long inputs.

It likewise helps improve readability over the pattern we are already utilizing in parts of the BCL where we were using Unsafe.ReadUnaligned + Unsafe.Add + Unsafe.As.

for (; (i + Vector256ByteCount) <= (uint)values.Length; i += Vector256ByteCount)
tannergooding marked this conversation as resolved.
Show resolved Hide resolved
{
for (; (i + Vector256ByteCount) <= (uint)values.Length; i += Vector256ByteCount)
{
Vector256<byte> vector = Avx.LoadVector256((byte*)ptr + i);
Vector256<byte> isFalse = Avx2.CompareEqual(vector, zero);
int result = Avx2.MoveMask(isFalse);
m_array[i / 32u] = ~result;
}
Vector256<byte> vector = Vector256.LoadUnsafe(ref value, i);
Vector256<byte> isFalse = Vector256.Equals(vector, Vector256<byte>.Zero);

uint result = isFalse.ExtractMostSignificantBits();
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ExtractMostSignificantBits behaves just like MoveMask on x86/x64. This is also exposed by WASM as bitmask

m_array[i / 32u] = (int)(~result);
}
}
else if (Sse2.IsSupported)
else if (Vector128.IsHardwareAccelerated)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In what situation would we also want a Vector64 code path?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Vector64 can be beneficial for cases where you know the inputs are going to be commonly small and for handling the "trailing" elements (rather than falling back to a scalar loop or manually unrolled loop).

We aren't currently taking advantage of this anywhere and it would need some more work/profiling to show the extra complexity is worthwhile.

  • The extra complexity isn't from using Vector64<T> but rather from changing out the "fallback" from for (; index < length; index++) to using Vector64<T> or Vector128<T> with appropriate backtracking and masking

{
// JIT does not support code hoisting for SIMD yet
Vector128<byte> zero = Vector128<byte>.Zero;
fixed (bool* ptr = values)
for (; (i + Vector128ByteCount * 2u) <= (uint)values.Length; i += Vector128ByteCount * 2u)
{
for (; (i + Vector128ByteCount * 2u) <= (uint)values.Length; i += Vector128ByteCount * 2u)
{
Vector128<byte> lowerVector = Sse2.LoadVector128((byte*)ptr + i);
Vector128<byte> lowerIsFalse = Sse2.CompareEqual(lowerVector, zero);
int lowerPackedIsFalse = Sse2.MoveMask(lowerIsFalse);
Vector128<byte> lowerVector = Vector128.LoadUnsafe(ref value, i);
Vector128<byte> lowerIsFalse = Vector128.Equals(lowerVector, Vector128<byte>.Zero);
uint lowerResult = lowerIsFalse.ExtractMostSignificantBits();

Vector128<byte> upperVector = Sse2.LoadVector128((byte*)ptr + i + Vector128<byte>.Count);
Vector128<byte> upperIsFalse = Sse2.CompareEqual(upperVector, zero);
int upperPackedIsFalse = Sse2.MoveMask(upperIsFalse);
Vector128<byte> upperVector = Vector128.LoadUnsafe(ref value, i + Vector128ByteCount);
Vector128<byte> upperIsFalse = Vector128.Equals(upperVector, Vector128<byte>.Zero);
uint upperResult = upperIsFalse.ExtractMostSignificantBits();

m_array[i / 32u] = ~((upperPackedIsFalse << 16) | lowerPackedIsFalse);
}
}
}
else if (AdvSimd.Arm64.IsSupported)
{
// JIT does not support code hoisting for SIMD yet
// However comparison against zero can be replaced to cmeq against zero (vceqzq_s8)
// See dotnet/runtime#33972 for details
Vector128<byte> zero = Vector128<byte>.Zero;
Vector128<byte> bitMask128 = BitConverter.IsLittleEndian ?
Vector128.Create(0x80402010_08040201).AsByte() :
Vector128.Create(0x01020408_10204080).AsByte();

fixed (bool* ptr = values)
{
for (; (i + Vector128ByteCount * 2u) <= (uint)values.Length; i += Vector128ByteCount * 2u)
{
// Same logic as SSE2 path, however we lack MoveMask (equivalent) instruction
// As a workaround, mask out the relevant bit after comparison
// and combine by ORing all of them together (In this case, adding all of them does the same thing)
Vector128<byte> lowerVector = AdvSimd.LoadVector128((byte*)ptr + i);
Vector128<byte> lowerIsFalse = AdvSimd.CompareEqual(lowerVector, zero);
Vector128<byte> bitsExtracted1 = AdvSimd.And(lowerIsFalse, bitMask128);
bitsExtracted1 = AdvSimd.Arm64.AddPairwise(bitsExtracted1, bitsExtracted1);
bitsExtracted1 = AdvSimd.Arm64.AddPairwise(bitsExtracted1, bitsExtracted1);
bitsExtracted1 = AdvSimd.Arm64.AddPairwise(bitsExtracted1, bitsExtracted1);
Vector128<short> lowerPackedIsFalse = bitsExtracted1.AsInt16();

Vector128<byte> upperVector = AdvSimd.LoadVector128((byte*)ptr + i + Vector128<byte>.Count);
Vector128<byte> upperIsFalse = AdvSimd.CompareEqual(upperVector, zero);
Vector128<byte> bitsExtracted2 = AdvSimd.And(upperIsFalse, bitMask128);
bitsExtracted2 = AdvSimd.Arm64.AddPairwise(bitsExtracted2, bitsExtracted2);
bitsExtracted2 = AdvSimd.Arm64.AddPairwise(bitsExtracted2, bitsExtracted2);
bitsExtracted2 = AdvSimd.Arm64.AddPairwise(bitsExtracted2, bitsExtracted2);
Vector128<short> upperPackedIsFalse = bitsExtracted2.AsInt16();

int result = AdvSimd.Arm64.ZipLow(lowerPackedIsFalse, upperPackedIsFalse).AsInt32().ToScalar();
if (!BitConverter.IsLittleEndian)
{
result = BinaryPrimitives.ReverseEndianness(result);
}
m_array[i / 32u] = ~result;
}
m_array[i / 32u] = (int)(~((upperResult << 16) | lowerResult));
}
}

Expand Down Expand Up @@ -400,43 +353,24 @@ public unsafe BitArray And(BitArray value)
}

uint i = 0;
if (Avx2.IsSupported)
{
fixed (int* leftPtr = thisArray)
fixed (int* rightPtr = valueArray)
{
for (; i < (uint)count - (Vector256IntCount - 1u); i += Vector256IntCount)
{
Vector256<int> leftVec = Avx.LoadVector256(leftPtr + i);
Vector256<int> rightVec = Avx.LoadVector256(rightPtr + i);
Avx.Store(leftPtr + i, Avx2.And(leftVec, rightVec));
}
}
}
else if (Sse2.IsSupported)

ref int left = ref MemoryMarshal.GetArrayDataReference<int>(thisArray);
ref int right = ref MemoryMarshal.GetArrayDataReference<int>(valueArray);

if (Vector256.IsHardwareAccelerated)
{
fixed (int* leftPtr = thisArray)
fixed (int* rightPtr = valueArray)
for (; i < (uint)count - (Vector256IntCount - 1u); i += Vector256IntCount)
{
for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount)
{
Vector128<int> leftVec = Sse2.LoadVector128(leftPtr + i);
Vector128<int> rightVec = Sse2.LoadVector128(rightPtr + i);
Sse2.Store(leftPtr + i, Sse2.And(leftVec, rightVec));
}
Vector256<int> result = Vector256.LoadUnsafe(ref left, i) & Vector256.LoadUnsafe(ref right, i);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The xplat helper intrinsics support operators and so we can make this "more readable" by just using x & y.

result.StoreUnsafe(ref left, i);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Storing an intrinsic likewise no longer requires pinning or complex Unsafe logic.

}
}
else if (AdvSimd.IsSupported)
else if (Vector128.IsHardwareAccelerated)
{
fixed (int* leftPtr = thisArray)
fixed (int* rightPtr = valueArray)
for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount)
{
for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount)
{
Vector128<int> leftVec = AdvSimd.LoadVector128(leftPtr + i);
Vector128<int> rightVec = AdvSimd.LoadVector128(rightPtr + i);
AdvSimd.Store(leftPtr + i, AdvSimd.And(leftVec, rightVec));
}
Vector128<int> result = Vector128.LoadUnsafe(ref left, i) & Vector128.LoadUnsafe(ref right, i);
result.StoreUnsafe(ref left, i);
}
}

Expand Down Expand Up @@ -486,43 +420,24 @@ public unsafe BitArray Or(BitArray value)
}

uint i = 0;
if (Avx2.IsSupported)
{
fixed (int* leftPtr = thisArray)
fixed (int* rightPtr = valueArray)
{
for (; i < (uint)count - (Vector256IntCount - 1u); i += Vector256IntCount)
{
Vector256<int> leftVec = Avx.LoadVector256(leftPtr + i);
Vector256<int> rightVec = Avx.LoadVector256(rightPtr + i);
Avx.Store(leftPtr + i, Avx2.Or(leftVec, rightVec));
}
}
}
else if (Sse2.IsSupported)

ref int left = ref MemoryMarshal.GetArrayDataReference<int>(thisArray);
ref int right = ref MemoryMarshal.GetArrayDataReference<int>(valueArray);

if (Vector256.IsHardwareAccelerated)
{
fixed (int* leftPtr = thisArray)
fixed (int* rightPtr = valueArray)
for (; i < (uint)count - (Vector256IntCount - 1u); i += Vector256IntCount)
{
for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount)
{
Vector128<int> leftVec = Sse2.LoadVector128(leftPtr + i);
Vector128<int> rightVec = Sse2.LoadVector128(rightPtr + i);
Sse2.Store(leftPtr + i, Sse2.Or(leftVec, rightVec));
}
Vector256<int> result = Vector256.LoadUnsafe(ref left, i) | Vector256.LoadUnsafe(ref right, i);
result.StoreUnsafe(ref left, i);
}
}
else if (AdvSimd.IsSupported)
else if (Vector128.IsHardwareAccelerated)
{
fixed (int* leftPtr = thisArray)
fixed (int* rightPtr = valueArray)
for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount)
{
for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount)
{
Vector128<int> leftVec = AdvSimd.LoadVector128(leftPtr + i);
Vector128<int> rightVec = AdvSimd.LoadVector128(rightPtr + i);
AdvSimd.Store(leftPtr + i, AdvSimd.Or(leftVec, rightVec));
}
Vector128<int> result = Vector128.LoadUnsafe(ref left, i) | Vector128.LoadUnsafe(ref right, i);
result.StoreUnsafe(ref left, i);
}
}

Expand Down Expand Up @@ -572,43 +487,24 @@ public unsafe BitArray Xor(BitArray value)
}

uint i = 0;
if (Avx2.IsSupported)

ref int left = ref MemoryMarshal.GetArrayDataReference<int>(thisArray);
ref int right = ref MemoryMarshal.GetArrayDataReference<int>(valueArray);

if (Vector256.IsHardwareAccelerated)
{
fixed (int* leftPtr = m_array)
fixed (int* rightPtr = value.m_array)
for (; i < (uint)count - (Vector256IntCount - 1u); i += Vector256IntCount)
{
for (; i < (uint)count - (Vector256IntCount - 1u); i += Vector256IntCount)
{
Vector256<int> leftVec = Avx.LoadVector256(leftPtr + i);
Vector256<int> rightVec = Avx.LoadVector256(rightPtr + i);
Avx.Store(leftPtr + i, Avx2.Xor(leftVec, rightVec));
}
Vector256<int> result = Vector256.LoadUnsafe(ref left, i) ^ Vector256.LoadUnsafe(ref right, i);
result.StoreUnsafe(ref left, i);
}
}
else if (Sse2.IsSupported)
else if (Vector128.IsHardwareAccelerated)
{
fixed (int* leftPtr = thisArray)
fixed (int* rightPtr = valueArray)
for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount)
{
for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount)
{
Vector128<int> leftVec = Sse2.LoadVector128(leftPtr + i);
Vector128<int> rightVec = Sse2.LoadVector128(rightPtr + i);
Sse2.Store(leftPtr + i, Sse2.Xor(leftVec, rightVec));
}
}
}
else if (AdvSimd.IsSupported)
{
fixed (int* leftPtr = thisArray)
fixed (int* rightPtr = valueArray)
{
for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount)
{
Vector128<int> leftVec = AdvSimd.LoadVector128(leftPtr + i);
Vector128<int> rightVec = AdvSimd.LoadVector128(rightPtr + i);
AdvSimd.Store(leftPtr + i, AdvSimd.Xor(leftVec, rightVec));
}
Vector128<int> result = Vector128.LoadUnsafe(ref left, i) ^ Vector128.LoadUnsafe(ref right, i);
result.StoreUnsafe(ref left, i);
}
}

Expand Down Expand Up @@ -650,39 +546,23 @@ public unsafe BitArray Not()
}

uint i = 0;
if (Avx2.IsSupported)
{
Vector256<int> ones = Vector256.Create(-1);
fixed (int* ptr = thisArray)
{
for (; i < (uint)count - (Vector256IntCount - 1u); i += Vector256IntCount)
{
Vector256<int> vec = Avx.LoadVector256(ptr + i);
Avx.Store(ptr + i, Avx2.Xor(vec, ones));
}
}
}
else if (Sse2.IsSupported)

ref int value = ref MemoryMarshal.GetArrayDataReference<int>(thisArray);

if (Vector256.IsHardwareAccelerated)
{
Vector128<int> ones = Vector128.Create(-1);
fixed (int* ptr = thisArray)
for (; i < (uint)count - (Vector256IntCount - 1u); i += Vector256IntCount)
{
for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount)
{
Vector128<int> vec = Sse2.LoadVector128(ptr + i);
Sse2.Store(ptr + i, Sse2.Xor(vec, ones));
}
Vector256<int> result = ~Vector256.LoadUnsafe(ref value, i);
result.StoreUnsafe(ref value, i);
}
}
else if (AdvSimd.IsSupported)
else if (Vector128.IsHardwareAccelerated)
{
fixed (int* leftPtr = thisArray)
for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount)
{
for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount)
{
Vector128<int> leftVec = AdvSimd.LoadVector128(leftPtr + i);
AdvSimd.Store(leftPtr + i, AdvSimd.Not(leftVec));
}
Vector128<int> result = ~Vector128.LoadUnsafe(ref value, i);
result.StoreUnsafe(ref value, i);
}
}

Expand Down