Skip to content

Commit

Permalink
Begin using the xplat hardware intrinsics in BitArray (#63722)
Browse files Browse the repository at this point in the history
* Change the BitArray(bool[]) constructor to use the xplat intrinsics

* Change the And, Or, Xor, and Not methods to use the xplat intrinsics
  • Loading branch information
tannergooding committed Feb 2, 2022
1 parent 3a77a6d commit 358e28a
Showing 1 changed file with 67 additions and 187 deletions.
254 changes: 67 additions & 187 deletions src/libraries/System.Collections/src/System/Collections/BitArray.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
using System.Buffers.Binary;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;
using System.Runtime.Intrinsics.Arm;
using Internal.Runtime.CompilerServices;

namespace System.Collections
{
Expand Down Expand Up @@ -145,81 +147,32 @@ public unsafe BitArray(bool[] values)
// (true for any non-zero values, false for 0) - any values between 2-255 will be interpreted as false.
// Instead, We compare with zeroes (== false) then negate the result to ensure compatibility.

if (Avx2.IsSupported)
ref byte value = ref Unsafe.As<bool, byte>(ref MemoryMarshal.GetArrayDataReference<bool>(values));

if (Vector256.IsHardwareAccelerated)
{
// JIT does not support code hoisting for SIMD yet
Vector256<byte> zero = Vector256<byte>.Zero;
fixed (bool* ptr = values)
for (; (i + Vector256ByteCount) <= (uint)values.Length; i += Vector256ByteCount)
{
for (; (i + Vector256ByteCount) <= (uint)values.Length; i += Vector256ByteCount)
{
Vector256<byte> vector = Avx.LoadVector256((byte*)ptr + i);
Vector256<byte> isFalse = Avx2.CompareEqual(vector, zero);
int result = Avx2.MoveMask(isFalse);
m_array[i / 32u] = ~result;
}
Vector256<byte> vector = Vector256.LoadUnsafe(ref value, i);
Vector256<byte> isFalse = Vector256.Equals(vector, Vector256<byte>.Zero);

uint result = isFalse.ExtractMostSignificantBits();
m_array[i / 32u] = (int)(~result);
}
}
else if (Sse2.IsSupported)
else if (Vector128.IsHardwareAccelerated)
{
// JIT does not support code hoisting for SIMD yet
Vector128<byte> zero = Vector128<byte>.Zero;
fixed (bool* ptr = values)
for (; (i + Vector128ByteCount * 2u) <= (uint)values.Length; i += Vector128ByteCount * 2u)
{
for (; (i + Vector128ByteCount * 2u) <= (uint)values.Length; i += Vector128ByteCount * 2u)
{
Vector128<byte> lowerVector = Sse2.LoadVector128((byte*)ptr + i);
Vector128<byte> lowerIsFalse = Sse2.CompareEqual(lowerVector, zero);
int lowerPackedIsFalse = Sse2.MoveMask(lowerIsFalse);
Vector128<byte> lowerVector = Vector128.LoadUnsafe(ref value, i);
Vector128<byte> lowerIsFalse = Vector128.Equals(lowerVector, Vector128<byte>.Zero);
uint lowerResult = lowerIsFalse.ExtractMostSignificantBits();

Vector128<byte> upperVector = Sse2.LoadVector128((byte*)ptr + i + Vector128<byte>.Count);
Vector128<byte> upperIsFalse = Sse2.CompareEqual(upperVector, zero);
int upperPackedIsFalse = Sse2.MoveMask(upperIsFalse);
Vector128<byte> upperVector = Vector128.LoadUnsafe(ref value, i + Vector128ByteCount);
Vector128<byte> upperIsFalse = Vector128.Equals(upperVector, Vector128<byte>.Zero);
uint upperResult = upperIsFalse.ExtractMostSignificantBits();

m_array[i / 32u] = ~((upperPackedIsFalse << 16) | lowerPackedIsFalse);
}
}
}
else if (AdvSimd.Arm64.IsSupported)
{
// JIT does not support code hoisting for SIMD yet
// However comparison against zero can be replaced to cmeq against zero (vceqzq_s8)
// See dotnet/runtime#33972 for details
Vector128<byte> zero = Vector128<byte>.Zero;
Vector128<byte> bitMask128 = BitConverter.IsLittleEndian ?
Vector128.Create(0x80402010_08040201).AsByte() :
Vector128.Create(0x01020408_10204080).AsByte();

fixed (bool* ptr = values)
{
for (; (i + Vector128ByteCount * 2u) <= (uint)values.Length; i += Vector128ByteCount * 2u)
{
// Same logic as SSE2 path, however we lack MoveMask (equivalent) instruction
// As a workaround, mask out the relevant bit after comparison
// and combine by ORing all of them together (In this case, adding all of them does the same thing)
Vector128<byte> lowerVector = AdvSimd.LoadVector128((byte*)ptr + i);
Vector128<byte> lowerIsFalse = AdvSimd.CompareEqual(lowerVector, zero);
Vector128<byte> bitsExtracted1 = AdvSimd.And(lowerIsFalse, bitMask128);
bitsExtracted1 = AdvSimd.Arm64.AddPairwise(bitsExtracted1, bitsExtracted1);
bitsExtracted1 = AdvSimd.Arm64.AddPairwise(bitsExtracted1, bitsExtracted1);
bitsExtracted1 = AdvSimd.Arm64.AddPairwise(bitsExtracted1, bitsExtracted1);
Vector128<short> lowerPackedIsFalse = bitsExtracted1.AsInt16();

Vector128<byte> upperVector = AdvSimd.LoadVector128((byte*)ptr + i + Vector128<byte>.Count);
Vector128<byte> upperIsFalse = AdvSimd.CompareEqual(upperVector, zero);
Vector128<byte> bitsExtracted2 = AdvSimd.And(upperIsFalse, bitMask128);
bitsExtracted2 = AdvSimd.Arm64.AddPairwise(bitsExtracted2, bitsExtracted2);
bitsExtracted2 = AdvSimd.Arm64.AddPairwise(bitsExtracted2, bitsExtracted2);
bitsExtracted2 = AdvSimd.Arm64.AddPairwise(bitsExtracted2, bitsExtracted2);
Vector128<short> upperPackedIsFalse = bitsExtracted2.AsInt16();

int result = AdvSimd.Arm64.ZipLow(lowerPackedIsFalse, upperPackedIsFalse).AsInt32().ToScalar();
if (!BitConverter.IsLittleEndian)
{
result = BinaryPrimitives.ReverseEndianness(result);
}
m_array[i / 32u] = ~result;
}
m_array[i / 32u] = (int)(~((upperResult << 16) | lowerResult));
}
}

Expand Down Expand Up @@ -400,43 +353,24 @@ public unsafe BitArray And(BitArray value)
}

uint i = 0;
if (Avx2.IsSupported)
{
fixed (int* leftPtr = thisArray)
fixed (int* rightPtr = valueArray)
{
for (; i < (uint)count - (Vector256IntCount - 1u); i += Vector256IntCount)
{
Vector256<int> leftVec = Avx.LoadVector256(leftPtr + i);
Vector256<int> rightVec = Avx.LoadVector256(rightPtr + i);
Avx.Store(leftPtr + i, Avx2.And(leftVec, rightVec));
}
}
}
else if (Sse2.IsSupported)

ref int left = ref MemoryMarshal.GetArrayDataReference<int>(thisArray);
ref int right = ref MemoryMarshal.GetArrayDataReference<int>(valueArray);

if (Vector256.IsHardwareAccelerated)
{
fixed (int* leftPtr = thisArray)
fixed (int* rightPtr = valueArray)
for (; i < (uint)count - (Vector256IntCount - 1u); i += Vector256IntCount)
{
for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount)
{
Vector128<int> leftVec = Sse2.LoadVector128(leftPtr + i);
Vector128<int> rightVec = Sse2.LoadVector128(rightPtr + i);
Sse2.Store(leftPtr + i, Sse2.And(leftVec, rightVec));
}
Vector256<int> result = Vector256.LoadUnsafe(ref left, i) & Vector256.LoadUnsafe(ref right, i);
result.StoreUnsafe(ref left, i);
}
}
else if (AdvSimd.IsSupported)
else if (Vector128.IsHardwareAccelerated)
{
fixed (int* leftPtr = thisArray)
fixed (int* rightPtr = valueArray)
for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount)
{
for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount)
{
Vector128<int> leftVec = AdvSimd.LoadVector128(leftPtr + i);
Vector128<int> rightVec = AdvSimd.LoadVector128(rightPtr + i);
AdvSimd.Store(leftPtr + i, AdvSimd.And(leftVec, rightVec));
}
Vector128<int> result = Vector128.LoadUnsafe(ref left, i) & Vector128.LoadUnsafe(ref right, i);
result.StoreUnsafe(ref left, i);
}
}

Expand Down Expand Up @@ -486,43 +420,24 @@ public unsafe BitArray Or(BitArray value)
}

uint i = 0;
if (Avx2.IsSupported)
{
fixed (int* leftPtr = thisArray)
fixed (int* rightPtr = valueArray)
{
for (; i < (uint)count - (Vector256IntCount - 1u); i += Vector256IntCount)
{
Vector256<int> leftVec = Avx.LoadVector256(leftPtr + i);
Vector256<int> rightVec = Avx.LoadVector256(rightPtr + i);
Avx.Store(leftPtr + i, Avx2.Or(leftVec, rightVec));
}
}
}
else if (Sse2.IsSupported)

ref int left = ref MemoryMarshal.GetArrayDataReference<int>(thisArray);
ref int right = ref MemoryMarshal.GetArrayDataReference<int>(valueArray);

if (Vector256.IsHardwareAccelerated)
{
fixed (int* leftPtr = thisArray)
fixed (int* rightPtr = valueArray)
for (; i < (uint)count - (Vector256IntCount - 1u); i += Vector256IntCount)
{
for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount)
{
Vector128<int> leftVec = Sse2.LoadVector128(leftPtr + i);
Vector128<int> rightVec = Sse2.LoadVector128(rightPtr + i);
Sse2.Store(leftPtr + i, Sse2.Or(leftVec, rightVec));
}
Vector256<int> result = Vector256.LoadUnsafe(ref left, i) | Vector256.LoadUnsafe(ref right, i);
result.StoreUnsafe(ref left, i);
}
}
else if (AdvSimd.IsSupported)
else if (Vector128.IsHardwareAccelerated)
{
fixed (int* leftPtr = thisArray)
fixed (int* rightPtr = valueArray)
for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount)
{
for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount)
{
Vector128<int> leftVec = AdvSimd.LoadVector128(leftPtr + i);
Vector128<int> rightVec = AdvSimd.LoadVector128(rightPtr + i);
AdvSimd.Store(leftPtr + i, AdvSimd.Or(leftVec, rightVec));
}
Vector128<int> result = Vector128.LoadUnsafe(ref left, i) | Vector128.LoadUnsafe(ref right, i);
result.StoreUnsafe(ref left, i);
}
}

Expand Down Expand Up @@ -572,43 +487,24 @@ public unsafe BitArray Xor(BitArray value)
}

uint i = 0;
if (Avx2.IsSupported)

ref int left = ref MemoryMarshal.GetArrayDataReference<int>(thisArray);
ref int right = ref MemoryMarshal.GetArrayDataReference<int>(valueArray);

if (Vector256.IsHardwareAccelerated)
{
fixed (int* leftPtr = m_array)
fixed (int* rightPtr = value.m_array)
for (; i < (uint)count - (Vector256IntCount - 1u); i += Vector256IntCount)
{
for (; i < (uint)count - (Vector256IntCount - 1u); i += Vector256IntCount)
{
Vector256<int> leftVec = Avx.LoadVector256(leftPtr + i);
Vector256<int> rightVec = Avx.LoadVector256(rightPtr + i);
Avx.Store(leftPtr + i, Avx2.Xor(leftVec, rightVec));
}
Vector256<int> result = Vector256.LoadUnsafe(ref left, i) ^ Vector256.LoadUnsafe(ref right, i);
result.StoreUnsafe(ref left, i);
}
}
else if (Sse2.IsSupported)
else if (Vector128.IsHardwareAccelerated)
{
fixed (int* leftPtr = thisArray)
fixed (int* rightPtr = valueArray)
for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount)
{
for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount)
{
Vector128<int> leftVec = Sse2.LoadVector128(leftPtr + i);
Vector128<int> rightVec = Sse2.LoadVector128(rightPtr + i);
Sse2.Store(leftPtr + i, Sse2.Xor(leftVec, rightVec));
}
}
}
else if (AdvSimd.IsSupported)
{
fixed (int* leftPtr = thisArray)
fixed (int* rightPtr = valueArray)
{
for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount)
{
Vector128<int> leftVec = AdvSimd.LoadVector128(leftPtr + i);
Vector128<int> rightVec = AdvSimd.LoadVector128(rightPtr + i);
AdvSimd.Store(leftPtr + i, AdvSimd.Xor(leftVec, rightVec));
}
Vector128<int> result = Vector128.LoadUnsafe(ref left, i) ^ Vector128.LoadUnsafe(ref right, i);
result.StoreUnsafe(ref left, i);
}
}

Expand Down Expand Up @@ -650,39 +546,23 @@ public unsafe BitArray Not()
}

uint i = 0;
if (Avx2.IsSupported)
{
Vector256<int> ones = Vector256.Create(-1);
fixed (int* ptr = thisArray)
{
for (; i < (uint)count - (Vector256IntCount - 1u); i += Vector256IntCount)
{
Vector256<int> vec = Avx.LoadVector256(ptr + i);
Avx.Store(ptr + i, Avx2.Xor(vec, ones));
}
}
}
else if (Sse2.IsSupported)

ref int value = ref MemoryMarshal.GetArrayDataReference<int>(thisArray);

if (Vector256.IsHardwareAccelerated)
{
Vector128<int> ones = Vector128.Create(-1);
fixed (int* ptr = thisArray)
for (; i < (uint)count - (Vector256IntCount - 1u); i += Vector256IntCount)
{
for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount)
{
Vector128<int> vec = Sse2.LoadVector128(ptr + i);
Sse2.Store(ptr + i, Sse2.Xor(vec, ones));
}
Vector256<int> result = ~Vector256.LoadUnsafe(ref value, i);
result.StoreUnsafe(ref value, i);
}
}
else if (AdvSimd.IsSupported)
else if (Vector128.IsHardwareAccelerated)
{
fixed (int* leftPtr = thisArray)
for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount)
{
for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount)
{
Vector128<int> leftVec = AdvSimd.LoadVector128(leftPtr + i);
AdvSimd.Store(leftPtr + i, AdvSimd.Not(leftVec));
}
Vector128<int> result = ~Vector128.LoadUnsafe(ref value, i);
result.StoreUnsafe(ref value, i);
}
}

Expand Down

0 comments on commit 358e28a

Please sign in to comment.