Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Microsoft.ML.CpuMath/Avx.cs
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ public static void Scale(float a, float[] dst, int count)
unsafe
{
fixed (float* pd = &dst[0])
Thunk.ScaleU(a, pd, count);
Thunk.Scale(a, pd, count);
}
}

Expand Down
146 changes: 124 additions & 22 deletions src/Microsoft.ML.CpuMath/AvxIntrinsics.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,36 @@
using System.Runtime.CompilerServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;
using nuint = System.UInt64;

namespace Microsoft.ML.Runtime.Internal.CpuMath
{
internal static class AvxIntrinsics
{
public static readonly uint[] LeadingAlignmentMask = new uint[64]
{
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000,
0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000,
0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000,
};

public static readonly uint[] TrailingAlignmentMask = new uint[64]
{
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
};

private static readonly Vector256<float> _absMask256 = Avx.StaticCast<int, float>(Avx.SetAllVector256(0x7FFFFFFF));

private const int Vector256Alignment = 32;
Expand Down Expand Up @@ -451,45 +476,122 @@ public static unsafe void AddScalarU(float scalar, Span<float> dst)
}
}

public static unsafe void ScaleU(float scale, Span<float> dst)
public static unsafe void Scale(float scale, Span<float> dst)
{
fixed (float* pdst = dst)
fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0])
fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0])
fixed (float* pd = dst)
{
float* pDstCurrent = pdst;
float* pEnd = pdst + dst.Length;

float* pDstCurrent = pd;
int length = dst.Length;
Vector256<float> scaleVector256 = Avx.SetAllVector256(scale);

while (pDstCurrent + 8 <= pEnd)
if (length < 8)
{
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);
switch(length)
{
case 7: dst[6] *= scale; goto case 6;
case 6: dst[5] *= scale; goto case 5;
case 5: dst[4] *= scale; goto case 4;
case 4: dst[3] *= scale; goto case 3;
case 3: dst[2] *= scale; goto case 2;
case 2: dst[1] *= scale; goto case 1;
case 1: dst[0] *= scale; break;
}
return;
}

dstVector = Avx.Multiply(scaleVector256, dstVector);
Avx.Store(pDstCurrent, dstVector);
nuint address = (nuint)(pd);
int misalignment = (int)(address % 32);
int remainder = 0;

pDstCurrent += 8;
if ((misalignment & 3) != 0)
{
// Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations
remainder = length % 8;

for (float* pEnd = pd + (length - remainder); pDstCurrent < pEnd; pDstCurrent += 8)
{
Vector256<float> temp = Avx.LoadVector256(pDstCurrent);
temp = Avx.Multiply(scaleVector256, temp);
Avx.Store(pDstCurrent, temp);
}
}
else
{
if (misalignment != 0)
{
// Handle cases where the data is not 256-bit aligned by doing an unaligned read and then
// masking any elements that will be included in the first aligned read

Vector128<float> scaleVector128 = Sse.SetAllVector128(scale);
misalignment >>= 2;
misalignment = 8 - misalignment;

if (pDstCurrent + 4 <= pEnd)
{
Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
Vector256<float> result = Avx.LoadVector256(pDstCurrent);

dstVector = Sse.Multiply(scaleVector128, dstVector);
Sse.Store(pDstCurrent, dstVector);
Vector256<float> leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8));
Vector256<float> trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (( 8 - misalignment) * 8));

pDstCurrent += 4;
Vector256<float> temp = Avx.And(result, leadingMask);
result = Avx.And(result, trailingMask);

temp = Avx.Multiply(scaleVector256, temp);
result = Avx.Or(temp, result);

Avx.Store(pDstCurrent, result);

pDstCurrent += misalignment;
length -= misalignment;
}

if (length > 7)
{
// Handle all the 256-bit blocks that we can now that we have offset to an aligned address

remainder = length % 8;

for (float* pEnd = pDstCurrent + (length - remainder); pDstCurrent < pEnd; pDstCurrent += 8)
{
// The JIT will only fold away unaligned loads due to the semantics behind
// the VEX-encoding of the memory operand for `ins xmm, xmm, [mem]`. Since
// modern hardware has unaligned loads that are as fast as aligned loads,
// when it doesn't cross a cache-line/page boundary, we will just assert
// that the alignment is correct and allow for the more-efficient codegen.

Contracts.Assert(((nuint)(pDstCurrent) % 32) == 0);
Vector256<float> temp = Avx.LoadVector256(pDstCurrent);
temp = Avx.Multiply(scaleVector256, temp);
Avx.Store(pDstCurrent, temp);
}
}
else
{
// Handle the "worst-case" scenario, which is when we have 8-16 elements and the input is not
// 256-bit aligned. This means we can't do any aligned loads and will just end up doing two
// unaligned loads where we mask the input each time.
remainder = length;
}
}

while (pDstCurrent < pEnd)
if (remainder != 0)
{
Vector128<float> dstVector = Sse.LoadScalarVector128(pDstCurrent);
// Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next
// unaligned load will read to the end of the array and then mask out any elements already processed

dstVector = Sse.MultiplyScalar(scaleVector128, dstVector);
Sse.StoreScalar(pDstCurrent, dstVector);
pDstCurrent -= (8 - remainder);

pDstCurrent++;
Vector256<float> result = Avx.LoadVector256(pDstCurrent);

Vector256<float> trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8));
Vector256<float> leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + ((8 - remainder) * 8));

Vector256<float> temp = Avx.And(result, trailingMask);
result = Avx.And(result, leadingMask);

temp = Avx.Multiply(scaleVector256, temp);
temp = Avx.Or(temp, result);

Avx.Store(pDstCurrent, temp);
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,11 @@ private static void Scale(float a, Span<float> dst)
{
if (Avx.IsSupported)
{
AvxIntrinsics.ScaleU(a, dst);
AvxIntrinsics.Scale(a, dst);
}
else if (Sse.IsSupported)
{
SseIntrinsics.ScaleU(a, dst);
SseIntrinsics.Scale(a, dst);
}
else
{
Expand Down
6 changes: 3 additions & 3 deletions src/Microsoft.ML.CpuMath/Sse.cs
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ public static void Scale(float a, AlignedArray dst)
unsafe
{
fixed (float* pdst = &dst.Items[0])
Thunk.ScaleA(a, Ptr(dst, pdst), dst.Size);
Thunk.Scale(a, Ptr(dst, pdst), dst.Size);
}
}

Expand All @@ -618,7 +618,7 @@ public static void Scale(float a, float[] dst, int count)
unsafe
{
fixed (float* pd = &dst[0])
Thunk.ScaleU(a, pd, count);
Thunk.Scale(a, pd, count);
}
}

Expand All @@ -631,7 +631,7 @@ public static void Scale(float a, float[] dst, int offset, int count)
unsafe
{
fixed (float* pd = &dst[offset])
Thunk.ScaleU(a, pd, count);
Thunk.Scale(a, pd, count);
}
}

Expand Down
Loading