Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use FMA instruction in CpuMath for .NET Core 3 #1292

Merged
merged 4 commits into from Oct 19, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
105 changes: 51 additions & 54 deletions src/Microsoft.ML.CpuMath/AvxIntrinsics.cs
Expand Up @@ -141,6 +141,34 @@ private static Vector256<float> GetNewDst256(in Vector256<float> xDst1, in Vecto
return Avx.And(Avx.Subtract(xDst1, x2), xCond);
}

[MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
private static unsafe Vector256<float> MultiplyAdd(float* psrc1, Vector256<float> src2, Vector256<float> src3)
{
if (Fma.IsSupported)
{
return Fma.MultiplyAdd(Avx.LoadVector256(psrc1), src2, src3);
}
else
{
Vector256<float> product = Avx.Multiply(src2, Avx.LoadVector256(psrc1));
return Avx.Add(product, src3);
}
}

[MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
private static Vector256<float> MultiplyAdd(Vector256<float> src1, Vector256<float> src2, Vector256<float> src3)
{
if (Fma.IsSupported)
{
return Fma.MultiplyAdd(src1, src2, src3);
}
else
{
Vector256<float> product = Avx.Multiply(src1, src2);
return Avx.Add(product, src3);
}
}

// Multiply matrix times vector into vector.
public static unsafe void MatMulX(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol)
{
Expand Down Expand Up @@ -185,15 +213,10 @@ public static unsafe void MatMulX(float[] mat, float[] src, float[] dst, int cro
Vector256<float> vector = Avx.LoadVector256(pSrcCurrent);

float* pMatTemp = pMatCurrent;
Vector256<float> x01 = Avx.Multiply(vector, Avx.LoadVector256(pMatTemp));
Vector256<float> x11 = Avx.Multiply(vector, Avx.LoadVector256(pMatTemp += ccol));
Vector256<float> x21 = Avx.Multiply(vector, Avx.LoadVector256(pMatTemp += ccol));
Vector256<float> x31 = Avx.Multiply(vector, Avx.LoadVector256(pMatTemp += ccol));

res0 = Avx.Add(res0, x01);
res1 = Avx.Add(res1, x11);
res2 = Avx.Add(res2, x21);
res3 = Avx.Add(res3, x31);
res0 = MultiplyAdd(pMatTemp, vector, res0);
res1 = MultiplyAdd(pMatTemp += ccol, vector, res1);
res2 = MultiplyAdd(pMatTemp += ccol, vector, res2);
res3 = MultiplyAdd(pMatTemp += ccol, vector, res3);

pSrcCurrent += 8;
pMatCurrent += 8;
Expand Down Expand Up @@ -236,15 +259,10 @@ public static unsafe void MatMulX(float[] mat, float[] src, float[] dst, int cro
Vector256<float> vector = Avx.LoadVector256(pSrcCurrent);

float* pMatTemp = pMatCurrent;
Vector256<float> x01 = Avx.Multiply(vector, Avx.LoadVector256(pMatTemp));
Vector256<float> x11 = Avx.Multiply(vector, Avx.LoadVector256(pMatTemp += ccol));
Vector256<float> x21 = Avx.Multiply(vector, Avx.LoadVector256(pMatTemp += ccol));
Vector256<float> x31 = Avx.Multiply(vector, Avx.LoadVector256(pMatTemp += ccol));

res0 = Avx.Add(res0, x01);
res1 = Avx.Add(res1, x11);
res2 = Avx.Add(res2, x21);
res3 = Avx.Add(res3, x31);
res0 = MultiplyAdd(pMatTemp, vector, res0);
res1 = MultiplyAdd(pMatTemp += ccol, vector, res1);
res2 = MultiplyAdd(pMatTemp += ccol, vector, res2);
res3 = MultiplyAdd(pMatTemp += ccol, vector, res3);

pSrcCurrent += 8;
pMatCurrent += 8;
Expand All @@ -269,10 +287,10 @@ public static unsafe void MatMulX(float[] mat, float[] src, float[] dst, int cro
Vector256<float> x31 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol));
Vector256<float> vector = Avx.And(mask, Avx.LoadVector256(pSrcCurrent));

res0 = Avx.Add(res0, Avx.Multiply(x01, vector));
res1 = Avx.Add(res1, Avx.Multiply(x11, vector));
res2 = Avx.Add(res2, Avx.Multiply(x21, vector));
res3 = Avx.Add(res3, Avx.Multiply(x31, vector));
res0 = MultiplyAdd(x01, vector, res0);
res1 = MultiplyAdd(x11, vector, res1);
res2 = MultiplyAdd(x21, vector, res2);
res3 = MultiplyAdd(x31, vector, res3);

pMatCurrent += 8;
pSrcCurrent += 8;
Expand Down Expand Up @@ -335,8 +353,7 @@ public static unsafe void MatMulX(float[] mat, float[] src, float[] dst, int cro
Vector256<float> x1 = Avx.SetVector256(pm3[col2], pm2[col2], pm1[col2], pm0[col2],
pm3[col1], pm2[col1], pm1[col1], pm0[col1]);
Vector256<float> x2 = Avx.SetAllVector256(pSrcCurrent[col1]);
x2 = Avx.Multiply(x2, x1);
result = Avx.Add(result, x2);
result = MultiplyAdd(x2, x1, result);

ppos++;
}
Expand Down Expand Up @@ -921,11 +938,9 @@ public static unsafe void AddScaleU(float scale, ReadOnlySpan<float> src, Span<f

while (pDstCurrent + 8 <= pEnd)
{
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);

srcVector = Avx.Multiply(srcVector, scaleVector256);
dstVector = Avx.Add(dstVector, srcVector);
dstVector = MultiplyAdd(pSrcCurrent, scaleVector256, dstVector);
Avx.Store(pDstCurrent, dstVector);

pSrcCurrent += 8;
Expand Down Expand Up @@ -977,10 +992,8 @@ public static unsafe void AddScaleCopyU(float scale, ReadOnlySpan<float> src, Re

while (pResCurrent + 8 <= pResEnd)
{
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);
srcVector = Avx.Multiply(srcVector, scaleVector256);
dstVector = Avx.Add(dstVector, srcVector);
dstVector = MultiplyAdd(pSrcCurrent, scaleVector256, dstVector);
Avx.Store(pResCurrent, dstVector);

pSrcCurrent += 8;
Expand Down Expand Up @@ -1033,11 +1046,8 @@ public static unsafe void AddScaleSU(float scale, ReadOnlySpan<float> src, ReadO

while (pIdxCurrent + 8 <= pEnd)
{
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
Vector256<float> dstVector = Load8(pDstCurrent, pIdxCurrent);

srcVector = Avx.Multiply(srcVector, scaleVector256);
dstVector = Avx.Add(dstVector, srcVector);
dstVector = MultiplyAdd(pSrcCurrent, scaleVector256, dstVector);
Store8(in dstVector, pDstCurrent, pIdxCurrent);

pIdxCurrent += 8;
Expand Down Expand Up @@ -1260,7 +1270,7 @@ public static unsafe float SumSqU(ReadOnlySpan<float> src)
while (pSrcCurrent + 8 <= pSrcEnd)
{
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
result256 = Avx.Add(result256, Avx.Multiply(srcVector, srcVector));
result256 = MultiplyAdd(srcVector, srcVector, result256);

pSrcCurrent += 8;
}
Expand Down Expand Up @@ -1306,8 +1316,7 @@ public static unsafe float SumSqDiffU(float mean, ReadOnlySpan<float> src)
{
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
srcVector = Avx.Subtract(srcVector, meanVector256);
result256 = Avx.Add(result256, Avx.Multiply(srcVector, srcVector));

result256 = MultiplyAdd(srcVector, srcVector, result256);
pSrcCurrent += 8;
}

Expand Down Expand Up @@ -1540,11 +1549,8 @@ public static unsafe float DotU(ReadOnlySpan<float> src, ReadOnlySpan<float> dst

while (pSrcCurrent + 8 <= pSrcEnd)
{
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);

result256 = Avx.Add(result256, Avx.Multiply(srcVector, dstVector));

result256 = MultiplyAdd(pSrcCurrent, dstVector, result256);
pSrcCurrent += 8;
pDstCurrent += 8;
}
Expand Down Expand Up @@ -1598,10 +1604,7 @@ public static unsafe float DotSU(ReadOnlySpan<float> src, ReadOnlySpan<float> ds
while (pIdxCurrent + 8 <= pIdxEnd)
{
Vector256<float> srcVector = Load8(pSrcCurrent, pIdxCurrent);
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);

result256 = Avx.Add(result256, Avx.Multiply(srcVector, dstVector));

result256 = MultiplyAdd(pDstCurrent, srcVector, result256);
pIdxCurrent += 8;
pDstCurrent += 8;
}
Expand Down Expand Up @@ -1654,9 +1657,7 @@ public static unsafe float Dist2(ReadOnlySpan<float> src, ReadOnlySpan<float> ds
{
Vector256<float> distanceVector = Avx.Subtract(Avx.LoadVector256(pSrcCurrent),
Avx.LoadVector256(pDstCurrent));
sqDistanceVector256 = Avx.Add(sqDistanceVector256,
Avx.Multiply(distanceVector, distanceVector));

sqDistanceVector256 = MultiplyAdd(distanceVector, distanceVector, sqDistanceVector256);
pSrcCurrent += 8;
pDstCurrent += 8;
}
Expand Down Expand Up @@ -1709,10 +1710,8 @@ public static unsafe void SdcaL1UpdateU(float primalUpdate, int count, ReadOnlyS

while (pSrcCurrent + 8 <= pSrcEnd)
{
Vector256<float> xSrc = Avx.LoadVector256(pSrcCurrent);

Vector256<float> xDst1 = Avx.LoadVector256(pDst1Current);
xDst1 = Avx.Add(xDst1, Avx.Multiply(xSrc, xPrimal256));
xDst1 = MultiplyAdd(pSrcCurrent, xPrimal256, xDst1);
Vector256<float> xDst2 = GetNewDst256(xDst1, xThreshold256);

Avx.Store(pDst1Current, xDst1);
Expand Down Expand Up @@ -1771,10 +1770,8 @@ public static unsafe void SdcaL1UpdateSU(float primalUpdate, int count, ReadOnly

while (pIdxCurrent + 8 <= pIdxEnd)
{
Vector256<float> xSrc = Avx.LoadVector256(pSrcCurrent);

Vector256<float> xDst1 = Load8(pdst1, pIdxCurrent);
xDst1 = Avx.Add(xDst1, Avx.Multiply(xSrc, xPrimal256));
xDst1 = MultiplyAdd(pSrcCurrent, xPrimal256, xDst1);
Vector256<float> xDst2 = GetNewDst256(xDst1, xThreshold);

Store8(in xDst1, pdst1, pIdxCurrent);
Expand Down
11 changes: 11 additions & 0 deletions test/Microsoft.ML.CpuMath.PerformanceTests/AvxPerformanceTests.cs
Expand Up @@ -28,14 +28,17 @@ public void ScaleAddU()
=> AvxIntrinsics.ScaleAddU(DefaultScale, DefaultScale, new Span<float>(dst, 0, Length));

[Benchmark]
[BenchmarkCategory("Fma")]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea with adding the Fma category! 👍

public void AddScaleU()
=> AvxIntrinsics.AddScaleU(DefaultScale, src, dst, Length);

[Benchmark]
[BenchmarkCategory("Fma")]
public void AddScaleSU()
=> AvxIntrinsics.AddScaleSU(DefaultScale, src, idx, dst, IndexLength);

[Benchmark]
[BenchmarkCategory("Fma")]
public void AddScaleCopyU()
=> AvxIntrinsics.AddScaleCopyU(DefaultScale, src, dst, result, Length);

Expand All @@ -56,10 +59,12 @@ public float SumU()
=> AvxIntrinsics.SumU(new Span<float>(src, 0, Length));

[Benchmark]
[BenchmarkCategory("Fma")]
public float SumSqU()
=> AvxIntrinsics.SumSqU(new Span<float>(src, 0, Length));

[Benchmark]
[BenchmarkCategory("Fma")]
public float SumSqDiffU()
=> AvxIntrinsics.SumSqDiffU(DefaultScale, new Span<float>(src, 0, Length));

Expand All @@ -80,25 +85,31 @@ public float MaxAbsDiffU()
=> AvxIntrinsics.MaxAbsDiffU(DefaultScale, new Span<float>(src, 0, Length));

[Benchmark]
[BenchmarkCategory("Fma")]
public float DotU()
=> AvxIntrinsics.DotU(src, dst, Length);

[Benchmark]
[BenchmarkCategory("Fma")]
public float DotSU()
=> AvxIntrinsics.DotSU(src, dst, idx, IndexLength);

[Benchmark]
[BenchmarkCategory("Fma")]
public float Dist2()
=> AvxIntrinsics.Dist2(src, dst, Length);

[Benchmark]
[BenchmarkCategory("Fma")]
public void SdcaL1UpdateU()
=> AvxIntrinsics.SdcaL1UpdateU(DefaultScale, Length, src, DefaultScale, dst, result);

[Benchmark]
[BenchmarkCategory("Fma")]
public void SdcaL1UpdateSU()
=> AvxIntrinsics.SdcaL1UpdateSU(DefaultScale, IndexLength, src, idx, DefaultScale, dst, result);
[Benchmark]
[BenchmarkCategory("Fma")]
public void MatMulX()
=> AvxIntrinsics.MatMulX(src, src1, dst, 1000, 1000);

Expand Down