Skip to content
This repository has been archived by the owner on Jan 23, 2023. It is now read-only.
/ corefx Public archive

Vectorise BitArray #41896

Merged
merged 6 commits into from
Nov 7, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 217 additions & 18 deletions src/System.Collections/src/System/Collections/BitArray.cs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ public BitArray(byte[] bytes)
_version = 0;
}

public BitArray(bool[] values)
public unsafe BitArray(bool[] values)
{
if (values == null)
{
Expand All @@ -123,7 +123,51 @@ public BitArray(bool[] values)
m_array = new int[GetInt32ArrayLengthFromBitLength(values.Length)];
m_length = values.Length;

for (int i = 0; i < values.Length; i++)
int i = 0;

if (values.Length < Vector256<byte>.Count)
{
goto LessThan32;
}

// Comparing with 1s would get rid of the final negation, however this would not work for some CLR bools
// (true for any non-zero values, false for 0) - any values between 2-255 will be interpreted as false.
// Instead, We compare with zeroes (== false) then negate the result to ensure compatibility.

if (Avx2.IsSupported)
{
fixed (bool* ptr = values)
{
for (; (i + Vector256<byte>.Count) <= values.Length; i += Vector256<byte>.Count)
{
Vector256<byte> vector = Avx.LoadVector256((byte*)ptr + i);
Vector256<byte> isFalse = Avx2.CompareEqual(vector, Vector256<byte>.Zero);
int result = Avx2.MoveMask(isFalse);
m_array[i / 32] = ~result;
Gnbrkm41 marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
else if (Sse2.IsSupported)
{
fixed (bool* ptr = values)
{
for (; (i + Vector128<byte>.Count * 2) <= values.Length; i += Vector128<byte>.Count * 2)
{
Vector128<byte> lowerVector = Sse2.LoadVector128((byte*)ptr + i);
Vector128<byte> lowerIsFalse = Sse2.CompareEqual(lowerVector, Vector128<byte>.Zero);
int lowerPackedIsFalse = Sse2.MoveMask(lowerIsFalse);

Vector128<byte> upperVector = Sse2.LoadVector128((byte*)ptr + i + Vector128<byte>.Count);
Vector128<byte> upperIsFalse = Sse2.CompareEqual(upperVector, Vector128<byte>.Zero);
int upperPackedIsFalse = Sse2.MoveMask(upperIsFalse);

m_array[i / 32] = ~((upperPackedIsFalse << 16) | lowerPackedIsFalse);
}
}
}

LessThan32:
for (; i < values.Length; i++)
{
if (values[i])
{
Expand Down Expand Up @@ -241,13 +285,8 @@ public void Set(int index, bool value)
public void SetAll(bool value)
{
int fillValue = value ? -1 : 0;
int[] array = m_array;

for (int i = 0; i < array.Length; i++)
{
array[i] = fillValue;
}

int arrayLength = GetInt32ArrayLengthFromBitLength(Length);
m_array.AsSpan(0, arrayLength).Fill(fillValue);
_version++;
}

Expand Down Expand Up @@ -275,16 +314,34 @@ public unsafe BitArray And(BitArray value)
if (Length != value.Length || (uint)count > (uint)thisArray.Length || (uint)count > (uint)valueArray.Length)
throw new ArgumentException(SR.Arg_ArrayLengthsDiffer);

// Unroll loop for count less than Vector256 size.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After the vectorized version there's a sequential loop to process the remaining elements. Why not jump to this switch instead the loop?

(Of course, keep the loop if no Avx2 or Sse2 is available.)

switch (count)
{
case 7: thisArray[6] &= valueArray[6]; goto case 6;
case 6: thisArray[5] &= valueArray[5]; goto case 5;
case 5: thisArray[4] &= valueArray[4]; goto case 4;
case 4: thisArray[3] &= valueArray[3]; goto case 3;
case 3: thisArray[2] &= valueArray[2]; goto case 2;
case 2: thisArray[1] &= valueArray[1]; goto case 1;
case 1: thisArray[0] &= valueArray[0]; goto Done;
case 0: goto Done;
}

int i = 0;
if (Sse2.IsSupported)
if (Avx2.IsSupported)
{
fixed (int* leftPtr = thisArray)
fixed (int* rightPtr = valueArray)
{
for (; i < count - (Vector256<int>.Count - 1); i += Vector256<int>.Count)
{
Vector256<int> leftVec = Avx.LoadVector256(leftPtr + i);
Vector256<int> rightVec = Avx.LoadVector256(rightPtr + i);
Avx.Store(leftPtr + i, Avx2.And(leftVec, rightVec));
}
}
}
else if (Sse2.IsSupported)
{
fixed (int* leftPtr = thisArray)
fixed (int* rightPtr = valueArray)
Expand Down Expand Up @@ -330,16 +387,34 @@ public unsafe BitArray Or(BitArray value)
if (Length != value.Length || (uint)count > (uint)thisArray.Length || (uint)count > (uint)valueArray.Length)
throw new ArgumentException(SR.Arg_ArrayLengthsDiffer);

// Unroll loop for count less than Vector256 size.
switch (count)
{
case 7: thisArray[6] |= valueArray[6]; goto case 6;
case 6: thisArray[5] |= valueArray[5]; goto case 5;
case 5: thisArray[4] |= valueArray[4]; goto case 4;
case 4: thisArray[3] |= valueArray[3]; goto case 3;
case 3: thisArray[2] |= valueArray[2]; goto case 2;
case 2: thisArray[1] |= valueArray[1]; goto case 1;
case 1: thisArray[0] |= valueArray[0]; goto Done;
case 0: goto Done;
}

int i = 0;
if (Sse2.IsSupported)
if (Avx2.IsSupported)
{
fixed (int* leftPtr = thisArray)
fixed (int* rightPtr = valueArray)
{
for (; i < count - (Vector256<int>.Count - 1); i += Vector256<int>.Count)
{
Vector256<int> leftVec = Avx.LoadVector256(leftPtr + i);
Vector256<int> rightVec = Avx.LoadVector256(rightPtr + i);
Avx.Store(leftPtr + i, Avx2.Or(leftVec, rightVec));
}
}
}
else if (Sse2.IsSupported)
{
fixed (int* leftPtr = thisArray)
fixed (int* rightPtr = valueArray)
Expand Down Expand Up @@ -385,16 +460,34 @@ public unsafe BitArray Xor(BitArray value)
if (Length != value.Length || (uint)count > (uint)thisArray.Length || (uint)count > (uint)valueArray.Length)
throw new ArgumentException(SR.Arg_ArrayLengthsDiffer);

// Unroll loop for count less than Vector256 size.
switch (count)
{
case 7: thisArray[6] ^= valueArray[6]; goto case 6;
case 6: thisArray[5] ^= valueArray[5]; goto case 5;
case 5: thisArray[4] ^= valueArray[4]; goto case 4;
case 4: thisArray[3] ^= valueArray[3]; goto case 3;
case 3: thisArray[2] ^= valueArray[2]; goto case 2;
case 2: thisArray[1] ^= valueArray[1]; goto case 1;
case 1: thisArray[0] ^= valueArray[0]; goto Done;
case 0: goto Done;
}

int i = 0;
if (Sse2.IsSupported)
if (Avx2.IsSupported)
{
fixed (int* leftPtr = m_array)
fixed (int* rightPtr = value.m_array)
{
for (; i < count - (Vector256<int>.Count - 1); i += Vector256<int>.Count)
{
Vector256<int> leftVec = Avx.LoadVector256(leftPtr + i);
Vector256<int> rightVec = Avx.LoadVector256(rightPtr + i);
Avx.Store(leftPtr + i, Avx2.Xor(leftVec, rightVec));
}
}
}
else if (Sse2.IsSupported)
{
fixed (int* leftPtr = thisArray)
fixed (int* rightPtr = valueArray)
Expand All @@ -421,15 +514,60 @@ public unsafe BitArray Xor(BitArray value)
** off/false. Off/false bit values are turned on/true. The current instance
** is updated and returned.
=========================================================================*/
public BitArray Not()
public unsafe BitArray Not()
{
int[] array = m_array;
// This method uses unsafe code to manipulate data in the BitArray. To avoid issues with
// buggy code concurrently mutating this instance in a way that could cause memory corruption,
// we snapshot the array then operate only on this snapshot. We don't care about such code
// corrupting the BitArray data in a way that produces incorrect answers, since BitArray is not meant
// to be thread-safe; we only care about avoiding buffer overruns.
int[] thisArray = m_array;

int count = GetInt32ArrayLengthFromBitLength(Length);

for (int i = 0; i < array.Length; i++)
// Unroll loop for count less than Vector256 size.
switch (count)
{
array[i] = ~array[i];
case 7: thisArray[6] = ~thisArray[6]; goto case 6;
case 6: thisArray[5] = ~thisArray[5]; goto case 5;
case 5: thisArray[4] = ~thisArray[4]; goto case 4;
case 4: thisArray[3] = ~thisArray[3]; goto case 3;
case 3: thisArray[2] = ~thisArray[2]; goto case 2;
case 2: thisArray[1] = ~thisArray[1]; goto case 1;
case 1: thisArray[0] = ~thisArray[0]; goto Done;
case 0: goto Done;
}

int i = 0;
if (Avx2.IsSupported)
{
Vector256<int> ones = Vector256.Create(-1);
fixed (int* ptr = thisArray)
{
for (; i < count - (Vector256<int>.Count - 1); i += Vector256<int>.Count)
{
Vector256<int> vec = Avx.LoadVector256(ptr + i);
Avx.Store(ptr + i, Avx2.Xor(vec, ones));
}
}
}
else if (Sse2.IsSupported)
{
Vector128<int> ones = Vector128.Create(-1);
fixed (int* ptr = thisArray)
{
for (; i < count - (Vector128<int>.Count - 1); i += Vector128<int>.Count)
{
Vector128<int> vec = Sse2.LoadVector128(ptr + i);
Sse2.Store(ptr + i, Sse2.Xor(vec, ones));
}
}
}

for (; i < count; i++)
thisArray[i] = ~thisArray[i];

Done:
_version++;
return this;
}
Expand Down Expand Up @@ -597,7 +735,13 @@ public int Length
}
}

public void CopyTo(Array array, int index)
// The mask used when shuffling a single int into Vector128/256.
// On little endian machines, the lower 8 bits of int belong in the first byte, next lower 8 in the second and so on.
// We place the bytes that contain the bits to its respective byte so that we can mask out only the relevant bits later.
private static readonly Vector128<byte> s_lowerShuffleMask_CopyToBoolArray = Vector128.Create(0, 0x01010101_01010101).AsByte();
private static readonly Vector128<byte> s_upperShuffleMask_CopyToBoolArray = Vector128.Create(0x_02020202_02020202, 0x03030303_03030303).AsByte();

public unsafe void CopyTo(Array array, int index)
{
if (array == null)
throw new ArgumentNullException(nameof(array));
Expand Down Expand Up @@ -682,7 +826,62 @@ public void CopyTo(Array array, int index)
throw new ArgumentException(SR.Argument_InvalidOffLen);
}

for (int i = 0; i < m_length; i++)
int i = 0;

if (m_length < BitsPerInt32)
goto LessThan32;

if (Avx2.IsSupported)
{
Vector256<byte> shuffleMask = Vector256.Create(s_lowerShuffleMask_CopyToBoolArray, s_upperShuffleMask_CopyToBoolArray);
Vector256<byte> bitMask = Vector256.Create(0x80402010_08040201).AsByte();
Vector256<byte> ones = Vector256.Create((byte)1);

fixed (bool* destination = &boolArray[index])
{
for (; (i + Vector256<byte>.Count) <= m_length; i += Vector256<byte>.Count)
{
int bits = m_array[i / BitsPerInt32];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, you can load m_array as a Vector and spawn vectors for each integer

Vector256<int> scalar = Vector256.Create(bits);
Vector256<byte> shuffled = Avx2.Shuffle(scalar.AsByte(), shuffleMask);
Vector256<byte> extracted = Avx2.And(shuffled, bitMask);

// The extracted bits can be anywhere between 0 and 255, so we normalise the value to either 0 or 1
// to ensure compatibility with "C# bool" (0 for false, 1 for true, rest undefined)
Vector256<byte> normalized = Avx2.Min(extracted, ones);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems you don't do this kind of normalization for BitArray(bool[]) constructor

Copy link
Author

@Gnbrkm41 Gnbrkm41 Oct 18, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is handled by comparing the bytes with zero (checking if the bytes are false) then negating the result: 72477e7#diff-e2f01cf03382b7d63fc3a67ad77fcedcR140-R142

Avx.Store((byte*)destination + i, normalized);
}
}
}
else if (Ssse3.IsSupported)
{
Vector128<byte> lowerShuffleMask = s_lowerShuffleMask_CopyToBoolArray;
Vector128<byte> upperShuffleMask = s_upperShuffleMask_CopyToBoolArray;
Vector128<byte> bitMask = Vector128.Create(0x80402010_08040201).AsByte(); ;
Vector128<byte> ones = Vector128.Create((byte)1);

fixed (bool* destination = &boolArray[index])
{
for (; (i + Vector128<byte>.Count * 2) <= m_length; i += Vector128<byte>.Count * 2)
{
int bits = m_array[i / BitsPerInt32];
Vector128<int> scalar = Vector128.CreateScalarUnsafe(bits);

Vector128<byte> shuffledLower = Ssse3.Shuffle(scalar.AsByte(), lowerShuffleMask);
Vector128<byte> extractedLower = Sse2.And(shuffledLower, bitMask);
Vector128<byte> normalizedLower = Sse2.Min(extractedLower, ones);
Sse2.Store((byte*)destination + i, normalizedLower);

Vector128<byte> shuffledHigher = Ssse3.Shuffle(scalar.AsByte(), upperShuffleMask);
Vector128<byte> extractedHigher = Sse2.And(shuffledHigher, bitMask);
Vector128<byte> normalizedHigher = Sse2.Min(extractedHigher, ones);
Sse2.Store((byte*)destination + i + Vector128<byte>.Count, normalizedHigher);
}
}
}

LessThan32:
Gnbrkm41 marked this conversation as resolved.
Show resolved Hide resolved
for (; i < m_length; i++)
{
int elementIndex = Div32Rem(i, out int extraBits);
boolArray[index + i] = ((m_array[elementIndex] >> extraBits) & 0x00000001) != 0;
Expand Down
12 changes: 11 additions & 1 deletion src/System.Collections/tests/BitArray/BitArray_CtorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,22 @@ public static void Ctor_Int_NegativeLength_ThrowsArgumentOutOfRangeException()

public static IEnumerable<object[]> Ctor_BoolArray_TestData()
{
Random rnd = new Random(0);

yield return new object[] { new bool[0] };
foreach (int size in new[] { 1, BitsPerByte, BitsPerByte * 2, BitsPerInt32, BitsPerInt32 * 2 })
foreach (int size in new[] { 1, BitsPerByte, BitsPerByte * 2, BitsPerInt32, BitsPerInt32 * 2, BitsPerInt32 * 4, BitsPerInt32 * 8, BitsPerInt32 * 16})
{
yield return new object[] { Enumerable.Repeat(true, size).ToArray() };
yield return new object[] { Enumerable.Repeat(false, size).ToArray() };
yield return new object[] { Enumerable.Range(0, size).Select(x => x % 2 == 0).ToArray() };

bool[] random = new bool[size];
for (int i = 0; i < random.Length; i++)
{
random[i] = rnd.Next(0, 2) == 0;
}

yield return new object[] { random };
}
}

Expand Down
Loading