diff --git a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs index 5e7845007c..1b19b46949 100644 --- a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs +++ b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs @@ -47,6 +47,25 @@ internal static class AvxIntrinsics private static readonly Vector256 _absMask256 = Avx.StaticCast(Avx.SetAllVector256(0x7FFFFFFF)); + private const int Vector256Alignment = 32; + + [MethodImplAttribute(MethodImplOptions.AggressiveInlining)] + private static bool HasCompatibleAlignment(AlignedArray alignedArray) + { + Contracts.AssertValue(alignedArray); + Contracts.Assert(alignedArray.Size > 0); + return (alignedArray.CbAlign % Vector256Alignment) == 0; + } + + [MethodImplAttribute(MethodImplOptions.AggressiveInlining)] + private static unsafe float* GetAlignedBase(AlignedArray alignedArray, float* unalignedBase) + { + Contracts.AssertValue(alignedArray); + float* alignedBase = unalignedBase + alignedArray.GetBase((long)unalignedBase); + Contracts.Assert(((long)alignedBase % Vector256Alignment) == 0); + return alignedBase; + } + [MethodImplAttribute(MethodImplOptions.AggressiveInlining)] private static Vector128 GetHigh(in Vector256 x) => Avx.ExtractVector128(x, 1); @@ -154,17 +173,18 @@ private static Vector256 MultiplyAdd(Vector256 src1, Vector256 mat, ReadOnlySpan src, Span dst, int crow, int ccol) - { - fixed (float* psrc = &MemoryMarshal.GetReference(src)) - fixed (float* pdst = &MemoryMarshal.GetReference(dst)) - fixed (float* pmat = &MemoryMarshal.GetReference(mat)) - fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0]) - fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0]) + fixed (float* pSrcStart = &src.Items[0]) + fixed (float* pDstStart = &dst.Items[0]) + fixed (float* pMatStart = &mat.Items[0]) { + float* psrc = GetAlignedBase(src, pSrcStart); + float* pdst = GetAlignedBase(dst, pDstStart); + float* pmat = GetAlignedBase(mat, pMatStart); + float* pSrcEnd = psrc + ccol; float* pDstEnd = pdst + crow; float* pDstCurrent = pdst; @@ -173,118 +193,36 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr while (pDstCurrent < pDstEnd) { Vector256 res0 = Avx.SetZeroVector256(); - Vector256 res1 = Avx.SetZeroVector256(); - Vector256 res2 = Avx.SetZeroVector256(); - Vector256 res3 = Avx.SetZeroVector256(); + Vector256 res1 = res0; + Vector256 res2 = res0; + Vector256 res3 = res0; - int length = ccol; float* pSrcCurrent = psrc; - nuint address = (nuint)(pMatCurrent); - int misalignment = (int)(address % 32); - int remainder = 0; - - if ((misalignment & 3) != 0) - { - // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations - while (pSrcCurrent < pSrcEnd) - { - Vector256 vector = Avx.LoadVector256(pSrcCurrent); - - float* pMatTemp = pMatCurrent; - res0 = MultiplyAdd(pMatTemp, vector, res0); - res1 = MultiplyAdd(pMatTemp += ccol, vector, res1); - res2 = MultiplyAdd(pMatTemp += ccol, vector, res2); - res3 = MultiplyAdd(pMatTemp += ccol, vector, res3); - - pSrcCurrent += 8; - pMatCurrent += 8; - } - } - else + while (pSrcCurrent < pSrcEnd) { - if (misalignment != 0) - { - // Handle cases where the data is not 256-bit aligned by doing an unaligned read and then - // masking any elements that will be included in the first aligned read - misalignment >>= 2; - misalignment = 8 - misalignment; - - Vector256 mask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8)); - - // We only align pMat since it has significantly more reads. - float* pMatTemp = pMatCurrent; - Vector256 x01 = Avx.And(mask, Avx.LoadVector256(pMatTemp)); - Vector256 x11 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); - Vector256 x21 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); - Vector256 x31 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); - Vector256 vector = Avx.And(mask, Avx.LoadVector256(pSrcCurrent)); - - res0 = Avx.Multiply(x01, vector); - res1 = Avx.Multiply(x11, vector); - res2 = Avx.Multiply(x21, vector); - res3 = Avx.Multiply(x31, vector); - - pMatCurrent += misalignment; - pSrcCurrent += misalignment; - length -= misalignment; - } - - if (length > 7) - { - // Handle all the 256-bit blocks that we can now that we have offset to an aligned address - remainder = length % 8; - - while (pSrcCurrent + 8 <= pSrcEnd) - { - // If we aren't using the VEX-encoding, the JIT will only fold away aligned loads - // (due to semantics of the legacy encoding). - // We don't need an assert, since the instruction will throw for unaligned inputs. - Vector256 vector = Avx.LoadVector256(pSrcCurrent); - - float* pMatTemp = pMatCurrent; - res0 = MultiplyAdd(pMatTemp, vector, res0); - res1 = MultiplyAdd(pMatTemp += ccol, vector, res1); - res2 = MultiplyAdd(pMatTemp += ccol, vector, res2); - res3 = MultiplyAdd(pMatTemp += ccol, vector, res3); - - pSrcCurrent += 8; - pMatCurrent += 8; - } - } - else - { - // Handle the "worst-case" scenario, which is when we have 8-16 elements and the input is not - // 256-bit aligned. This means we can't do any aligned loads and will just end up doing two - // unaligned loads where we mask the input each time. - remainder = length; - } - - if (remainder != 0) - { - // Handle any trailing elements that don't fit into a 256-bit block by moving back so that the next - // unaligned load will read to the end of the array and then mask out any elements already processed - - pMatCurrent -= (8 - remainder); - pSrcCurrent -= (8 - remainder); - - Vector256 mask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8)); - - float* pMatTemp = pMatCurrent; - Vector256 x01 = Avx.And(mask, Avx.LoadVector256(pMatTemp)); - Vector256 x11 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); - Vector256 x21 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); - Vector256 x31 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); - Vector256 vector = Avx.And(mask, Avx.LoadVector256(pSrcCurrent)); - - res0 = MultiplyAdd(x01, vector, res0); - res1 = MultiplyAdd(x11, vector, res1); - res2 = MultiplyAdd(x21, vector, res2); - res3 = MultiplyAdd(x31, vector, res3); - - pMatCurrent += 8; - pSrcCurrent += 8; - } + float* pMatTemp = pMatCurrent; + Contracts.Assert(((nuint)(pMatTemp) % 32) == 0); + Contracts.Assert(((nuint)(pSrcCurrent) % 32) == 0); + + // The JIT will only fold away unaligned loads due to the semantics behind + // the VEX-encoding of the memory operand for `ins xmm, xmm, [mem]`. Since + // modern hardware has unaligned loads that are as fast as aligned loads, + // when it doesn't cross a cache-line/page boundary, we will just assert + // that the alignment is correct and allow for the more-efficient codegen. + Vector256 x01 = Avx.LoadVector256(pMatTemp); + Vector256 x11 = Avx.LoadVector256(pMatTemp += ccol); + Vector256 x21 = Avx.LoadVector256(pMatTemp += ccol); + Vector256 x31 = Avx.LoadVector256(pMatTemp += ccol); + Vector256 x02 = Avx.LoadVector256(pSrcCurrent); + + res0 = MultiplyAdd(x01, x02, res0); + res1 = MultiplyAdd(x11, x02, res1); + res2 = MultiplyAdd(x21, x02, res2); + res3 = MultiplyAdd(x31, x02, res3); + + pSrcCurrent += 8; + pMatCurrent += 8; } // Add up the entries of each, with the 4 results in res0 @@ -293,7 +231,7 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr res0 = Avx.HorizontalAdd(res0, res2); Vector128 sum = Sse.Add(Avx.GetLowerHalf(res0), GetHigh(in res0)); - Sse.Store(pDstCurrent, sum); + Sse.StoreAligned(pDstCurrent, sum); pDstCurrent += 4; pMatCurrent += 3 * ccol; @@ -304,22 +242,22 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr // Partial sparse source vector. public static unsafe void MatMulP(AlignedArray mat, ReadOnlySpan rgposSrc, AlignedArray src, int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow, int ccol) - { - MatMulP(mat.Items, rgposSrc, src.Items, posMin, iposMin, iposEnd, dst.Items, crow, ccol); - } - - public static unsafe void MatMulP(ReadOnlySpan mat, ReadOnlySpan rgposSrc, ReadOnlySpan src, - int posMin, int iposMin, int iposEnd, Span dst, int crow, int ccol) { // REVIEW: For extremely sparse inputs, interchanging the loops would // likely be more efficient. - fixed (float* psrc = &MemoryMarshal.GetReference(src)) - fixed (float* pdst = &MemoryMarshal.GetReference(dst)) - fixed (float* pmat = &MemoryMarshal.GetReference(mat)) - fixed (int* pposSrc = &MemoryMarshal.GetReference(rgposSrc)) - fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0]) - fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0]) + Contracts.Assert(HasCompatibleAlignment(mat)); + Contracts.Assert(HasCompatibleAlignment(src)); + Contracts.Assert(HasCompatibleAlignment(dst)); + + fixed (float* pSrcStart = &src.Items[0]) + fixed (float* pDstStart = &dst.Items[0]) + fixed (float* pMatStart = &mat.Items[0]) + fixed (int* pposSrc = &rgposSrc[0]) { + float* psrc = GetAlignedBase(src, pSrcStart); + float* pdst = GetAlignedBase(dst, pDstStart); + float* pmat = GetAlignedBase(mat, pMatStart); + int* pposMin = pposSrc + iposMin; int* pposEnd = pposSrc + iposEnd; float* pDstEnd = pdst + crow; @@ -327,116 +265,7 @@ public static unsafe void MatMulP(ReadOnlySpan mat, ReadOnlySpan rgp float* pSrcCurrent = psrc - posMin; float* pDstCurrent = pdst; - nuint address = (nuint)(pDstCurrent); - int misalignment = (int)(address % 32); - int length = crow; - int remainder = 0; - - if ((misalignment & 3) != 0) - { - // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations - while (pDstCurrent < pDstEnd) - { - Avx.Store(pDstCurrent, SparseMultiplicationAcrossRow()); - pDstCurrent += 8; - pm0 += 8 * ccol; - } - } - else - { - if (misalignment != 0) - { - // Handle cases where the data is not 256-bit aligned by doing an unaligned read and then - // masking any elements that will be included in the first aligned read - misalignment >>= 2; - misalignment = 8 - misalignment; - - Vector256 mask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8)); - - float* pm1 = pm0 + ccol; - float* pm2 = pm1 + ccol; - float* pm3 = pm2 + ccol; - Vector256 result = Avx.SetZeroVector256(); - - int* ppos = pposMin; - - while (ppos < pposEnd) - { - int col1 = *ppos; - int col2 = col1 + 4 * ccol; - Vector256 x1 = Avx.SetVector256(pm3[col2], pm2[col2], pm1[col2], pm0[col2], - pm3[col1], pm2[col1], pm1[col1], pm0[col1]); - - x1 = Avx.And(mask, x1); - Vector256 x2 = Avx.SetAllVector256(pSrcCurrent[col1]); - result = MultiplyAdd(x2, x1, result); - ppos++; - } - - Avx.Store(pDstCurrent, result); - pDstCurrent += misalignment; - pm0 += misalignment * ccol; - length -= misalignment; - } - - if (length > 7) - { - // Handle all the 256-bit blocks that we can now that we have offset to an aligned address - remainder = length % 8; - while (pDstCurrent < pDstEnd) - { - Avx.Store(pDstCurrent, SparseMultiplicationAcrossRow()); - pDstCurrent += 8; - pm0 += 8 * ccol; - } - } - else - { - // Handle the "worst-case" scenario, which is when we have 8-16 elements and the input is not - // 256-bit aligned. This means we can't do any aligned loads and will just end up doing two - // unaligned loads where we mask the input each time. - remainder = length; - } - - if (remainder != 0) - { - // Handle any trailing elements that don't fit into a 256-bit block by moving back so that the next - // unaligned load will read to the end of the array and then mask out any elements already processed - - pDstCurrent -= (8 - remainder); - pm0 -= (8 - remainder) * ccol; - Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8)); - Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + ((8 - remainder) * 8)); - - float* pm1 = pm0 + ccol; - float* pm2 = pm1 + ccol; - float* pm3 = pm2 + ccol; - Vector256 result = Avx.SetZeroVector256(); - - int* ppos = pposMin; - - while (ppos < pposEnd) - { - int col1 = *ppos; - int col2 = col1 + 4 * ccol; - Vector256 x1 = Avx.SetVector256(pm3[col2], pm2[col2], pm1[col2], pm0[col2], - pm3[col1], pm2[col1], pm1[col1], pm0[col1]); - x1 = Avx.And(x1, trailingMask); - - Vector256 x2 = Avx.SetAllVector256(pSrcCurrent[col1]); - result = MultiplyAdd(x2, x1, result); - ppos++; - } - - result = Avx.Add(result, Avx.And(leadingMask, Avx.LoadVector256(pDstCurrent))); - - Avx.Store(pDstCurrent, result); - pDstCurrent += 8; - pm0 += 8 * ccol; - } - } - - Vector256 SparseMultiplicationAcrossRow() + while (pDstCurrent < pDstEnd) { float* pm1 = pm0 + ccol; float* pm2 = pm1 + ccol; @@ -450,326 +279,133 @@ Vector256 SparseMultiplicationAcrossRow() int col1 = *ppos; int col2 = col1 + 4 * ccol; Vector256 x1 = Avx.SetVector256(pm3[col2], pm2[col2], pm1[col2], pm0[col2], - pm3[col1], pm2[col1], pm1[col1], pm0[col1]); + pm3[col1], pm2[col1], pm1[col1], pm0[col1]); Vector256 x2 = Avx.SetAllVector256(pSrcCurrent[col1]); - result = MultiplyAdd(x2, x1, result); + x2 = Avx.Multiply(x2, x1); + result = Avx.Add(result, x2); + ppos++; } - return result; + Avx.StoreAligned(pDstCurrent, result); + pDstCurrent += 8; + pm0 += 8 * ccol; } } } public static unsafe void MatMulTran(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) { - MatMulTran(mat.Items, src.Items, dst.Items, crow, ccol); - } + Contracts.Assert(HasCompatibleAlignment(mat)); + Contracts.Assert(HasCompatibleAlignment(src)); + Contracts.Assert(HasCompatibleAlignment(dst)); - public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan src, Span dst, int crow, int ccol) - { - fixed (float* psrc = &MemoryMarshal.GetReference(src)) - fixed (float* pdst = &MemoryMarshal.GetReference(dst)) - fixed (float* pmat = &MemoryMarshal.GetReference(mat)) - fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0]) - fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0]) + fixed (float* pSrcStart = &src.Items[0]) + fixed (float* pDstStart = &dst.Items[0]) + fixed (float* pMatStart = &mat.Items[0]) { + float* psrc = GetAlignedBase(src, pSrcStart); + float* pdst = GetAlignedBase(dst, pDstStart); + float* pmat = GetAlignedBase(mat, pMatStart); + float* pSrcEnd = psrc + ccol; float* pDstEnd = pdst + crow; float* pSrcCurrent = psrc; float* pMatCurrent = pmat; - // The reason behind adding the if condtion instead of boolean flag - // is to avoid branching in codegen. - if (pSrcCurrent < pSrcEnd) - { - Vector128 h01 = Sse.LoadVector128(pSrcCurrent); - // Replicate each slot of h01 (ABCD) into its own register. - Vector128 h11 = Avx.Permute(h01, 0x55); // B - Vector128 h21 = Avx.Permute(h01, 0xAA); // C - Vector128 h31 = Avx.Permute(h01, 0xFF); // D - h01 = Avx.Permute(h01, 0x00); // A + // We do 4-way unrolling + Vector128 h01 = Sse.LoadAlignedVector128(pSrcCurrent); + // Replicate each slot of h01 (ABCD) into its own register. + Vector128 h11 = Sse.Shuffle(h01, h01, 0x55); // B + Vector128 h21 = Sse.Shuffle(h01, h01, 0xAA); // C + Vector128 h31 = Sse.Shuffle(h01, h01, 0xFF); // D + h01 = Sse.Shuffle(h01, h01, 0x00); // A - Vector256 x01 = Avx.SetHighLow(h01, h01); - Vector256 x11 = Avx.SetHighLow(h11, h11); - Vector256 x21 = Avx.SetHighLow(h21, h21); - Vector256 x31 = Avx.SetHighLow(h31, h31); + Vector256 x01 = Avx.SetHighLow(h01, h01); + Vector256 x11 = Avx.SetHighLow(h11, h11); + Vector256 x21 = Avx.SetHighLow(h21, h21); + Vector256 x31 = Avx.SetHighLow(h31, h31); - int length = crow; - float* pDstCurrent = pdst; + pSrcCurrent += 4; - nuint address = (nuint)(pMatCurrent); - int misalignment = (int)(address % 32); + float* pDstCurrent = pdst; - if ((misalignment & 3) != 0) - { - // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations - while (pDstCurrent < pDstEnd) - { - float* pMatTemp = pMatCurrent; - Vector256 x02 = Avx.Multiply(x01, Avx.LoadVector256(pMatTemp)); - Vector256 x12 = Avx.Multiply(x11, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x22 = Avx.Multiply(x21, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x32 = Avx.Multiply(x31, Avx.LoadVector256(pMatTemp += crow)); - - x02 = Avx.Add(x02, x12); - x22 = Avx.Add(x22, x32); - x02 = Avx.Add(x02, x22); - - Avx.Store(pDstCurrent, x02); - pDstCurrent += 8; - pMatCurrent += 8; - } - } - else - { - int remainder = 0; - if (misalignment != 0) - { - // Handle cases where the data is not 256-bit aligned by doing an unaligned read and then - // masking any elements that will be included in the first aligned read - misalignment >>= 2; - misalignment = 8 - misalignment; - - Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8)); - - // We only align pMat since it has significantly more reads. - float* pMatTemp = pMatCurrent; - Vector256 x02 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp)); - Vector256 x12 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x22 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x32 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); - - x02 = Avx.Multiply(x01, x02); - x12 = Avx.Multiply(x11, x12); - x22 = Avx.Multiply(x21, x22); - x32 = Avx.Multiply(x31, x32); - - x02 = Avx.Add(x02, x12); - x22 = Avx.Add(x22, x32); - x02 = Avx.Add(x02, x22); - - Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + ((8 - misalignment) * 8)); - Vector256 x3 = Avx.LoadVector256(pDstCurrent); - x02 = Avx.Or(x02, Avx.And(x3, trailingMask)); - - Avx.Store(pDstCurrent, x02); - pMatCurrent += misalignment; - pDstCurrent += misalignment; - length -= misalignment; - } - if (length > 7) - { - // Handle all the 256-bit blocks that we can now that we have offset to an aligned address - remainder = length % 8; - - while (pDstCurrent + 8 <= pDstEnd) - { - // If we aren't using the VEX-encoding, the JIT will only fold away aligned loads - // (due to semantics of the legacy encoding). - // We don't need an assert, since the instruction will throw for unaligned inputs. - float* pMatTemp = pMatCurrent; - - Vector256 x02 = Avx.Multiply(x01, Avx.LoadVector256(pMatTemp)); - Vector256 x12 = Avx.Multiply(x11, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x22 = Avx.Multiply(x21, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x32 = Avx.Multiply(x31, Avx.LoadVector256(pMatTemp += crow)); - - x02 = Avx.Add(x02, x12); - x22 = Avx.Add(x22, x32); - x02 = Avx.Add(x02, x22); - - Avx.Store(pDstCurrent, x02); - pDstCurrent += 8; - pMatCurrent += 8; - } - } - else - { - // Handle the "worst-case" scenario, which is when we have 8-16 elements and the input is not - // 256-bit aligned. This means we can't do any aligned loads and will just end up doing two - // unaligned loads where we mask the input each time. - remainder = length; - } + while (pDstCurrent < pDstEnd) + { + float* pMatTemp = pMatCurrent; + Contracts.Assert(((nuint)(pMatTemp) % 32) == 0); - if (remainder != 0) - { - // Handle any trailing elements that don't fit into a 256-bit block by moving back so that the next - // unaligned load will read to the end of the array and then mask out any elements already processed - - pMatCurrent -= (8 - remainder); - pDstCurrent -= (8 - remainder); - Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8)); - - float* pMatTemp = pMatCurrent; - Vector256 x02 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp)); - Vector256 x12 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x22 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x32 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); - - x02 = Avx.Multiply(x01, x02); - x12 = Avx.Multiply(x11, x12); - x22 = Avx.Multiply(x21, x22); - x32 = Avx.Multiply(x31, x32); - - x02 = Avx.Add(x02, x12); - x22 = Avx.Add(x22, x32); - x02 = Avx.Add(x02, x22); - - Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + ((8 - remainder) * 8)); - Vector256 x3 = Avx.LoadVector256(pDstCurrent); - x02 = Avx.Or(x02, Avx.And(x3, leadingMask)); - - Avx.Store(pDstCurrent, x02); - pDstCurrent += 8; - pMatCurrent += 8; - } - } + // The JIT will only fold away unaligned loads due to the semantics behind + // the VEX-encoding of the memory operand for `ins xmm, xmm, [mem]`. Since + // modern hardware has unaligned loads that are as fast as aligned loads, + // when it doesn't cross a cache-line/page boundary, we will just assert + // that the alignment is correct and allow for the more-efficient codegen. + Vector256 x02 = Avx.LoadVector256(pMatTemp); + Vector256 x12 = Avx.LoadVector256(pMatTemp += crow); + Vector256 x22 = Avx.LoadVector256(pMatTemp += crow); + Vector256 x32 = Avx.LoadVector256(pMatTemp += crow); - pMatCurrent += 3 * crow; - pSrcCurrent += 4; + x02 = Avx.Multiply(x01, x02); + x02 = MultiplyAdd(x11, x12, x02); + + x22 = Avx.Multiply(x21, x22); + x22 = MultiplyAdd(x31, x32, x22); + + x02 = Avx.Add(x02, x22); + Avx.StoreAligned(pDstCurrent, x02); + + pDstCurrent += 8; + pMatCurrent += 8; } - // We do 4-way unrolling + pMatCurrent += 3 * crow; + while (pSrcCurrent < pSrcEnd) { - Vector128 h01 = Sse.LoadVector128(pSrcCurrent); + h01 = Sse.LoadAlignedVector128(pSrcCurrent); // Replicate each slot of h01 (ABCD) into its own register. - Vector128 h11 = Avx.Permute(h01, 0x55); // B - Vector128 h21 = Avx.Permute(h01, 0xAA); // C - Vector128 h31 = Avx.Permute(h01, 0xFF); // D - h01 = Avx.Permute(h01, 0x00); // A - - Vector256 x01 = Avx.SetHighLow(h01, h01); - Vector256 x11 = Avx.SetHighLow(h11, h11); - Vector256 x21 = Avx.SetHighLow(h21, h21); - Vector256 x31 = Avx.SetHighLow(h31, h31); + h11 = Sse.Shuffle(h01, h01, 0x55); // B + h21 = Sse.Shuffle(h01, h01, 0xAA); // C + h31 = Sse.Shuffle(h01, h01, 0xFF); // D + h01 = Sse.Shuffle(h01, h01, 0x00); // A - int length = crow; - float* pDstCurrent = pdst; + x01 = Avx.SetHighLow(h01, h01); + x11 = Avx.SetHighLow(h11, h11); + x21 = Avx.SetHighLow(h21, h21); + x31 = Avx.SetHighLow(h31, h31); - nuint address = (nuint)(pMatCurrent); - int misalignment = (int)(address % 32); + pDstCurrent = pdst; - if ((misalignment & 3) != 0) + while (pDstCurrent < pDstEnd) { - while (pDstCurrent < pDstEnd) - { - float* pMatTemp = pMatCurrent; - Vector256 x02 = Avx.Multiply(x01, Avx.LoadVector256(pMatTemp)); - Vector256 x12 = Avx.Multiply(x11, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x22 = Avx.Multiply(x21, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x32 = Avx.Multiply(x31, Avx.LoadVector256(pMatTemp += crow)); + float* pMatTemp = pMatCurrent; - x02 = Avx.Add(x02, x12); - x22 = Avx.Add(x22, x32); - x02 = Avx.Add(x02, x22); + Contracts.Assert(((nuint)(pMatTemp) % 32) == 0); + Contracts.Assert(((nuint)(pDstCurrent) % 32) == 0); - x02 = Avx.Add(x02, Avx.LoadVector256(pDstCurrent)); + // The JIT will only fold away unaligned loads due to the semantics behind + // the VEX-encoding of the memory operand for `ins xmm, xmm, [mem]`. Since + // modern hardware has unaligned loads that are as fast as aligned loads, + // when it doesn't cross a cache-line/page boundary, we will just assert + // that the alignment is correct and allow for the more-efficient codegen. + Vector256 x02 = Avx.LoadVector256(pMatTemp); + Vector256 x12 = Avx.LoadVector256(pMatTemp += crow); + Vector256 x22 = Avx.LoadVector256(pMatTemp += crow); + Vector256 x32 = Avx.LoadVector256(pMatTemp += crow); + Vector256 x3 = Avx.LoadVector256(pDstCurrent); - Avx.Store(pDstCurrent, x02); - pDstCurrent += 8; - pMatCurrent += 8; - } - } - else - { - int remainder = 0; - if (misalignment != 0) - { - // Handle cases where the data is not 256-bit aligned by doing an unaligned read and then - // masking any elements that will be included in the first aligned read - misalignment >>= 2; - misalignment = 8 - misalignment; - - Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8)); - - // We only align pMat since it has significantly more reads. - float* pMatTemp = pMatCurrent; - Vector256 x02 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp)); - Vector256 x12 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x22 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x32 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); - - x02 = Avx.Multiply(x01, x02); - x12 = Avx.Multiply(x11, x12); - x22 = Avx.Multiply(x21, x22); - x32 = Avx.Multiply(x31, x32); - - x02 = Avx.Add(x02, x12); - x22 = Avx.Add(x22, x32); - x02 = Avx.Add(x02, x22); - - Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + ((8 - misalignment) * 8)); - Vector256 x3 = Avx.LoadVector256(pDstCurrent); - x02 = Avx.Or(x02, Avx.And(x3, trailingMask)); - - x02 = Avx.Add(x02, Avx.And(x3, leadingMask)); - - Avx.Store(pDstCurrent, x02); - pMatCurrent += misalignment; - pDstCurrent += misalignment; - length -= misalignment; - } - if (length > 7) - { - remainder = length % 8; - while (pDstCurrent + 8 <= pDstEnd) - { - float* pMatTemp = pMatCurrent; - - Vector256 x02 = Avx.Multiply(x01, Avx.LoadVector256(pMatTemp)); - Vector256 x12 = Avx.Multiply(x11, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x22 = Avx.Multiply(x21, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x32 = Avx.Multiply(x31, Avx.LoadVector256(pMatTemp += crow)); - - x02 = Avx.Add(x02, x12); - x22 = Avx.Add(x22, x32); - x02 = Avx.Add(x02, x22); - - x02 = Avx.Add(x02, Avx.LoadVector256(pDstCurrent)); - - Avx.Store(pDstCurrent, x02); - pDstCurrent += 8; - pMatCurrent += 8; - } - } - else - { - remainder = length; - } + x02 = Avx.Multiply(x01, x02); + x02 = MultiplyAdd(x11, x12, x02); - if (remainder != 0) - { - pMatCurrent -= (8 - remainder); - pDstCurrent -= (8 - remainder); - Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8)); - - float* pMatTemp = pMatCurrent; - Vector256 x02 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp)); - Vector256 x12 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x22 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); - Vector256 x32 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); - - x02 = Avx.Multiply(x01, x02); - x12 = Avx.Multiply(x11, x12); - x22 = Avx.Multiply(x21, x22); - x32 = Avx.Multiply(x31, x32); - - x02 = Avx.Add(x02, x12); - x22 = Avx.Add(x22, x32); - x02 = Avx.Add(x02, x22); - - Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + ((8 - remainder) * 8)); - Vector256 x3 = Avx.LoadVector256(pDstCurrent); - x02 = Avx.Or(x02, Avx.And(x3, leadingMask)); - - x02 = Avx.Add(x02, Avx.And(x3, trailingMask)); - - Avx.Store(pDstCurrent, x02); - pDstCurrent += 8; - pMatCurrent += 8; - } + x22 = Avx.Multiply(x21, x22); + x22 = MultiplyAdd(x31, x32, x22); + + x02 = Avx.Add(x02, x22); + x3 = Avx.Add(x02, x3); + Avx.StoreAligned(pDstCurrent, x3); + + pDstCurrent += 8; + pMatCurrent += 8; } pMatCurrent += 3 * crow; diff --git a/src/Microsoft.ML.CpuMath/CpuMathUtils.netstandard.cs b/src/Microsoft.ML.CpuMath/CpuMathUtils.netstandard.cs index 5ecbc62be1..a046bbba98 100644 --- a/src/Microsoft.ML.CpuMath/CpuMathUtils.netstandard.cs +++ b/src/Microsoft.ML.CpuMath/CpuMathUtils.netstandard.cs @@ -5,6 +5,7 @@ using Microsoft.ML.Runtime.Internal.CpuMath.Core; using System; using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; namespace Microsoft.ML.Runtime.Internal.CpuMath { @@ -18,57 +19,411 @@ internal static partial class CpuMathUtils public static int GetVectorAlignment() => Vector128Alignment; - public static void MatrixTimesSource(bool transpose, AlignedArray matrix, AlignedArray source, AlignedArray destination, int stride) => SseUtils.MatTimesSrc(transpose, matrix, source, destination, stride); + private static bool Compat(AlignedArray a) + { + Contracts.AssertValue(a); + Contracts.Assert(a.Size > 0); + return a.CbAlign == Vector128Alignment; + } - public static void MatrixTimesSource(AlignedArray matrix, ReadOnlySpan rgposSrc, AlignedArray sourceValues, - int posMin, int iposMin, int iposLimit, AlignedArray destination, int stride) => SseUtils.MatTimesSrc(matrix, rgposSrc, sourceValues, posMin, iposMin, iposLimit, destination, stride); + private static unsafe float* Ptr(AlignedArray a, float* p) + { + Contracts.AssertValue(a); + float* q = p + a.GetBase((long)p); + Contracts.Assert(((long)q & (Vector128Alignment - 1)) == 0); + return q; + } - public static void Add(float value, Span destination) => SseUtils.Add(value, destination); + public static void MatrixTimesSource(bool tran, AlignedArray mat, AlignedArray src, AlignedArray dst, int crun) + { + Contracts.Assert(Compat(mat)); + Contracts.Assert(Compat(src)); + Contracts.Assert(Compat(dst)); + Contracts.Assert(mat.Size == dst.Size * src.Size); - public static void Scale(float value, Span destination) => SseUtils.Scale(value, destination); + unsafe + { + fixed (float* pmat = &mat.Items[0]) + fixed (float* psrc = &src.Items[0]) + fixed (float* pdst = &dst.Items[0]) + { + if (!tran) + { + Contracts.Assert(0 <= crun && crun <= dst.Size); + Thunk.MatMul(Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), crun, src.Size); + } + else + { + Contracts.Assert(0 <= crun && crun <= src.Size); + Thunk.MatMulTran(Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), dst.Size, crun); + } + } + } + } - public static void Scale(float value, ReadOnlySpan source, Span destination, int count) => SseUtils.Scale(value, source, destination, count); + public static void MatrixTimesSource(AlignedArray mat, ReadOnlySpan rgposSrc, AlignedArray srcValues, + int posMin, int iposMin, int iposLim, AlignedArray dst, int crun) + { + Contracts.Assert(Compat(mat)); + Contracts.Assert(Compat(srcValues)); + Contracts.Assert(Compat(dst)); + Contracts.Assert(0 <= iposMin && iposMin <= iposLim && iposLim <= rgposSrc.Length); + Contracts.Assert(mat.Size == dst.Size * srcValues.Size); - public static void ScaleAdd(float value, float addend, Span destination) => SseUtils.ScaleAdd(value, addend, destination); + if (iposMin >= iposLim) + { + dst.ZeroItems(); + return; + } + Contracts.AssertNonEmpty(rgposSrc); + unsafe + { + fixed (float* pdst = &dst.Items[0]) + fixed (float* pmat = &mat.Items[0]) + fixed (float* psrc = &srcValues.Items[0]) + fixed (int* ppossrc = &rgposSrc[0]) + { + Contracts.Assert(0 <= crun && crun <= dst.Size); + Thunk.MatMulP(Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), crun, srcValues.Size); + } + } + } - public static void AddScale(float value, ReadOnlySpan source, Span destination, int count) => SseUtils.AddScale(value, source, destination, count); + // dst += a + public static void Add(float a, Span dst) + { + Contracts.AssertNonEmpty(dst); - public static void AddScale(float value, ReadOnlySpan source, ReadOnlySpan indices, Span destination, int count) => SseUtils.AddScale(value, source, indices, destination, count); + unsafe + { + fixed (float* pdst = &MemoryMarshal.GetReference(dst)) + Thunk.AddScalarU(a, pdst, dst.Length); + } + } - public static void AddScaleCopy(float value, ReadOnlySpan source, ReadOnlySpan destination, Span res, int count) => SseUtils.AddScaleCopy(value, source, destination, res, count); + public static void Scale(float a, Span dst) + { + Contracts.AssertNonEmpty(dst); - public static void Add(ReadOnlySpan source, Span destination, int count) => SseUtils.Add(source, destination, count); + unsafe + { + fixed (float* pd = &MemoryMarshal.GetReference(dst)) + Thunk.Scale(a, pd, dst.Length); + } + } - public static void Add(ReadOnlySpan source, ReadOnlySpan indices, Span destination, int count) => SseUtils.Add(source, indices, destination, count); + // dst = a * src + public static void Scale(float a, ReadOnlySpan src, Span dst, int count) + { + Contracts.AssertNonEmpty(src); + Contracts.Assert(0 < count && count <= src.Length); + Contracts.AssertNonEmpty(dst); + Contracts.Assert(count <= dst.Length); - public static void MulElementWise(ReadOnlySpan left, ReadOnlySpan right, Span destination, int count) => SseUtils.MulElementWise(left, right, destination, count); + unsafe + { + fixed (float* psrc = &MemoryMarshal.GetReference(src)) + fixed (float* pdst = &MemoryMarshal.GetReference(dst)) + { + Thunk.ScaleSrcU(a, psrc, pdst, count); + } + } + } - public static float Sum(ReadOnlySpan source) => SseUtils.Sum(source); + // dst[i] = a * (dst[i] + b) + public static void ScaleAdd(float a, float b, Span dst) + { + Contracts.AssertNonEmpty(dst); - public static float SumSq(ReadOnlySpan source) => SseUtils.SumSq(source); + unsafe + { + fixed (float* pdst = &MemoryMarshal.GetReference(dst)) + Thunk.ScaleAddU(a, b, pdst, dst.Length); + } + } - public static float SumSq(float mean, ReadOnlySpan source) => SseUtils.SumSq(mean, source); + public static void AddScale(float a, ReadOnlySpan src, Span dst, int count) + { + Contracts.AssertNonEmpty(src); + Contracts.Assert(0 < count && count <= src.Length); + Contracts.AssertNonEmpty(dst); + Contracts.Assert(count <= dst.Length); - public static float SumAbs(ReadOnlySpan source) => SseUtils.SumAbs(source); + unsafe + { + fixed (float* psrc = &MemoryMarshal.GetReference(src)) + fixed (float* pdst = &MemoryMarshal.GetReference(dst)) + Thunk.AddScaleU(a, psrc, pdst, count); + } + } - public static float SumAbs(float mean, ReadOnlySpan source) => SseUtils.SumAbs(mean, source); + public static void AddScale(float a, ReadOnlySpan src, ReadOnlySpan indices, Span dst, int count) + { + Contracts.AssertNonEmpty(src); + Contracts.Assert(0 < count && count <= src.Length); + Contracts.AssertNonEmpty(indices); + Contracts.Assert(count <= indices.Length); + Contracts.AssertNonEmpty(dst); + Contracts.Assert(count < dst.Length); - public static float MaxAbs(ReadOnlySpan source) => SseUtils.MaxAbs(source); + unsafe + { + fixed (float* psrc = &MemoryMarshal.GetReference(src)) + fixed (int* pi = &MemoryMarshal.GetReference(indices)) + fixed (float* pdst = &MemoryMarshal.GetReference(dst)) + Thunk.AddScaleSU(a, psrc, pi, pdst, count); + } + } - public static float MaxAbsDiff(float mean, ReadOnlySpan source) => SseUtils.MaxAbsDiff(mean, source); + public static void AddScaleCopy(float a, ReadOnlySpan src, ReadOnlySpan dst, Span res, int count) + { + Contracts.AssertNonEmpty(dst); + Contracts.Assert(0 < count && count <= dst.Length); + Contracts.AssertNonEmpty(src); + Contracts.Assert(count <= src.Length); + Contracts.AssertNonEmpty(res); + Contracts.Assert(count <= res.Length); - public static float DotProductDense(ReadOnlySpan left, ReadOnlySpan right, int count) => SseUtils.DotProductDense(left, right, count); + unsafe + { + fixed (float* pdst = &MemoryMarshal.GetReference(dst)) + fixed (float* psrc = &MemoryMarshal.GetReference(src)) + fixed (float* pres = &MemoryMarshal.GetReference(res)) + Thunk.AddScaleCopyU(a, psrc, pdst, pres, count); + } + } - public static float DotProductSparse(ReadOnlySpan left, ReadOnlySpan right, ReadOnlySpan indices, int count) => SseUtils.DotProductSparse(left, right, indices, count); + public static void Add(ReadOnlySpan src, Span dst, int count) + { + Contracts.AssertNonEmpty(src); + Contracts.Assert(0 < count && count <= src.Length); + Contracts.AssertNonEmpty(dst); + Contracts.Assert(count <= dst.Length); - public static float L2DistSquared(ReadOnlySpan left, ReadOnlySpan right, int count) => SseUtils.L2DistSquared(left, right, count); + unsafe + { + fixed (float* ps = &MemoryMarshal.GetReference(src)) + fixed (float* pd = &MemoryMarshal.GetReference(dst)) + Thunk.AddU(ps, pd, count); + } + } - public static void ZeroMatrixItems(AlignedArray destination, int ccol, int cfltRow, int[] indices) => SseUtils.ZeroMatrixItems(destination, ccol, cfltRow, indices); + public static void Add(ReadOnlySpan src, ReadOnlySpan indices, Span dst, int count) + { + Contracts.AssertNonEmpty(src); + Contracts.Assert(0 < count && count <= src.Length); + Contracts.AssertNonEmpty(indices); + Contracts.Assert(count <= indices.Length); + Contracts.AssertNonEmpty(dst); + Contracts.Assert(count < dst.Length); - public static void SdcaL1UpdateDense(float primalUpdate, int count, ReadOnlySpan source, float threshold, Span v, Span w) - => SseUtils.SdcaL1UpdateDense(primalUpdate, count, source, threshold, v, w); + unsafe + { + fixed (float* ps = &MemoryMarshal.GetReference(src)) + fixed (int* pi = &MemoryMarshal.GetReference(indices)) + fixed (float* pd = &MemoryMarshal.GetReference(dst)) + Thunk.AddSU(ps, pi, pd, count); + } + } + + public static void MulElementWise(ReadOnlySpan src1, ReadOnlySpan src2, Span dst, int count) + { + Contracts.AssertNonEmpty(src1); + Contracts.Assert(0 < count && count <= src1.Length); + Contracts.AssertNonEmpty(src2); + Contracts.Assert(0 < count && count <= src2.Length); + Contracts.AssertNonEmpty(dst); + unsafe + { + fixed (float* ps1 = &MemoryMarshal.GetReference(src1)) + fixed (float* ps2 = &MemoryMarshal.GetReference(src2)) + fixed (float* pd = &MemoryMarshal.GetReference(dst)) + Thunk.MulElementWiseU(ps1, ps2, pd, count); + } + } + + public static float Sum(ReadOnlySpan src) + { + Contracts.AssertNonEmpty(src); + + unsafe + { + fixed (float* psrc = &MemoryMarshal.GetReference(src)) + return Thunk.Sum(psrc, src.Length); + } + } + + public static float SumSq(ReadOnlySpan src) + { + Contracts.AssertNonEmpty(src); + + unsafe + { + fixed (float* psrc = &MemoryMarshal.GetReference(src)) + return Thunk.SumSqU(psrc, src.Length); + } + } + + public static float SumSq(float mean, ReadOnlySpan src) + { + Contracts.AssertNonEmpty(src); + + unsafe + { + fixed (float* psrc = &MemoryMarshal.GetReference(src)) + return (mean == 0 ? Thunk.SumSqU(psrc, src.Length) : Thunk.SumSqDiffU(mean, psrc, src.Length)); + } + } + + public static float SumAbs(ReadOnlySpan src) + { + Contracts.AssertNonEmpty(src); + + unsafe + { + fixed (float* psrc = &MemoryMarshal.GetReference(src)) + return Thunk.SumAbsU(psrc, src.Length); + } + } + + public static float SumAbs(float mean, ReadOnlySpan src) + { + Contracts.AssertNonEmpty(src); + + unsafe + { + fixed (float* psrc = &MemoryMarshal.GetReference(src)) + return (mean == 0 ? Thunk.SumAbsU(psrc, src.Length) : Thunk.SumAbsDiffU(mean, psrc, src.Length)); + } + } + + public static float MaxAbs(ReadOnlySpan src) + { + Contracts.AssertNonEmpty(src); + + unsafe + { + fixed (float* psrc = &MemoryMarshal.GetReference(src)) + return Thunk.MaxAbsU(psrc, src.Length); + } + } + + public static float MaxAbsDiff(float mean, ReadOnlySpan src) + { + Contracts.AssertNonEmpty(src); + + unsafe + { + fixed (float* psrc = &MemoryMarshal.GetReference(src)) + return Thunk.MaxAbsDiffU(mean, psrc, src.Length); + } + } + + public static float DotProductDense(ReadOnlySpan a, ReadOnlySpan b, int count) + { + Contracts.AssertNonEmpty(a); + Contracts.AssertNonEmpty(b); + Contracts.Assert(0 < count); + Contracts.Assert(a.Length >= count); + Contracts.Assert(b.Length >= count); + + unsafe + { + fixed (float* pa = &MemoryMarshal.GetReference(a)) + fixed (float* pb = &MemoryMarshal.GetReference(b)) + return Thunk.DotU(pa, pb, count); + } + } + + public static float DotProductSparse(ReadOnlySpan a, ReadOnlySpan b, ReadOnlySpan indices, int count) + { + Contracts.AssertNonEmpty(a); + Contracts.AssertNonEmpty(b); + Contracts.Assert(0 < count); + Contracts.Assert(count < a.Length); + Contracts.Assert(count <= b.Length); + Contracts.Assert(count <= indices.Length); + + unsafe + { + fixed (float* pa = &MemoryMarshal.GetReference(a)) + fixed (float* pb = &MemoryMarshal.GetReference(b)) + fixed (int* pi = &MemoryMarshal.GetReference(indices)) + return Thunk.DotSU(pa, pb, pi, count); + } + } + + public static float L2DistSquared(ReadOnlySpan a, ReadOnlySpan b, int count) + { + Contracts.AssertNonEmpty(a); + Contracts.AssertNonEmpty(b); + Contracts.Assert(0 < count && count <= a.Length); + Contracts.Assert(count <= b.Length); + + unsafe + { + fixed (float* pa = &MemoryMarshal.GetReference(a)) + fixed (float* pb = &MemoryMarshal.GetReference(b)) + return Thunk.Dist2(pa, pb, count); + } + } + + public static void ZeroMatrixItems(AlignedArray dst, int ccol, int cfltRow, int[] indices) + { + Contracts.Assert(0 < ccol && ccol <= cfltRow); + + unsafe + { + fixed (float* pdst = &dst.Items[0]) + fixed (int* pi = &indices[0]) + { + if (ccol == cfltRow) + Thunk.ZeroItemsU(Ptr(dst, pdst), dst.Size, pi, indices.Length); + else + Thunk.ZeroMatrixItemsCore(Ptr(dst, pdst), dst.Size, ccol, cfltRow, pi, indices.Length); + } + } + } + + public static void SdcaL1UpdateDense(float primalUpdate, int count, ReadOnlySpan src, float threshold, Span v, Span w) + { + Contracts.AssertNonEmpty(src); + Contracts.Assert(count <= src.Length); + Contracts.AssertNonEmpty(v); + Contracts.Assert(count <= v.Length); + Contracts.AssertNonEmpty(w); + Contracts.Assert(count <= w.Length); + Contracts.Assert(count > 0); + + unsafe + { + fixed (float* psrc = &MemoryMarshal.GetReference(src)) + fixed (float* pd1 = &MemoryMarshal.GetReference(v)) + fixed (float* pd2 = &MemoryMarshal.GetReference(w)) + Thunk.SdcaL1UpdateU(primalUpdate, psrc, threshold, pd1, pd2, count); + } + } public static void SdcaL1UpdateSparse(float primalUpdate, int count, ReadOnlySpan source, ReadOnlySpan indices, float threshold, Span v, Span w) - => SseUtils.SdcaL1UpdateSparse(primalUpdate, count, source, indices, threshold, v, w); + { + Contracts.AssertNonEmpty(source); + Contracts.Assert(count <= source.Length); + Contracts.AssertNonEmpty(indices); + Contracts.Assert(count <= indices.Length); + Contracts.AssertNonEmpty(v); + Contracts.Assert(count <= v.Length); + Contracts.AssertNonEmpty(w); + Contracts.Assert(count <= w.Length); + Contracts.Assert(count > 0); + + unsafe + { + fixed (float* psrc = &MemoryMarshal.GetReference(source)) + fixed (int* pi = &MemoryMarshal.GetReference(indices)) + fixed (float* pd1 = &MemoryMarshal.GetReference(v)) + fixed (float* pd2 = &MemoryMarshal.GetReference(w)) + Thunk.SdcaL1UpdateSU(primalUpdate, psrc, pi, threshold, pd1, pd2, count); + } + } } -} +} \ No newline at end of file diff --git a/src/Microsoft.ML.CpuMath/Sse.cs b/src/Microsoft.ML.CpuMath/Sse.cs deleted file mode 100644 index 8b1c4da70f..0000000000 --- a/src/Microsoft.ML.CpuMath/Sse.cs +++ /dev/null @@ -1,427 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using Microsoft.ML.Runtime.Internal.CpuMath.Core; -using System; -using System.Runtime.InteropServices; - -namespace Microsoft.ML.Runtime.Internal.CpuMath -{ - /// - /// Keep Sse.cs in sync with Avx.cs. When making changes to one, use BeyondCompare or a similar tool - /// to view diffs and propagate appropriate changes to the other. - /// - [BestFriend] - internal static class SseUtils - { - public const int CbAlign = 16; - - private static bool Compat(AlignedArray a) - { - Contracts.AssertValue(a); - Contracts.Assert(a.Size > 0); - return a.CbAlign == CbAlign; - } - - private static unsafe float* Ptr(AlignedArray a, float* p) - { - Contracts.AssertValue(a); - float* q = p + a.GetBase((long)p); - Contracts.Assert(((long)q & (CbAlign - 1)) == 0); - return q; - } - - public static void MatTimesSrc(bool tran, AlignedArray mat, AlignedArray src, AlignedArray dst, int crun) - { - Contracts.Assert(Compat(mat)); - Contracts.Assert(Compat(src)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(mat.Size == dst.Size * src.Size); - - unsafe - { - fixed (float* pmat = &mat.Items[0]) - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - { - if (!tran) - { - Contracts.Assert(0 <= crun && crun <= dst.Size); - Thunk.MatMul(Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), crun, src.Size); - } - else - { - Contracts.Assert(0 <= crun && crun <= src.Size); - Thunk.MatMulTran(Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), dst.Size, crun); - } - } - } - } - - public static void MatTimesSrc(AlignedArray mat, ReadOnlySpan rgposSrc, AlignedArray srcValues, - int posMin, int iposMin, int iposLim, AlignedArray dst, int crun) - { - Contracts.Assert(Compat(mat)); - Contracts.Assert(Compat(srcValues)); - Contracts.Assert(Compat(dst)); - Contracts.Assert(0 <= iposMin && iposMin <= iposLim && iposLim <= rgposSrc.Length); - Contracts.Assert(mat.Size == dst.Size * srcValues.Size); - - if (iposMin >= iposLim) - { - dst.ZeroItems(); - return; - } - Contracts.AssertNonEmpty(rgposSrc); - unsafe - { - fixed (float* pdst = &dst.Items[0]) - fixed (float* pmat = &mat.Items[0]) - fixed (float* psrc = &srcValues.Items[0]) - fixed (int* ppossrc = &rgposSrc[0]) - { - Contracts.Assert(0 <= crun && crun <= dst.Size); - Thunk.MatMulP(Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), crun, srcValues.Size); - } - } - } - - // dst += a - public static void Add(float a, Span dst) - { - Contracts.AssertNonEmpty(dst); - - unsafe - { - fixed (float* pdst = &MemoryMarshal.GetReference(dst)) - Thunk.AddScalarU(a, pdst, dst.Length); - } - } - - public static void Scale(float a, Span dst) - { - Contracts.AssertNonEmpty(dst); - - unsafe - { - fixed (float* pd = &MemoryMarshal.GetReference(dst)) - Thunk.Scale(a, pd, dst.Length); - } - } - - // dst = a * src - public static void Scale(float a, ReadOnlySpan src, Span dst, int count) - { - Contracts.AssertNonEmpty(src); - Contracts.Assert(0 < count && count <= src.Length); - Contracts.AssertNonEmpty(dst); - Contracts.Assert(count <= dst.Length); - - unsafe - { - fixed (float* psrc = &MemoryMarshal.GetReference(src)) - fixed (float* pdst = &MemoryMarshal.GetReference(dst)) - { - Thunk.ScaleSrcU(a, psrc, pdst, count); - } - } - } - - // dst[i] = a * (dst[i] + b) - public static void ScaleAdd(float a, float b, Span dst) - { - Contracts.AssertNonEmpty(dst); - - unsafe - { - fixed (float* pdst = &MemoryMarshal.GetReference(dst)) - Thunk.ScaleAddU(a, b, pdst, dst.Length); - } - } - - public static void AddScale(float a, ReadOnlySpan src, Span dst, int count) - { - Contracts.AssertNonEmpty(src); - Contracts.Assert(0 < count && count <= src.Length); - Contracts.AssertNonEmpty(dst); - Contracts.Assert(count <= dst.Length); - - unsafe - { - fixed (float* psrc = &MemoryMarshal.GetReference(src)) - fixed (float* pdst = &MemoryMarshal.GetReference(dst)) - Thunk.AddScaleU(a, psrc, pdst, count); - } - } - - public static void AddScale(float a, ReadOnlySpan src, ReadOnlySpan indices, Span dst, int count) - { - Contracts.AssertNonEmpty(src); - Contracts.Assert(0 < count && count <= src.Length); - Contracts.AssertNonEmpty(indices); - Contracts.Assert(count <= indices.Length); - Contracts.AssertNonEmpty(dst); - Contracts.Assert(count < dst.Length); - - unsafe - { - fixed (float* psrc = &MemoryMarshal.GetReference(src)) - fixed (int* pi = &MemoryMarshal.GetReference(indices)) - fixed (float* pdst = &MemoryMarshal.GetReference(dst)) - Thunk.AddScaleSU(a, psrc, pi, pdst, count); - } - } - - public static void AddScaleCopy(float a, ReadOnlySpan src, ReadOnlySpan dst, Span res, int count) - { - Contracts.AssertNonEmpty(dst); - Contracts.Assert(0 < count && count <= dst.Length); - Contracts.AssertNonEmpty(src); - Contracts.Assert(count <= src.Length); - Contracts.AssertNonEmpty(res); - Contracts.Assert(count <= res.Length); - - unsafe - { - fixed (float* pdst = &MemoryMarshal.GetReference(dst)) - fixed (float* psrc = &MemoryMarshal.GetReference(src)) - fixed (float* pres = &MemoryMarshal.GetReference(res)) - Thunk.AddScaleCopyU(a, psrc, pdst, pres, count); - } - } - - public static void Add(ReadOnlySpan src, Span dst, int count) - { - Contracts.AssertNonEmpty(src); - Contracts.Assert(0 < count && count <= src.Length); - Contracts.AssertNonEmpty(dst); - Contracts.Assert(count <= dst.Length); - - unsafe - { - fixed (float* ps = &MemoryMarshal.GetReference(src)) - fixed (float* pd = &MemoryMarshal.GetReference(dst)) - Thunk.AddU(ps, pd, count); - } - } - - public static void Add(ReadOnlySpan src, ReadOnlySpan indices, Span dst, int count) - { - Contracts.AssertNonEmpty(src); - Contracts.Assert(0 < count && count <= src.Length); - Contracts.AssertNonEmpty(indices); - Contracts.Assert(count <= indices.Length); - Contracts.AssertNonEmpty(dst); - Contracts.Assert(count < dst.Length); - - unsafe - { - fixed (float* ps = &MemoryMarshal.GetReference(src)) - fixed (int* pi = &MemoryMarshal.GetReference(indices)) - fixed (float* pd = &MemoryMarshal.GetReference(dst)) - Thunk.AddSU(ps, pi, pd, count); - } - } - - public static void MulElementWise(ReadOnlySpan src1, ReadOnlySpan src2, Span dst, int count) - { - Contracts.AssertNonEmpty(src1); - Contracts.Assert(0 < count && count <= src1.Length); - Contracts.AssertNonEmpty(src2); - Contracts.Assert(0 < count && count <= src2.Length); - Contracts.AssertNonEmpty(dst); - unsafe - { - fixed (float* ps1 = &MemoryMarshal.GetReference(src1)) - fixed (float* ps2 = &MemoryMarshal.GetReference(src2)) - fixed (float* pd = &MemoryMarshal.GetReference(dst)) - Thunk.MulElementWiseU(ps1, ps2, pd, count); - } - } - - public static float Sum(ReadOnlySpan src) - { - Contracts.AssertNonEmpty(src); - - unsafe - { - fixed (float* psrc = &MemoryMarshal.GetReference(src)) - return Thunk.Sum(psrc, src.Length); - } - } - - public static float SumSq(ReadOnlySpan src) - { - Contracts.AssertNonEmpty(src); - - unsafe - { - fixed (float* psrc = &MemoryMarshal.GetReference(src)) - return Thunk.SumSqU(psrc, src.Length); - } - } - - public static float SumSq(float mean, ReadOnlySpan src) - { - Contracts.AssertNonEmpty(src); - - unsafe - { - fixed (float* psrc = &MemoryMarshal.GetReference(src)) - return (mean == 0 ? Thunk.SumSqU(psrc, src.Length) : Thunk.SumSqDiffU(mean, psrc, src.Length)); - } - } - - public static float SumAbs(ReadOnlySpan src) - { - Contracts.AssertNonEmpty(src); - - unsafe - { - fixed (float* psrc = &MemoryMarshal.GetReference(src)) - return Thunk.SumAbsU(psrc, src.Length); - } - } - - public static float SumAbs(float mean, ReadOnlySpan src) - { - Contracts.AssertNonEmpty(src); - - unsafe - { - fixed (float* psrc = &MemoryMarshal.GetReference(src)) - return (mean == 0 ? Thunk.SumAbsU(psrc, src.Length) : Thunk.SumAbsDiffU(mean, psrc, src.Length)); - } - } - - public static float MaxAbs(ReadOnlySpan src) - { - Contracts.AssertNonEmpty(src); - - unsafe - { - fixed (float* psrc = &MemoryMarshal.GetReference(src)) - return Thunk.MaxAbsU(psrc, src.Length); - } - } - - public static float MaxAbsDiff(float mean, ReadOnlySpan src) - { - Contracts.AssertNonEmpty(src); - - unsafe - { - fixed (float* psrc = &MemoryMarshal.GetReference(src)) - return Thunk.MaxAbsDiffU(mean, psrc, src.Length); - } - } - - public static float DotProductDense(ReadOnlySpan a, ReadOnlySpan b, int count) - { - Contracts.AssertNonEmpty(a); - Contracts.AssertNonEmpty(b); - Contracts.Assert(0 < count); - Contracts.Assert(a.Length >= count); - Contracts.Assert(b.Length >= count); - - unsafe - { - fixed (float* pa = &MemoryMarshal.GetReference(a)) - fixed (float* pb = &MemoryMarshal.GetReference(b)) - return Thunk.DotU(pa, pb, count); - } - } - - public static float DotProductSparse(ReadOnlySpan a, ReadOnlySpan b, ReadOnlySpan indices, int count) - { - Contracts.AssertNonEmpty(a); - Contracts.AssertNonEmpty(b); - Contracts.Assert(0 < count); - Contracts.Assert(count < a.Length); - Contracts.Assert(count <= b.Length); - Contracts.Assert(count <= indices.Length); - - unsafe - { - fixed (float* pa = &MemoryMarshal.GetReference(a)) - fixed (float* pb = &MemoryMarshal.GetReference(b)) - fixed (int* pi = &MemoryMarshal.GetReference(indices)) - return Thunk.DotSU(pa, pb, pi, count); - } - } - - public static float L2DistSquared(ReadOnlySpan a, ReadOnlySpan b, int count) - { - Contracts.AssertNonEmpty(a); - Contracts.AssertNonEmpty(b); - Contracts.Assert(0 < count && count <= a.Length); - Contracts.Assert(count <= b.Length); - - unsafe - { - fixed (float* pa = &MemoryMarshal.GetReference(a)) - fixed (float* pb = &MemoryMarshal.GetReference(b)) - return Thunk.Dist2(pa, pb, count); - } - } - - public static void ZeroMatrixItems(AlignedArray dst, int ccol, int cfltRow, int[] indices) - { - Contracts.Assert(0 < ccol && ccol <= cfltRow); - - unsafe - { - fixed (float* pdst = &dst.Items[0]) - fixed (int* pi = &indices[0]) - { - if (ccol == cfltRow) - Thunk.ZeroItemsU(Ptr(dst, pdst), dst.Size, pi, indices.Length); - else - Thunk.ZeroMatrixItemsCore(Ptr(dst, pdst), dst.Size, ccol, cfltRow, pi, indices.Length); - } - } - } - - public static void SdcaL1UpdateDense(float primalUpdate, int count, ReadOnlySpan src, float threshold, Span v, Span w) - { - Contracts.AssertNonEmpty(src); - Contracts.Assert(count <= src.Length); - Contracts.AssertNonEmpty(v); - Contracts.Assert(count <= v.Length); - Contracts.AssertNonEmpty(w); - Contracts.Assert(count <= w.Length); - Contracts.Assert(count > 0); - - unsafe - { - fixed (float* psrc = &MemoryMarshal.GetReference(src)) - fixed (float* pd1 = &MemoryMarshal.GetReference(v)) - fixed (float* pd2 = &MemoryMarshal.GetReference(w)) - Thunk.SdcaL1UpdateU(primalUpdate, psrc, threshold, pd1, pd2, count); - } - } - - public static void SdcaL1UpdateSparse(float primalUpdate, int count, ReadOnlySpan source, ReadOnlySpan indices, float threshold, Span v, Span w) - { - Contracts.AssertNonEmpty(source); - Contracts.Assert(count <= source.Length); - Contracts.AssertNonEmpty(indices); - Contracts.Assert(count <= indices.Length); - Contracts.AssertNonEmpty(v); - Contracts.Assert(count <= v.Length); - Contracts.AssertNonEmpty(w); - Contracts.Assert(count <= w.Length); - Contracts.Assert(count > 0); - - unsafe - { - fixed (float* psrc = &MemoryMarshal.GetReference(source)) - fixed (int* pi = &MemoryMarshal.GetReference(indices)) - fixed (float* pd1 = &MemoryMarshal.GetReference(v)) - fixed (float* pd2 = &MemoryMarshal.GetReference(w)) - Thunk.SdcaL1UpdateSU(primalUpdate, psrc, pi, threshold, pd1, pd2, count); - } - } - } -} \ No newline at end of file diff --git a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs index 89d6dbce03..44bf8abcaa 100644 --- a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs +++ b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs @@ -41,6 +41,26 @@ internal static class SseIntrinsics 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, }; + // The count of bytes in Vector128, corresponding to _cbAlign in AlignedArray + private const int Vector128Alignment = 16; + + [MethodImplAttribute(MethodImplOptions.AggressiveInlining)] + private static bool HasCompatibleAlignment(AlignedArray alignedArray) + { + Contracts.AssertValue(alignedArray); + Contracts.Assert(alignedArray.Size > 0); + return (alignedArray.CbAlign % Vector128Alignment) == 0; + } + + [MethodImplAttribute(MethodImplOptions.AggressiveInlining)] + private static unsafe float* GetAlignedBase(AlignedArray alignedArray, float* unalignedBase) + { + Contracts.AssertValue(alignedArray); + float* alignedBase = unalignedBase + alignedArray.GetBase((long)unalignedBase); + Contracts.Assert(((long)alignedBase & (Vector128Alignment - 1)) == 0); + return alignedBase; + } + internal static readonly Vector128 AbsMask128 = Sse2.IsSupported ? Sse.StaticCast(Sse2.SetAllVector128(0x7FFFFFFF)) : Sse.SetAllVector128(BitConverter.Int32BitsToSingle(0x7FFFFFFF)); @@ -118,17 +138,18 @@ internal static Vector128 GetNewDst128(in Vector128 xDst1, in Vect // Multiply matrix times vector into vector. public static unsafe void MatMul(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) { - MatMul(mat.Items, src.Items, dst.Items, crow, ccol); - } + Contracts.Assert(HasCompatibleAlignment(mat)); + Contracts.Assert(HasCompatibleAlignment(src)); + Contracts.Assert(HasCompatibleAlignment(dst)); - public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan src, Span dst, int crow, int ccol) - { - fixed (float* psrc = &MemoryMarshal.GetReference(src)) - fixed (float* pdst = &MemoryMarshal.GetReference(dst)) - fixed (float* pmat = &MemoryMarshal.GetReference(mat)) - fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0]) - fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0]) + fixed (float* pSrcStart = &src.Items[0]) + fixed (float* pDstStart = &dst.Items[0]) + fixed (float* pMatStart = &mat.Items[0]) { + float* psrc = GetAlignedBase(src, pSrcStart); + float* pdst = GetAlignedBase(dst, pDstStart); + float* pmat = GetAlignedBase(mat, pMatStart); + float* pSrcEnd = psrc + ccol; float* pDstEnd = pdst + crow; float* pDstCurrent = pdst; @@ -137,128 +158,29 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr while (pDstCurrent < pDstEnd) { Vector128 res0 = Sse.SetZeroVector128(); - Vector128 res1 = Sse.SetZeroVector128(); - Vector128 res2 = Sse.SetZeroVector128(); - Vector128 res3 = Sse.SetZeroVector128(); + Vector128 res1 = res0; + Vector128 res2 = res0; + Vector128 res3 = res0; - int length = ccol; float* pSrcCurrent = psrc; - nuint address = (nuint)(pMatCurrent); - int misalignment = (int)(address % 16); - int remainder = 0; - - if ((misalignment & 3) != 0) + while (pSrcCurrent < pSrcEnd) { - // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations - while (pSrcCurrent < pSrcEnd) - { - Vector128 vector = Sse.LoadVector128(pSrcCurrent); + float* pMatTemp = pMatCurrent; - float* pMatTemp = pMatCurrent; - Vector128 x01 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp)); - Vector128 x11 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp += ccol)); - Vector128 x21 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp += ccol)); - Vector128 x31 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp += ccol)); + Vector128 x01 = Sse.LoadAlignedVector128(pMatTemp); + Vector128 x11 = Sse.LoadAlignedVector128(pMatTemp += ccol); + Vector128 x21 = Sse.LoadAlignedVector128(pMatTemp += ccol); + Vector128 x31 = Sse.LoadAlignedVector128(pMatTemp += ccol); + Vector128 x02 = Sse.LoadAlignedVector128(pSrcCurrent); - res0 = Sse.Add(res0, x01); - res1 = Sse.Add(res1, x11); - res2 = Sse.Add(res2, x21); - res3 = Sse.Add(res3, x31); + res0 = Sse.Add(res0, Sse.Multiply(x01, x02)); + res1 = Sse.Add(res1, Sse.Multiply(x11, x02)); + res2 = Sse.Add(res2, Sse.Multiply(x21, x02)); + res3 = Sse.Add(res3, Sse.Multiply(x31, x02)); - pSrcCurrent += 4; - pMatCurrent += 4; - } - } - else - { - if (misalignment != 0) - { - // Handle cases where the data is not 128-bit aligned by doing an unaligned read and then - // masking any elements that will be included in the first aligned read - misalignment >>= 2; - misalignment = 4 - misalignment; - - Vector128 mask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + (misalignment * 4)); - - // We only align pMat since it has significantly more reads. - float* pMatTemp = pMatCurrent; - Vector128 x01 = Sse.And(mask, Sse.LoadVector128(pMatTemp)); - Vector128 x11 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); - Vector128 x21 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); - Vector128 x31 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); - Vector128 vector = Sse.And(mask, Sse.LoadVector128(pSrcCurrent)); - - res0 = Sse.Multiply(x01, vector); - res1 = Sse.Multiply(x11, vector); - res2 = Sse.Multiply(x21, vector); - res3 = Sse.Multiply(x31, vector); - - pMatCurrent += misalignment; - pSrcCurrent += misalignment; - length -= misalignment; - } - - if (length > 3) - { - // Handle all the 128-bit blocks that we can now that we have offset to an aligned address - remainder = length % 4; - - // If we aren't using the VEX-encoding, the JIT will only fold away aligned loads - // (due to semantics of the legacy encoding). - // We don't need an assert, since the instruction will throw for unaligned inputs. - while (pSrcCurrent + 4 <= pSrcEnd) - { - Vector128 vector = Sse.LoadVector128(pSrcCurrent); - - float* pMatTemp = pMatCurrent; - Vector128 x01 = Sse.Multiply(vector, Sse.LoadAlignedVector128(pMatTemp)); - Vector128 x11 = Sse.Multiply(vector, Sse.LoadAlignedVector128(pMatTemp += ccol)); - Vector128 x21 = Sse.Multiply(vector, Sse.LoadAlignedVector128(pMatTemp += ccol)); - Vector128 x31 = Sse.Multiply(vector, Sse.LoadAlignedVector128(pMatTemp += ccol)); - - res0 = Sse.Add(res0, x01); - res1 = Sse.Add(res1, x11); - res2 = Sse.Add(res2, x21); - res3 = Sse.Add(res3, x31); - - pSrcCurrent += 4; - pMatCurrent += 4; - } - } - else - { - // Handle the "worst-case" scenario, which is when we have 4-8 elements and the input is not - // 128-bit aligned. This means we can't do any aligned loads and will just end up doing two - // unaligned loads where we mask the input each time. - remainder = length; - } - - if (remainder != 0) - { - // Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next - // unaligned load will read to the end of the array and then mask out any elements already processed - - pMatCurrent -= (4 - remainder); - pSrcCurrent -= (4 - remainder); - - Vector128 mask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + (remainder * 4)); - - float* pMatTemp = pMatCurrent; - Vector128 x01 = Sse.And(mask, Sse.LoadVector128(pMatTemp)); - Vector128 x11 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); - Vector128 x21 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); - Vector128 x31 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); - Vector128 vector = Sse.And(mask, Sse.LoadVector128(pSrcCurrent)); - - res0 = Sse.Add(res0, Sse.Multiply(x01, vector)); - res1 = Sse.Add(res1, Sse.Multiply(x11, vector)); - res2 = Sse.Add(res2, Sse.Multiply(x21, vector)); - res3 = Sse.Add(res3, Sse.Multiply(x31, vector)); - - pMatCurrent += 4; - pSrcCurrent += 4; - } + pSrcCurrent += 4; + pMatCurrent += 4; } // Add up the entries of each, with the 4 results in res0 @@ -266,7 +188,8 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr res2 = Sse3.HorizontalAdd(res2, res3); res0 = Sse3.HorizontalAdd(res0, res2); - Sse.Store(pDstCurrent, res0); + Sse.StoreAligned(pDstCurrent, res0); + pDstCurrent += 4; pMatCurrent += 3 * ccol; } @@ -277,21 +200,23 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr public static unsafe void MatMulP(AlignedArray mat, ReadOnlySpan rgposSrc, AlignedArray src, int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow, int ccol) { - MatMulP(mat.Items, rgposSrc, src.Items, posMin, iposMin, iposEnd, dst.Items, crow, ccol); - } + // REVIEW: For extremely sparse inputs, interchanging the loops would + // likely be more efficient. + Contracts.Assert(HasCompatibleAlignment(mat)); + Contracts.Assert(HasCompatibleAlignment(src)); + Contracts.Assert(HasCompatibleAlignment(dst)); - public static unsafe void MatMulP(ReadOnlySpan mat, ReadOnlySpan rgposSrc, ReadOnlySpan src, - int posMin, int iposMin, int iposEnd, Span dst, int crow, int ccol) - { // REVIEW: For extremely sparse inputs, interchanging the loops would // likely be more efficient. - fixed (float* psrc = &MemoryMarshal.GetReference(src)) - fixed (float* pdst = &MemoryMarshal.GetReference(dst)) - fixed (float* pmat = &MemoryMarshal.GetReference(mat)) - fixed (int* pposSrc = &MemoryMarshal.GetReference(rgposSrc)) - fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0]) - fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0]) + fixed (float* pSrcStart = &src.Items[0]) + fixed (float* pDstStart = &dst.Items[0]) + fixed (float* pMatStart = &mat.Items[0]) + fixed (int* pposSrc = &rgposSrc[0]) { + float* psrc = GetAlignedBase(src, pSrcStart); + float* pdst = GetAlignedBase(dst, pDstStart); + float* pmat = GetAlignedBase(mat, pMatStart); + int* pposMin = pposSrc + iposMin; int* pposEnd = pposSrc + iposEnd; float* pDstEnd = pdst + crow; @@ -299,120 +224,7 @@ public static unsafe void MatMulP(ReadOnlySpan mat, ReadOnlySpan rgp float* pSrcCurrent = psrc - posMin; float* pDstCurrent = pdst; - nuint address = (nuint)(pDstCurrent); - int misalignment = (int)(address % 16); - - int length = crow; - int remainder = 0; - - if ((misalignment & 3) != 0) - { - // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations - while (pDstCurrent < pDstEnd) - { - Sse.Store(pDstCurrent, SparseMultiplicationAcrossRow()); - pDstCurrent += 4; - pm0 += 4 * ccol; - } - } - else - { - if (misalignment != 0) - { - // Handle cases where the data is not 128-bit aligned by doing an unaligned read and then - // masking any elements that will be included in the first aligned read - - misalignment >>= 2; - misalignment = 4 - misalignment; - - Vector128 mask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + (misalignment * 4)); - - float* pm1 = pm0 + ccol; - float* pm2 = pm1 + ccol; - float* pm3 = pm2 + ccol; - Vector128 result = Sse.SetZeroVector128(); - - int* ppos = pposMin; - - while (ppos < pposEnd) - { - int col = *ppos; - Vector128 x1 = Sse.SetVector128(pm3[col], pm2[col], pm1[col], pm0[col]); - - x1 = Sse.And(mask, x1); - Vector128 x2 = Sse.SetAllVector128(pSrcCurrent[col]); - x2 = Sse.Multiply(x2, x1); - result = Sse.Add(result, x2); - ppos++; - } - - Sse.Store(pDstCurrent, result); - pDstCurrent += misalignment; - pm0 += misalignment * ccol; - length -= misalignment; - } - - if (length > 3) - { - // Handle all the 128-bit blocks that we can now that we have offset to an aligned address - remainder = length % 4; - - // If we aren't using the VEX-encoding, the JIT will only fold away aligned loads - // (due to semantics of the legacy encoding). - // We don't need an assert, since the instruction will throw for unaligned inputs. - while (pDstCurrent < pDstEnd) - { - Sse.Store(pDstCurrent, SparseMultiplicationAcrossRow()); - pDstCurrent += 4; - pm0 += 4 * ccol; - } - } - else - { - // Handle the "worst-case" scenario, which is when we have 4-8 elements and the input is not - // 128-bit aligned. This means we can't do any aligned loads and will just end up doing two - // unaligned loads where we mask the input each time. - remainder = length; - } - - if (remainder != 0) - { - // Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next - // unaligned load will read to the end of the array and then mask out any elements already processed - pDstCurrent -= (4 - remainder); - pm0 -= (4 - remainder) * ccol; - - Vector128 trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + (remainder * 4)); - Vector128 leadingMask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + ((4 - remainder) * 4)); - - float* pm1 = pm0 + ccol; - float* pm2 = pm1 + ccol; - float* pm3 = pm2 + ccol; - Vector128 result = Sse.SetZeroVector128(); - - int* ppos = pposMin; - - while (ppos < pposEnd) - { - int col = *ppos; - Vector128 x1 = Sse.SetVector128(pm3[col], pm2[col], pm1[col], pm0[col]); - x1 = Sse.And(x1, trailingMask); - - Vector128 x2 = Sse.SetAllVector128(pSrcCurrent[col]); - x2 = Sse.Multiply(x2, x1); - result = Sse.Add(result, x2); - ppos++; - } - - result = Sse.Add(result, Sse.And(leadingMask, Sse.LoadVector128(pDstCurrent))); - - Sse.Store(pDstCurrent, result); - pDstCurrent += 4; - pm0 += 4 * ccol; - } - } - - Vector128 SparseMultiplicationAcrossRow() + while (pDstCurrent < pDstEnd) { float* pm1 = pm0 + ccol; float* pm2 = pm1 + ccol; @@ -428,310 +240,107 @@ Vector128 SparseMultiplicationAcrossRow() Vector128 x2 = Sse.SetAllVector128(pSrcCurrent[col]); x2 = Sse.Multiply(x2, x1); result = Sse.Add(result, x2); + ppos++; } - return result; + Sse.StoreAligned(pDstCurrent, result); + pDstCurrent += 4; + pm0 += 4 * ccol; } } } public static unsafe void MatMulTran(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) { - MatMulTran(mat.Items, src.Items, dst.Items, crow, ccol); - } + Contracts.Assert(HasCompatibleAlignment(mat)); + Contracts.Assert(HasCompatibleAlignment(src)); + Contracts.Assert(HasCompatibleAlignment(dst)); - public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan src, Span dst, int crow, int ccol) - { - fixed (float* psrc = &MemoryMarshal.GetReference(src)) - fixed (float* pdst = &MemoryMarshal.GetReference(dst)) - fixed (float* pmat = &MemoryMarshal.GetReference(mat)) - fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0]) - fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0]) + fixed (float* pSrcStart = &src.Items[0]) + fixed (float* pDstStart = &dst.Items[0]) + fixed (float* pMatStart = &mat.Items[0]) { + float* psrc = GetAlignedBase(src, pSrcStart); + float* pdst = GetAlignedBase(dst, pDstStart); + float* pmat = GetAlignedBase(mat, pMatStart); + float* pSrcEnd = psrc + ccol; float* pDstEnd = pdst + crow; float* pSrcCurrent = psrc; float* pMatCurrent = pmat; - // The reason behind adding the if condtion instead of boolean flag - // is to avoid branching in codegen. - if (pSrcCurrent < pSrcEnd) - { - Vector128 x01 = Sse.LoadVector128(pSrcCurrent); - // Replicate each 32-bit slot of x01 (ABCD) into its own register. - Vector128 x11 = Sse.Shuffle(x01, x01, 0x55); // B - Vector128 x21 = Sse.Shuffle(x01, x01, 0xAA); // C - Vector128 x31 = Sse.Shuffle(x01, x01, 0xFF); // D - x01 = Sse.Shuffle(x01, x01, 0x00); // A + Vector128 x01 = Sse.LoadAlignedVector128(pSrcCurrent); + // Replicate each 32-bit slot of x01 (ABCD) into its own register. + Vector128 x11 = Sse.Shuffle(x01, x01, 0x55); // B + Vector128 x21 = Sse.Shuffle(x01, x01, 0xAA); // C + Vector128 x31 = Sse.Shuffle(x01, x01, 0xFF); // D + x01 = Sse.Shuffle(x01, x01, 0x00); // A - int length = crow; - float* pDstCurrent = pdst; + pSrcCurrent += 4; - nuint address = (nuint)(pMatCurrent); - int misalignment = (int)(address % 16); + float* pDstCurrent = pdst; - if ((misalignment & 3) != 0) - { - // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations - while (pDstCurrent < pDstEnd) - { - float* pMatTemp = pMatCurrent; - Vector128 x02 = Sse.Multiply(x01, Sse.LoadVector128(pMatTemp)); - Vector128 x12 = Sse.Multiply(x11, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x22 = Sse.Multiply(x21, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x32 = Sse.Multiply(x31, Sse.LoadVector128(pMatTemp += crow)); - - x02 = Sse.Add(x02, x12); - x22 = Sse.Add(x22, x32); - x02 = Sse.Add(x02, x22); - - Sse.Store(pDstCurrent, x02); - pDstCurrent += 4; - pMatCurrent += 4; - } - } - else - { - int remainder = 0; - if (misalignment != 0) - { - // Handle cases where the data is not 128-bit aligned by doing an unaligned read and then - // masking any elements that will be included in the first aligned read - misalignment >>= 2; - misalignment = 4 - misalignment; - - Vector128 leadingMask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + (misalignment * 4)); - - // We only align pMat since it has significantly more reads. - float* pMatTemp = pMatCurrent; - Vector128 x02 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp)); - Vector128 x12 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x22 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x32 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); - - x02 = Sse.Multiply(x01, x02); - x12 = Sse.Multiply(x11, x12); - x22 = Sse.Multiply(x21, x22); - x32 = Sse.Multiply(x31, x32); - - x02 = Sse.Add(x02, x12); - x22 = Sse.Add(x22, x32); - x02 = Sse.Add(x02, x22); - - Vector128 trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + ((4 - misalignment) * 4)); - Vector128 x3 = Sse.LoadVector128(pDstCurrent); - x02 = Sse.Or(x02, Sse.And(x3, trailingMask)); - - Sse.Store(pDstCurrent, x02); - pMatCurrent += misalignment; - pDstCurrent += misalignment; - length -= misalignment; - } - if (length > 3) - { - // Handle all the 128-bit blocks that we can now that we have offset to an aligned address - remainder = length % 4; - while (pDstCurrent + 4 <= pDstEnd) - { - // If we aren't using the VEX-encoding, the JIT will only fold away aligned loads - // (due to semantics of the legacy encoding). - // We don't need an assert, since the instruction will throw for unaligned inputs. - float* pMatTemp = pMatCurrent; - - Vector128 x02 = Sse.Multiply(x01, Sse.LoadAlignedVector128(pMatTemp)); - Vector128 x12 = Sse.Multiply(x11, Sse.LoadAlignedVector128(pMatTemp += crow)); - Vector128 x22 = Sse.Multiply(x21, Sse.LoadAlignedVector128(pMatTemp += crow)); - Vector128 x32 = Sse.Multiply(x31, Sse.LoadAlignedVector128(pMatTemp += crow)); - - x02 = Sse.Add(x02, x12); - x22 = Sse.Add(x22, x32); - x02 = Sse.Add(x02, x22); - - Sse.Store(pDstCurrent, x02); - pDstCurrent += 4; - pMatCurrent += 4; - } - } - else - { - // Handle the "worst-case" scenario, which is when we have 4-8 elements and the input is not - // 128-bit aligned. This means we can't do any aligned loads and will just end up doing two - // unaligned loads where we mask the input each time. - remainder = length; - } + while (pDstCurrent < pDstEnd) + { + float* pMatTemp = pMatCurrent; + Vector128 x02 = Sse.LoadAlignedVector128(pMatTemp); + Vector128 x12 = Sse.LoadAlignedVector128(pMatTemp += crow); + Vector128 x22 = Sse.LoadAlignedVector128(pMatTemp += crow); + Vector128 x32 = Sse.LoadAlignedVector128(pMatTemp += crow); - if (remainder != 0) - { - // Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next - // unaligned load will read to the end of the array and then mask out any elements already processed - pMatCurrent -= (4 - remainder); - pDstCurrent -= (4 - remainder); - - Vector128 trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + (remainder * 4)); - - float* pMatTemp = pMatCurrent; - Vector128 x02 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp)); - Vector128 x12 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x22 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x32 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); - - x02 = Sse.Multiply(x01, x02); - x12 = Sse.Multiply(x11, x12); - x22 = Sse.Multiply(x21, x22); - x32 = Sse.Multiply(x31, x32); - - x02 = Sse.Add(x02, x12); - x22 = Sse.Add(x22, x32); - x02 = Sse.Add(x02, x22); - - Vector128 leadingMask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + ((4 - remainder) * 4)); - Vector128 x3 = Sse.LoadVector128(pDstCurrent); - x02 = Sse.Or(x02, Sse.And(x3, leadingMask)); - - Sse.Store(pDstCurrent, x02); - pDstCurrent += 4; - pMatCurrent += 4; - } - } + x02 = Sse.Multiply(x01, x02); + x12 = Sse.Multiply(x11, x12); + x22 = Sse.Multiply(x21, x22); + x32 = Sse.Multiply(x31, x32); - pMatCurrent += 3 * crow; - pSrcCurrent += 4; + x02 = Sse.Add(x02, x12); + x22 = Sse.Add(x22, x32); + x02 = Sse.Add(x02, x22); + + Sse.StoreAligned(pDstCurrent, x02); + + pDstCurrent += 4; + pMatCurrent += 4; } - // We do 4-way unrolling + pMatCurrent += 3 * crow; + while (pSrcCurrent < pSrcEnd) { - Vector128 x01 = Sse.LoadVector128(pSrcCurrent); + x01 = Sse.LoadAlignedVector128(pSrcCurrent); // Replicate each 32-bit slot of x01 (ABCD) into its own register. - Vector128 x11 = Sse.Shuffle(x01, x01, 0x55); // B - Vector128 x21 = Sse.Shuffle(x01, x01, 0xAA); // C - Vector128 x31 = Sse.Shuffle(x01, x01, 0xFF); // D + x11 = Sse.Shuffle(x01, x01, 0x55); // B + x21 = Sse.Shuffle(x01, x01, 0xAA); // C + x31 = Sse.Shuffle(x01, x01, 0xFF); // D x01 = Sse.Shuffle(x01, x01, 0x00); // A - int length = crow; - float* pDstCurrent = pdst; - - nuint address = (nuint)(pMatCurrent); - int misalignment = (int)(address % 16); + pDstCurrent = pdst; - if ((misalignment & 3) != 0) + while (pDstCurrent < pDstEnd) { - while (pDstCurrent < pDstEnd) - { - float* pMatTemp = pMatCurrent; - Vector128 x02 = Sse.Multiply(x01, Sse.LoadVector128(pMatTemp)); - Vector128 x12 = Sse.Multiply(x11, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x22 = Sse.Multiply(x21, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x32 = Sse.Multiply(x31, Sse.LoadVector128(pMatTemp += crow)); + float* pMatTemp = pMatCurrent; - x02 = Sse.Add(x02, x12); - x22 = Sse.Add(x22, x32); - x02 = Sse.Add(x02, x22); + Vector128 x02 = Sse.LoadAlignedVector128(pMatTemp); + Vector128 x12 = Sse.LoadAlignedVector128(pMatTemp += crow); + Vector128 x22 = Sse.LoadAlignedVector128(pMatTemp += crow); + Vector128 x32 = Sse.LoadAlignedVector128(pMatTemp += crow); + Vector128 x3 = Sse.LoadAlignedVector128(pDstCurrent); - x02 = Sse.Add(x02, Sse.LoadVector128(pDstCurrent)); + x02 = Sse.Multiply(x01, x02); + x12 = Sse.Multiply(x11, x12); + x22 = Sse.Multiply(x21, x22); + x32 = Sse.Multiply(x31, x32); - Sse.Store(pDstCurrent, x02); - pDstCurrent += 4; - pMatCurrent += 4; - } - } - else - { - int remainder = 0; - if (misalignment != 0) - { - // Handle cases where the data is not 128-bit aligned by doing an unaligned read and then - // masking any elements that will be included in the first aligned read - misalignment >>= 2; - misalignment = 4 - misalignment; - - Vector128 leadingMask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + (misalignment * 4)); - - // We only align pMat since it has significantly more reads. - float* pMatTemp = pMatCurrent; - Vector128 x02 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp)); - Vector128 x12 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x22 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x32 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); - - x02 = Sse.Multiply(x01, x02); - x12 = Sse.Multiply(x11, x12); - x22 = Sse.Multiply(x21, x22); - x32 = Sse.Multiply(x31, x32); - - x02 = Sse.Add(x02, x12); - x22 = Sse.Add(x22, x32); - x02 = Sse.Add(x02, x22); - - Vector128 trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + ((4 - misalignment) * 4)); - Vector128 x3 = Sse.LoadVector128(pDstCurrent); - x02 = Sse.Or(x02, Sse.And(x3, trailingMask)); - - x02 = Sse.Add(x02, Sse.And(x3, leadingMask)); - - Sse.Store(pDstCurrent, x02); - pMatCurrent += misalignment; - pDstCurrent += misalignment; - length -= misalignment; - } - if (length > 3) - { - remainder = length % 4; - while (pDstCurrent + 4 <= pDstEnd) - { - float* pMatTemp = pMatCurrent; - - Vector128 x02 = Sse.Multiply(x01, Sse.LoadAlignedVector128(pMatTemp)); - Vector128 x12 = Sse.Multiply(x11, Sse.LoadAlignedVector128(pMatTemp += crow)); - Vector128 x22 = Sse.Multiply(x21, Sse.LoadAlignedVector128(pMatTemp += crow)); - Vector128 x32 = Sse.Multiply(x31, Sse.LoadAlignedVector128(pMatTemp += crow)); - - x02 = Sse.Add(x02, x12); - x22 = Sse.Add(x22, x32); - x02 = Sse.Add(x02, x22); - - x02 = Sse.Add(x02, Sse.LoadVector128(pDstCurrent)); - Sse.Store(pDstCurrent, x02); - pDstCurrent += 4; - pMatCurrent += 4; - } - } - else - { - remainder = length; - } + x02 = Sse.Add(x02, x12); + x22 = Sse.Add(x22, x32); + x02 = Sse.Add(x02, x22); + x3 = Sse.Add(x02, x3); - if (remainder != 0) - { - pMatCurrent -= (4 - remainder); - pDstCurrent -= (4 - remainder); - Vector128 trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + (remainder * 4)); - - float* pMatTemp = pMatCurrent; - Vector128 x02 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp)); - Vector128 x12 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x22 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); - Vector128 x32 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); - - x02 = Sse.Multiply(x01, x02); - x12 = Sse.Multiply(x11, x12); - x22 = Sse.Multiply(x21, x22); - x32 = Sse.Multiply(x31, x32); - - x02 = Sse.Add(x02, x12); - x22 = Sse.Add(x22, x32); - x02 = Sse.Add(x02, x22); - - Vector128 leadingMask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + ((4 - remainder) * 4)); - Vector128 x3 = Sse.LoadVector128(pDstCurrent); - x02 = Sse.Or(x02, Sse.And(x3, leadingMask)); - - x02 = Sse.Add(x02, Sse.And(x3, trailingMask)); - Sse.Store(pDstCurrent, x02); - pDstCurrent += 4; - pMatCurrent += 4; - } + Sse.StoreAligned(pDstCurrent, x3); + + pDstCurrent += 4; + pMatCurrent += 4; } pMatCurrent += 3 * crow; diff --git a/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs b/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs index a473ccec29..b9534e8133 100644 --- a/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs +++ b/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs @@ -4,8 +4,6 @@ using System; using System.Collections.Generic; -using System.Globalization; -using System.Linq; using System.Numerics; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; @@ -290,8 +288,8 @@ public AdaptiveSingularSpectrumSequenceModeler(IHostEnvironment env, int trainSi _alpha = new Single[windowSize - 1]; _state = new Single[windowSize - 1]; - _x = new CpuAlignedVector(windowSize, SseUtils.CbAlign); - _xSmooth = new CpuAlignedVector(windowSize, SseUtils.CbAlign); + _x = new CpuAlignedVector(windowSize, CpuMathUtils.GetVectorAlignment()); + _xSmooth = new CpuAlignedVector(windowSize, CpuMathUtils.GetVectorAlignment()); ShouldComputeForecastIntervals = shouldComputeForecastIntervals; _observationNoiseVariance = 0; @@ -345,13 +343,13 @@ private AdaptiveSingularSpectrumSequenceModeler(AdaptiveSingularSpectrumSequence _state = new Single[_windowSize - 1]; Array.Copy(model._state, _state, _windowSize - 1); - _x = new CpuAlignedVector(_windowSize, SseUtils.CbAlign); - _xSmooth = new CpuAlignedVector(_windowSize, SseUtils.CbAlign); + _x = new CpuAlignedVector(_windowSize, CpuMathUtils.GetVectorAlignment()); + _xSmooth = new CpuAlignedVector(_windowSize, CpuMathUtils.GetVectorAlignment()); if (model._wTrans != null) { - _y = new CpuAlignedVector(_rank, SseUtils.CbAlign); - _wTrans = new CpuAlignedMatrixRow(_rank, _windowSize, SseUtils.CbAlign); + _y = new CpuAlignedVector(_rank, CpuMathUtils.GetVectorAlignment()); + _wTrans = new CpuAlignedMatrixRow(_rank, _windowSize, CpuMathUtils.GetVectorAlignment()); _wTrans.CopyFrom(model._wTrans); } } @@ -452,18 +450,18 @@ public AdaptiveSingularSpectrumSequenceModeler(IHostEnvironment env, ModelLoadCo { var tempArray = ctx.Reader.ReadFloatArray(); _host.CheckDecode(Utils.Size(tempArray) == _rank * _windowSize); - _wTrans = new CpuAlignedMatrixRow(_rank, _windowSize, SseUtils.CbAlign); + _wTrans = new CpuAlignedMatrixRow(_rank, _windowSize, CpuMathUtils.GetVectorAlignment()); int i = 0; _wTrans.CopyFrom(tempArray, ref i); tempArray = ctx.Reader.ReadFloatArray(); i = 0; - _y = new CpuAlignedVector(_rank, SseUtils.CbAlign); + _y = new CpuAlignedVector(_rank, CpuMathUtils.GetVectorAlignment()); _y.CopyFrom(tempArray, ref i); } _buffer = TimeSeriesUtils.DeserializeFixedSizeQueueSingle(ctx.Reader, _host); - _x = new CpuAlignedVector(_windowSize, SseUtils.CbAlign); - _xSmooth = new CpuAlignedVector(_windowSize, SseUtils.CbAlign); + _x = new CpuAlignedVector(_windowSize, CpuMathUtils.GetVectorAlignment()); + _xSmooth = new CpuAlignedVector(_windowSize, CpuMathUtils.GetVectorAlignment()); } public override void Save(ModelSaveContext ctx) @@ -1130,8 +1128,8 @@ internal override void Consume(ref Single input, bool updateModel = false) if (_wTrans == null) { - _y = new CpuAlignedVector(_rank, SseUtils.CbAlign); - _wTrans = new CpuAlignedMatrixRow(_rank, _windowSize, SseUtils.CbAlign); + _y = new CpuAlignedVector(_rank, CpuMathUtils.GetVectorAlignment()); + _wTrans = new CpuAlignedMatrixRow(_rank, _windowSize, CpuMathUtils.GetVectorAlignment()); Single[] vecs = new Single[_rank * _windowSize]; for (i = 0; i < _rank; ++i) @@ -1311,8 +1309,8 @@ private void TrainCore(Single[] dataArray, int originalSeriesLength) _maxRank = _windowSize / 2; _alpha = new Single[_windowSize - 1]; _state = new Single[_windowSize - 1]; - _x = new CpuAlignedVector(_windowSize, SseUtils.CbAlign); - _xSmooth = new CpuAlignedVector(_windowSize, SseUtils.CbAlign); + _x = new CpuAlignedVector(_windowSize, CpuMathUtils.GetVectorAlignment()); + _xSmooth = new CpuAlignedVector(_windowSize, CpuMathUtils.GetVectorAlignment()); TrainCore(dataArray, originalSeriesLength); return; @@ -1349,10 +1347,10 @@ private void TrainCore(Single[] dataArray, int originalSeriesLength) } // Setting the the y vector - _y = new CpuAlignedVector(_rank, SseUtils.CbAlign); + _y = new CpuAlignedVector(_rank, CpuMathUtils.GetVectorAlignment()); // Setting the weight matrix - _wTrans = new CpuAlignedMatrixRow(_rank, _windowSize, SseUtils.CbAlign); + _wTrans = new CpuAlignedMatrixRow(_rank, _windowSize, CpuMathUtils.GetVectorAlignment()); i = 0; _wTrans.CopyFrom(leftSingularVecs, ref i); diff --git a/src/Native/CpuMathNative/Sse.cpp b/src/Native/CpuMathNative/Sse.cpp index 0a2715d4d8..48fa6fb8a4 100644 --- a/src/Native/CpuMathNative/Sse.cpp +++ b/src/Native/CpuMathNative/Sse.cpp @@ -60,133 +60,31 @@ const unsigned int TrailingAlignmentMask[16] = // Multiply matrix times vector into vector. EXPORT_API(void) MatMul(_In_ const float * pmat, _In_ const float * psrc, _Inout_ float * pdst, int crow, int ccol) { - const float * pSrcEnd = psrc + ccol; - const float * pDstEnd = pdst + crow; - float* pDstCurrent = pdst; - const float* pMatCurrent = pmat; - - while (pDstCurrent < pDstEnd) + const float * psLim = psrc + ccol; + const float * pdLim = pdst + crow; + const float * pm = pmat; + for (float * pd = pdst; pd < pdLim; pd += 4, pm += 3 * ccol) { __m128 res0 = _mm_setzero_ps(); __m128 res1 = res0; __m128 res2 = res0; __m128 res3 = res0; - - int length = ccol; - const float* pSrcCurrent = psrc; - - uintptr_t address = (uintptr_t)(pMatCurrent); - uintptr_t misalignment = address % 16; - int remainder = 0; - - if ((misalignment & 3) != 0) - { - // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations - while (pSrcCurrent < pSrcEnd) - { - __m128 vector = _mm_loadu_ps(pSrcCurrent); - - const float* pMatTemp = pMatCurrent; - __m128 x01 = _mm_mul_ps(vector, _mm_loadu_ps(pMatTemp)); - __m128 x11 = _mm_mul_ps(vector, _mm_loadu_ps(pMatTemp += ccol)); - __m128 x21 = _mm_mul_ps(vector, _mm_loadu_ps(pMatTemp += ccol)); - __m128 x31 = _mm_mul_ps(vector, _mm_loadu_ps(pMatTemp += ccol)); - - res0 = _mm_add_ps(res0, x01); - res1 = _mm_add_ps(res1, x11); - res2 = _mm_add_ps(res2, x21); - res3 = _mm_add_ps(res3, x31); - - pSrcCurrent += 4; - pMatCurrent += 4; - } - } - else + for (const float * ps = psrc; ps < psLim; ps += 4, pm += 4) { - if (misalignment != 0) - { - // Handle cases where the data is not 128-bit aligned by doing an unaligned read and then - // masking any elements that will be included in the first aligned read - misalignment >>= 2; - misalignment = 4 - misalignment; - - __m128 mask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + (misalignment * 4)); - - // We only align pMat since it has significantly more reads. - const float* pMatTemp = pMatCurrent; - __m128 x01 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp)); - __m128 x11 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); - __m128 x21 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); - __m128 x31 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); - __m128 vector = _mm_and_ps(mask, _mm_loadu_ps(pSrcCurrent)); - - res0 = _mm_mul_ps(x01, vector); - res1 = _mm_mul_ps(x11, vector); - res2 = _mm_mul_ps(x21, vector); - res3 = _mm_mul_ps(x31, vector); - - pMatCurrent += misalignment; - pSrcCurrent += misalignment; - length -= misalignment; - } - - if (length > 3) - { - // Handle all the 128-bit blocks that we can now that we have offset to an aligned address - remainder = length % 4; - - while (pSrcCurrent + 4 <= pSrcEnd) - { - __m128 vector = _mm_loadu_ps(pSrcCurrent); - - const float* pMatTemp = pMatCurrent; - __m128 x01 = _mm_mul_ps(vector, _mm_load_ps(pMatTemp)); - __m128 x11 = _mm_mul_ps(vector, _mm_load_ps(pMatTemp += ccol)); - __m128 x21 = _mm_mul_ps(vector, _mm_load_ps(pMatTemp += ccol)); - __m128 x31 = _mm_mul_ps(vector, _mm_load_ps(pMatTemp += ccol)); - - res0 = _mm_add_ps(res0, x01); - res1 = _mm_add_ps(res1, x11); - res2 = _mm_add_ps(res2, x21); - res3 = _mm_add_ps(res3, x31); - - pSrcCurrent += 4; - pMatCurrent += 4; - } - } - else - { - // Handle the "worst-case" scenario, which is when we have 4-8 elements and the input is not - // 128-bit aligned. This means we can't do any aligned loads and will just end up doing two - // unaligned loads where we mask the input each time. - remainder = length; - } - - if (remainder != 0) - { - // Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next - // unaligned load will read to the end of the array and then mask out any elements already processed - - pMatCurrent -= (4 - remainder); - pSrcCurrent -= (4 - remainder); - - __m128 mask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + (remainder * 4)); - - const float* pMatTemp = pMatCurrent; - __m128 x01 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp)); - __m128 x11 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); - __m128 x21 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); - __m128 x31 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); - __m128 vector = _mm_and_ps(mask, _mm_loadu_ps(pSrcCurrent)); - - res0 = _mm_add_ps(res0, _mm_mul_ps(x01, vector)); - res1 = _mm_add_ps(res1, _mm_mul_ps(x11, vector)); - res2 = _mm_add_ps(res2, _mm_mul_ps(x21, vector)); - res3 = _mm_add_ps(res3, _mm_mul_ps(x31, vector)); - - pMatCurrent += 4; - pSrcCurrent += 4; - } + const float * pmTmp; + __m128 x01 = _mm_load_ps(pmTmp = pm); + __m128 x11 = _mm_load_ps(pmTmp += ccol); + __m128 x21 = _mm_load_ps(pmTmp += ccol); + __m128 x31 = _mm_load_ps(pmTmp += ccol); + __m128 x02 = _mm_load_ps(ps); + x01 = _mm_mul_ps(x01, x02); + x11 = _mm_mul_ps(x11, x02); + x21 = _mm_mul_ps(x21, x02); + x31 = _mm_mul_ps(x31, x02); + res0 = _mm_add_ps(res0, x01); + res1 = _mm_add_ps(res1, x11); + res2 = _mm_add_ps(res2, x21); + res3 = _mm_add_ps(res3, x31); } // Add up the entries of each, with the 4 results in res0 @@ -194,10 +92,7 @@ EXPORT_API(void) MatMul(_In_ const float * pmat, _In_ const float * psrc, _Inout res2 = _mm_hadd_ps(res2, res3); res0 = _mm_hadd_ps(res0, res2); - _mm_storeu_ps(pDstCurrent, res0); - - pDstCurrent += 4; - pMatCurrent += 3 * ccol; + _mm_store_ps(pd, res0); } } @@ -208,443 +103,90 @@ EXPORT_API(void) MatMulP(_In_ const float * pmat, _In_ const int * pposSrc, _In_ // REVIEW: For extremely sparse inputs, interchanging the loops would // likely be more efficient. const int * pposMin = pposSrc + iposMin; - const int * pposEnd = pposSrc + iposLim; - const float * pDstEnd = pdst + crow; + const int * pposLim = pposSrc + iposLim; + const float * pdLim = pdst + crow; const float * pm0 = pmat - posMin; - const float * pSrcCurrent = psrc - posMin; - float* pDstCurrent = pdst; - - uintptr_t address = (uintptr_t)(pDstCurrent); - uintptr_t misalignment = address % 16; - int length = crow; - int remainder = 0; - - if ((misalignment & 3) != 0) - { - // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations - while (pDstCurrent < pDstEnd) - { - const float* pm1 = pm0 + ccol; - const float* pm2 = pm1 + ccol; - const float* pm3 = pm2 + ccol; - - __m128 res = _mm_setzero_ps(); - const int* ppos = pposMin; - - while (ppos < pposEnd) - { - int col = *ppos; - __m128 x1 = _mm_setr_ps(pm0[col], pm1[col], pm2[col], pm3[col]); - __m128 x2 = _mm_set1_ps(pSrcCurrent[col]); - x2 = _mm_mul_ps(x2, x1); - res = _mm_add_ps(res, x2); - ppos++; - } - - _mm_storeu_ps(pDstCurrent, res); - pDstCurrent += 4; - pm0 += 4 * ccol; - } - } - else - { - if (misalignment != 0) - { - // Handle cases where the data is not 128-bit aligned by doing an unaligned read and then - // masking any elements that will be included in the first aligned read - misalignment >>= 2; - misalignment = 4 - misalignment; - - __m128 mask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + (misalignment * 4)); - - const float* pm1 = pm0 + ccol; - const float* pm2 = pm1 + ccol; - const float* pm3 = pm2 + ccol; - - __m128 res = _mm_setzero_ps(); - const int* ppos = pposMin; - - while (ppos < pposEnd) - { - int col = *ppos; - __m128 x1 = _mm_setr_ps(pm0[col], pm1[col], pm2[col], pm3[col]); - x1 = _mm_and_ps(mask, x1); - - __m128 x2 = _mm_set1_ps(pSrcCurrent[col]); - x2 = _mm_mul_ps(x2, x1); - res = _mm_add_ps(res, x2); - ppos++; - } - - _mm_storeu_ps(pDstCurrent, res); - pDstCurrent += misalignment; - pm0 += misalignment * ccol; - length -= misalignment; - } - - if (length > 3) + const float * ps = psrc - posMin; + for (float * pd = pdst; pd < pdLim; pd += 4, pm0 += 4 * ccol) + { + const float * pm1 = pm0 + ccol; + const float * pm2 = pm1 + ccol; + const float * pm3 = pm2 + ccol; + __m128 res = _mm_setzero_ps(); + for (const int * ppos = pposMin; ppos < pposLim; ppos++) { - // Handle all the 128-bit blocks that we can now that we have offset to an aligned address - remainder = length % 4; - while (pDstCurrent < pDstEnd) - { - const float* pm1 = pm0 + ccol; - const float* pm2 = pm1 + ccol; - const float* pm3 = pm2 + ccol; - - const int* ppos = pposMin; - __m128 res = _mm_setzero_ps(); - - while (ppos < pposEnd) - { - int col = *ppos; - __m128 x1 = _mm_setr_ps(pm0[col], pm1[col], pm2[col], pm3[col]); - __m128 x2 = _mm_set1_ps(pSrcCurrent[col]); - x2 = _mm_mul_ps(x2, x1); - res = _mm_add_ps(res, x2); - ppos++; - } - - _mm_store_ps(pDstCurrent, res); - pDstCurrent += 4; - pm0 += 4 * ccol; - } - } - else - { - // Handle the "worst-case" scenario, which is when we have 4-8 elements and the input is not - // 128-bit aligned. This means we can't do any aligned loads and will just end up doing two - // unaligned loads where we mask the input each time. - remainder = length; + int col = *ppos; + __m128 x1 = _mm_setr_ps(pm0[col], pm1[col], pm2[col], pm3[col]); + __m128 x2 = _mm_set1_ps(ps[col]); + x2 = _mm_mul_ps(x2, x1); + res = _mm_add_ps(res, x2); } - if (remainder != 0) - { - // Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next - // unaligned load will read to the end of the array and then mask out any elements already processed - - pDstCurrent -= (4 - remainder); - pm0 -= (4 - remainder) * ccol; - - __m128 trailingMask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + (remainder * 4)); - __m128 leadingMask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + ((4 - remainder) * 4)); - - const float* pm1 = pm0 + ccol; - const float* pm2 = pm1 + ccol; - const float* pm3 = pm2 + ccol; - - const int* ppos = pposMin; - __m128 res = _mm_setzero_ps(); - - while (ppos < pposEnd) - { - int col = *ppos; - __m128 x1 = _mm_setr_ps(pm0[col], pm1[col], pm2[col], pm3[col]); - x1 = _mm_and_ps(x1, trailingMask); - - __m128 x2 = _mm_set1_ps(pSrcCurrent[col]); - x2 = _mm_mul_ps(x2, x1); - res = _mm_add_ps(res, x2); - ppos++; - } - - res = _mm_add_ps(res, _mm_and_ps(leadingMask, _mm_loadu_ps(pDstCurrent))); - _mm_storeu_ps(pDstCurrent, res); - pDstCurrent += 4; - pm0 += 4 * ccol; - } + _mm_store_ps(pd, res); } } EXPORT_API(void) MatMulTran(_In_ const float * pmat, _In_ const float * psrc, _Inout_ float * pdst, int crow, int ccol) { - const float * pSrcEnd = psrc + ccol; - const float * pDstEnd = pdst + crow; - - const float* pMatCurrent = pmat; - const float* pSrcCurrent = psrc; - - if (pSrcCurrent < pSrcEnd) - { - __m128 x01 = _mm_loadu_ps(pSrcCurrent); - // Replicate each slot of x01 into its own register. - __m128 x11 = _mm_shuffle_ps(x01, x01, 0x55); - __m128 x21 = _mm_shuffle_ps(x01, x01, 0xAA); - __m128 x31 = _mm_shuffle_ps(x01, x01, 0xFF); - x01 = _mm_shuffle_ps(x01, x01, 0x00); - - int length = crow; - float* pDstCurrent = pdst; - - uintptr_t address = (uintptr_t)(pMatCurrent); - uintptr_t misalignment = address % 16; - int remainder = 0; - - if ((misalignment & 3) != 0) - { - // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations - while (pDstCurrent < pDstEnd) - { - const float* pMatTemp = pMatCurrent; - __m128 x02 = _mm_mul_ps(x01, _mm_loadu_ps(pMatTemp)); - __m128 x12 = _mm_mul_ps(x11, _mm_loadu_ps(pMatTemp += crow)); - __m128 x22 = _mm_mul_ps(x21, _mm_loadu_ps(pMatTemp += crow)); - __m128 x32 = _mm_mul_ps(x31, _mm_loadu_ps(pMatTemp += crow)); - - x02 = _mm_add_ps(x02, x12); - x22 = _mm_add_ps(x22, x32); - x02 = _mm_add_ps(x02, x22); - - _mm_storeu_ps(pDstCurrent, x02); - pDstCurrent += 4; - pMatCurrent += 4; - } - } - else - { - int remainder = 0; - if (misalignment != 0) - { - // Handle cases where the data is not 128-bit aligned by doing an unaligned read and then - // masking any elements that will be included in the first aligned read - misalignment >>= 2; - misalignment = 4 - misalignment; - - __m128 leadingMask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + (misalignment * 4)); - - // We only align pMat since it has significantly more reads. - const float* pMatTemp = pMatCurrent; - __m128 x02 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp)); - __m128 x12 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp += crow)); - __m128 x22 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp += crow)); - __m128 x32 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp += crow)); - - x02 = _mm_mul_ps(x01, x02); - x12 = _mm_mul_ps(x11, x12); - x22 = _mm_mul_ps(x21, x22); - x32 = _mm_mul_ps(x31, x32); - - x02 = _mm_add_ps(x02, x12); - x22 = _mm_add_ps(x22, x32); - x02 = _mm_add_ps(x02, x22); - - __m128 trailingMask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + ((4 - misalignment) * 4)); - __m128 x3 = _mm_loadu_ps(pDstCurrent); - x02 = _mm_or_ps(x02, _mm_and_ps(x3, trailingMask)); - - _mm_storeu_ps(pDstCurrent, x02); - pMatCurrent += misalignment; - pDstCurrent += misalignment; - length -= misalignment; - } - - if (length > 3) - { - // Handle all the 128-bit blocks that we can now that we have offset to an aligned address - remainder = length % 4; - - while (pDstCurrent + 4 <= pDstEnd) - { - const float* pMatTemp = pMatCurrent; - __m128 x02 = _mm_mul_ps(x01, _mm_load_ps(pMatTemp)); - __m128 x12 = _mm_mul_ps(x11, _mm_load_ps(pMatTemp += crow)); - __m128 x22 = _mm_mul_ps(x21, _mm_load_ps(pMatTemp += crow)); - __m128 x32 = _mm_mul_ps(x31, _mm_load_ps(pMatTemp += crow)); - - x02 = _mm_add_ps(x02, x12); - x22 = _mm_add_ps(x22, x32); - x02 = _mm_add_ps(x02, x22); - - _mm_storeu_ps(pDstCurrent, x02); - - pDstCurrent += 4; - pMatCurrent += 4; - } - } - else - { - // Handle the "worst-case" scenario, which is when we have 8-16 elements and the input is not - // 128-bit aligned. This means we can't do any aligned loads and will just end up doing two - // unaligned loads where we mask the input each time. - remainder = length; - } - - if (remainder != 0) - { - // Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next - // unaligned load will read to the end of the array and then mask out any elements already processed - - pMatCurrent -= (4 - remainder); - pDstCurrent -= (4 - remainder); - - __m128 trailingMask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + (remainder * 4)); - - const float* pMatTemp = pMatCurrent; - __m128 x02 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp)); - __m128 x12 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp += crow)); - __m128 x22 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp += crow)); - __m128 x32 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp += crow)); - - x02 = _mm_mul_ps(x01, x02); - x12 = _mm_mul_ps(x11, x12); - x22 = _mm_mul_ps(x21, x22); - x32 = _mm_mul_ps(x31, x32); - - x02 = _mm_add_ps(x02, x12); - x22 = _mm_add_ps(x22, x32); - x02 = _mm_add_ps(x02, x22); - - __m128 leadingMask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + ((4 - remainder) * 4)); - __m128 x3 = _mm_loadu_ps(pDstCurrent); - x02 = _mm_or_ps(x02, _mm_and_ps(x3, leadingMask)); - - _mm_storeu_ps(pDstCurrent, x02); - pMatCurrent += 4; - pDstCurrent += 4; - } - } - - pMatCurrent += 3 * crow; - pSrcCurrent += 4; - } - - while (pSrcCurrent < pSrcEnd) - { - __m128 x01 = _mm_loadu_ps(pSrcCurrent); + const float * psLim = psrc + ccol; + const float * pdLim = pdst + crow; + const float * pm = pmat; + const float * ps = psrc; + + __m128 x01 = _mm_load_ps(ps); + // Replicate each slot of x01 into its own register. + __m128 x11 = _mm_shuffle_ps(x01, x01, 0x55); + __m128 x21 = _mm_shuffle_ps(x01, x01, 0xAA); + __m128 x31 = _mm_shuffle_ps(x01, x01, 0xFF); + x01 = _mm_shuffle_ps(x01, x01, 0x00); + ps += 4; + for (float * pd = pdst; pd < pdLim; pd += 4, pm += 4) + { + const float * pmTmp; + __m128 x02 = _mm_load_ps(pmTmp = pm); + __m128 x12 = _mm_load_ps(pmTmp += crow); + __m128 x22 = _mm_load_ps(pmTmp += crow); + __m128 x32 = _mm_load_ps(pmTmp += crow); + x02 = _mm_mul_ps(x01, x02); + x12 = _mm_mul_ps(x11, x12); + x22 = _mm_mul_ps(x21, x22); + x32 = _mm_mul_ps(x31, x32); + x02 = _mm_add_ps(x02, x12); + x22 = _mm_add_ps(x22, x32); + x02 = _mm_add_ps(x02, x22); + _mm_store_ps(pd, x02); + } + + pm += 3 * crow; + + for (; ps < psLim; ps += 4) + { + __m128 x01 = _mm_load_ps(ps); // Replicate each slot of x01 into its own register. __m128 x11 = _mm_shuffle_ps(x01, x01, 0x55); __m128 x21 = _mm_shuffle_ps(x01, x01, 0xAA); __m128 x31 = _mm_shuffle_ps(x01, x01, 0xFF); x01 = _mm_shuffle_ps(x01, x01, 0x00); - - int length = crow; - float* pDstCurrent = pdst; - - uintptr_t address = (uintptr_t)(pMatCurrent); - uintptr_t misalignment = address % 16; - int remainder = 0; - - if ((misalignment & 3) != 0) + for (float * pd = pdst; pd < pdLim; pd += 4, pm += 4) { - while (pDstCurrent < pDstEnd) - { - const float* pMatTemp = pMatCurrent; - __m128 x02 = _mm_mul_ps(x01, _mm_loadu_ps(pMatTemp)); - __m128 x12 = _mm_mul_ps(x11, _mm_loadu_ps(pMatTemp += crow)); - __m128 x22 = _mm_mul_ps(x21, _mm_loadu_ps(pMatTemp += crow)); - __m128 x32 = _mm_mul_ps(x31, _mm_loadu_ps(pMatTemp += crow)); - - x02 = _mm_add_ps(x02, x12); - x22 = _mm_add_ps(x22, x32); - x02 = _mm_add_ps(x02, x22); - - x02 = _mm_add_ps(x02, _mm_loadu_ps(pDstCurrent)); - - _mm_storeu_ps(pDstCurrent, x02); - pDstCurrent += 4; - pMatCurrent += 4; - } - } - else - { - int remainder = 0; - if (misalignment != 0) - { - misalignment >>= 2; - misalignment = 4 - misalignment; - - __m128 leadingMask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + (misalignment * 4)); - - // We only align pMat since it has significantly more reads. - const float* pMatTemp = pMatCurrent; - __m128 x02 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp)); - __m128 x12 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp += crow)); - __m128 x22 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp += crow)); - __m128 x32 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp += crow)); - - x02 = _mm_mul_ps(x01, x02); - x12 = _mm_mul_ps(x11, x12); - x22 = _mm_mul_ps(x21, x22); - x32 = _mm_mul_ps(x31, x32); - - x02 = _mm_add_ps(x02, x12); - x22 = _mm_add_ps(x22, x32); - x02 = _mm_add_ps(x02, x22); - - __m128 trailingMask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + ((4 - misalignment) * 4)); - __m128 x3 = _mm_loadu_ps(pDstCurrent); - x02 = _mm_or_ps(x02, _mm_and_ps(x3, trailingMask)); - x02 = _mm_add_ps(x02, _mm_and_ps(x3, leadingMask)); - - _mm_storeu_ps(pDstCurrent, x02); - pMatCurrent += misalignment; - pDstCurrent += misalignment; - length -= misalignment; - } - - if (length > 3) - { - remainder = length % 4; - while (pDstCurrent + 4 <= pDstEnd) - { - const float* pMatTemp = pMatCurrent; - __m128 x02 = _mm_mul_ps(x01, _mm_load_ps(pMatTemp)); - __m128 x12 = _mm_mul_ps(x11, _mm_load_ps(pMatTemp += crow)); - __m128 x22 = _mm_mul_ps(x21, _mm_load_ps(pMatTemp += crow)); - __m128 x32 = _mm_mul_ps(x31, _mm_load_ps(pMatTemp += crow)); - - x02 = _mm_add_ps(x02, x12); - x22 = _mm_add_ps(x22, x32); - x02 = _mm_add_ps(x02, x22); - - x02 = _mm_add_ps(x02, _mm_loadu_ps(pDstCurrent)); - - _mm_storeu_ps(pDstCurrent, x02); - - pDstCurrent += 4; - pMatCurrent += 4; - } - } - else - { - remainder = length; - } - - if (remainder != 0) - { - pMatCurrent -= (4 - remainder); - pDstCurrent -= (4 - remainder); - - __m128 trailingMask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + (remainder * 4)); - - const float* pMatTemp = pMatCurrent; - __m128 x02 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp)); - __m128 x12 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp += crow)); - __m128 x22 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp += crow)); - __m128 x32 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp += crow)); - - x02 = _mm_mul_ps(x01, x02); - x12 = _mm_mul_ps(x11, x12); - x22 = _mm_mul_ps(x21, x22); - x32 = _mm_mul_ps(x31, x32); - - x02 = _mm_add_ps(x02, x12); - x22 = _mm_add_ps(x22, x32); - x02 = _mm_add_ps(x02, x22); - - __m128 leadingMask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + ((4 - remainder) * 4)); - __m128 x3 = _mm_loadu_ps(pDstCurrent); - x02 = _mm_or_ps(x02, _mm_and_ps(x3, leadingMask)); - - x02 = _mm_add_ps(x02, _mm_and_ps(x3, trailingMask)); - _mm_storeu_ps(pDstCurrent, x02); - pMatCurrent += 4; - pDstCurrent += 4; - } + const float * pmTmp; + __m128 x02 = _mm_load_ps(pmTmp = pm); + __m128 x12 = _mm_load_ps(pmTmp += crow); + __m128 x22 = _mm_load_ps(pmTmp += crow); + __m128 x32 = _mm_load_ps(pmTmp += crow); + __m128 x3 = _mm_load_ps(pd); + x02 = _mm_mul_ps(x01, x02); + x12 = _mm_mul_ps(x11, x12); + x22 = _mm_mul_ps(x21, x22); + x32 = _mm_mul_ps(x31, x32); + x02 = _mm_add_ps(x02, x12); + x22 = _mm_add_ps(x22, x32); + x02 = _mm_add_ps(x02, x22); + x3 = _mm_add_ps(x02, x3); + _mm_store_ps(pd, x3); } - pMatCurrent += 3 * crow; - pSrcCurrent += 4; + pm += 3 * crow; } } diff --git a/test/Microsoft.ML.CpuMath.PerformanceTests/AvxPerformanceTests.cs b/test/Microsoft.ML.CpuMath.PerformanceTests/AvxPerformanceTests.cs index 43efecd89c..50a3b06fbe 100644 --- a/test/Microsoft.ML.CpuMath.PerformanceTests/AvxPerformanceTests.cs +++ b/test/Microsoft.ML.CpuMath.PerformanceTests/AvxPerformanceTests.cs @@ -11,6 +11,8 @@ namespace Microsoft.ML.CpuMath.PerformanceTests { public class AvxPerformanceTests : PerformanceTests { + protected override int align { get; set; } = 32; + [Benchmark] public void AddScalarU() => AvxIntrinsics.AddScalarU(DefaultScale, new Span(dst, 0, Length)); @@ -112,15 +114,15 @@ public void SdcaL1UpdateSU() [Benchmark] [BenchmarkCategory("Fma")] public void MatMul() - => AvxIntrinsics.MatMul(src, src1, dst, 1000, 1000); - + => AvxIntrinsics.MatMul(testMatrixAligned, testSrcVectorAligned, testDstVectorAligned, matrixLength, matrixLength); + [Benchmark] public void MatMulTran() - => AvxIntrinsics.MatMulTran(src, src1, dst, 1000, 1000); + => AvxIntrinsics.MatMulTran(testMatrixAligned, testSrcVectorAligned, testDstVectorAligned, matrixLength, matrixLength); [Benchmark] [BenchmarkCategory("Fma")] public void MatMulP() - => AvxIntrinsics.MatMulP(src, matrixIdx, src1, 0, 0, MatrixIndexLength, dst, 1000, 1000); + => AvxIntrinsics.MatMulP(testMatrixAligned, matrixIdx, testSrcVectorAligned, 0, 0, MatrixIndexLength, testDstVectorAligned, matrixLength, matrixLength); } } diff --git a/test/Microsoft.ML.CpuMath.PerformanceTests/CpuMathNativeUtils.cs b/test/Microsoft.ML.CpuMath.PerformanceTests/CpuMathNativeUtils.cs deleted file mode 100644 index 8624ba90e9..0000000000 --- a/test/Microsoft.ML.CpuMath.PerformanceTests/CpuMathNativeUtils.cs +++ /dev/null @@ -1,91 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -// The exported function names need to be unique (can't be disambiguated based on signature), hence -// we introduce suffix letters to indicate the general patterns used. -// * A suffix means aligned and padded for SSE operations. -// * U suffix means unaligned and unpadded. -// * S suffix means sparse (unaligned) vector. -// * P suffix means sparse (unaligned) partial vector - the vector is only part of a larger sparse vector. -// * R suffix means sparse matrix. -// * C suffix means convolution matrix. -// * D suffix means convolution matrix, with implicit source padding. -// * Tran means the matrix is transposed. - -using System.Runtime.InteropServices; -using System.Security; - -namespace Microsoft.ML.CpuMath.PerformanceTests -{ - internal static class CpuMathNativeUtils - { - internal const string NativePath = "CpuMathNative"; - - [DllImport(NativePath, EntryPoint = "AddScalarU"), SuppressUnmanagedCodeSecurity] - internal static extern unsafe float AddScalarU(float a, /*_Inout_*/ float* pd, int c); - - [DllImport(NativePath, EntryPoint = "Scale"), SuppressUnmanagedCodeSecurity] - internal static extern unsafe void Scale(float a, /*_Inout_*/ float* pd, int c); - - [DllImport(NativePath, EntryPoint = "ScaleSrcU"), SuppressUnmanagedCodeSecurity] - internal static extern unsafe void ScaleSrcU(float a, /*_In_ const*/ float* ps, /*_Inout_*/ float* pd, int c); - - [DllImport(NativePath, EntryPoint = "ScaleAddU"), SuppressUnmanagedCodeSecurity] - internal static extern unsafe void ScaleAddU(float a, float b, /*_Inout_*/ float* pd, int c); - - [DllImport(NativePath, EntryPoint = "AddScaleU"), SuppressUnmanagedCodeSecurity] - internal static extern unsafe void AddScaleU(float a, /*_In_ const*/ float* ps, /*_Inout_*/ float* pd, int c); - - [DllImport(NativePath, EntryPoint = "AddScaleSU"), SuppressUnmanagedCodeSecurity] - internal static extern unsafe void AddScaleSU(float a, /*_In_ const*/ float* ps, /*_In_ const*/ int* pi, /*_Inout_*/ float* pd, int c); - - [DllImport(NativePath, EntryPoint = "AddScaleCopyU"), SuppressUnmanagedCodeSecurity] - internal static extern unsafe void AddScaleCopyU(float a, /*_In_ const*/ float* ps, /*_In_ const*/ float* pd, /*_Inout_*/ float* pr, int c); - - [DllImport(NativePath, EntryPoint = "AddU"), SuppressUnmanagedCodeSecurity] - internal static extern unsafe void AddU(/*_In_ const*/ float* ps, /*_Inout_*/ float* pd, int c); - - [DllImport(NativePath, EntryPoint = "AddSU"), SuppressUnmanagedCodeSecurity] - internal static extern unsafe void AddSU(/*_In_ const*/ float* ps, /*_In_ const*/ int* pi, /*_Inout_*/ float* pd, int c); - - [DllImport(NativePath, EntryPoint = "MulElementWiseU"), SuppressUnmanagedCodeSecurity] - internal static extern unsafe void MulElementWiseU(/*_In_ const*/ float* ps1, /*_In_ const*/ float* ps2, /*_Inout_*/ float* pd, int c); - - [DllImport(NativePath, EntryPoint = "Sum"), SuppressUnmanagedCodeSecurity] - internal static extern unsafe float Sum(/*const*/ float* pValues, int length); - - [DllImport(NativePath, EntryPoint = "SumSqU"), SuppressUnmanagedCodeSecurity] - internal static extern unsafe float SumSqU(/*const*/ float* ps, int c); - - [DllImport(NativePath, EntryPoint = "SumSqDiffU"), SuppressUnmanagedCodeSecurity] - internal static extern unsafe float SumSqDiffU(float mean, /*const*/ float* ps, int c); - - [DllImport(NativePath, EntryPoint = "SumAbsU"), SuppressUnmanagedCodeSecurity] - internal static extern unsafe float SumAbsU(/*const*/ float* ps, int c); - - [DllImport(NativePath, EntryPoint = "SumAbsDiffU"), SuppressUnmanagedCodeSecurity] - internal static extern unsafe float SumAbsDiffU(float mean, /*const*/ float* ps, int c); - - [DllImport(NativePath, EntryPoint = "MaxAbsU"), SuppressUnmanagedCodeSecurity] - internal static extern unsafe float MaxAbsU(/*const*/ float* ps, int c); - - [DllImport(NativePath, EntryPoint = "MaxAbsDiffU"), SuppressUnmanagedCodeSecurity] - internal static extern unsafe float MaxAbsDiffU(float mean, /*const*/ float* ps, int c); - - [DllImport(NativePath, EntryPoint = "DotU"), SuppressUnmanagedCodeSecurity] - internal static extern unsafe float DotU(/*const*/ float* pa, /*const*/ float* pb, int c); - - [DllImport(NativePath, EntryPoint = "DotSU"), SuppressUnmanagedCodeSecurity] - internal static extern unsafe float DotSU(/*const*/ float* pa, /*const*/ float* pb, /*const*/ int* pi, int c); - - [DllImport(NativePath, EntryPoint = "Dist2"), SuppressUnmanagedCodeSecurity] - internal static extern unsafe float Dist2(/*const*/ float* px, /*const*/ float* py, int c); - - [DllImport(NativePath, EntryPoint = "SdcaL1UpdateU"), SuppressUnmanagedCodeSecurity] - internal static extern unsafe void SdcaL1UpdateU(float primalUpdate, /*_In_ const*/ float* ps, float threshold, /*_Inout_*/ float* pd1, /*_Inout_*/ float* pd2, int c); - - [DllImport(NativePath, EntryPoint = "SdcaL1UpdateSU"), SuppressUnmanagedCodeSecurity] - internal static extern unsafe void SdcaL1UpdateSU(float primalUpdate, /*_In_ const*/ float* ps, /*_In_ const*/ int* pi, float threshold, /*_Inout_*/ float* pd1, /*_Inout_*/ float* pd2, int c); - } -} diff --git a/test/Microsoft.ML.CpuMath.PerformanceTests/NativePerformanceTests.cs b/test/Microsoft.ML.CpuMath.PerformanceTests/NativePerformanceTests.cs index fe1a3a8386..92c0cc86db 100644 --- a/test/Microsoft.ML.CpuMath.PerformanceTests/NativePerformanceTests.cs +++ b/test/Microsoft.ML.CpuMath.PerformanceTests/NativePerformanceTests.cs @@ -6,58 +6,69 @@ using BenchmarkDotNet.Attributes; using BenchmarkDotNet.Running; using Microsoft.ML.Runtime.Internal.CpuMath; +using Microsoft.ML.Runtime.Internal.CpuMath.Core; namespace Microsoft.ML.CpuMath.PerformanceTests { public class NativePerformanceTests : PerformanceTests { + private const int CbAlign = 16; + + private static unsafe float* Ptr(AlignedArray a, float* p) + { + Contracts.AssertValue(a); + float* q = p + a.GetBase((long)p); + Contracts.Assert(((long)q & (CbAlign - 1)) == 0); + return q; + } + [Benchmark] public unsafe void AddScalarU() { fixed (float* pdst = dst) { - CpuMathNativeUtils.AddScalarU(DefaultScale, pdst, Length); + Thunk.AddScalarU(DefaultScale, pdst, Length); } } - + [Benchmark] public unsafe void Scale() { fixed (float* pdst = dst) { - CpuMathNativeUtils.Scale(DefaultScale, pdst, Length); + Thunk.Scale(DefaultScale, pdst, Length); } } - + [Benchmark] public unsafe void ScaleSrcU() { fixed (float* psrc = src) fixed (float* pdst = dst) { - CpuMathNativeUtils.ScaleSrcU(DefaultScale, psrc, pdst, Length); + Thunk.ScaleSrcU(DefaultScale, psrc, pdst, Length); } } - + [Benchmark] public unsafe void ScaleAddU() { fixed (float* pdst = dst) { - CpuMathNativeUtils.ScaleAddU(DefaultScale, DefaultScale, pdst, Length); + Thunk.ScaleAddU(DefaultScale, DefaultScale, pdst, Length); } } - + [Benchmark] public unsafe void AddScaleU() { fixed (float* psrc = src) fixed (float* pdst = dst) { - CpuMathNativeUtils.AddScaleU(DefaultScale, psrc, pdst, Length); + Thunk.AddScaleU(DefaultScale, psrc, pdst, Length); } } - + [Benchmark] public unsafe void AddScaleSU() { @@ -65,10 +76,10 @@ public unsafe void AddScaleSU() fixed (float* pdst = dst) fixed (int* pidx = idx) { - CpuMathNativeUtils.AddScaleSU(DefaultScale, psrc, pidx, pdst, IndexLength); + Thunk.AddScaleSU(DefaultScale, psrc, pidx, pdst, IndexLength); } } - + [Benchmark] public unsafe void AddScaleCopyU() { @@ -76,20 +87,20 @@ public unsafe void AddScaleCopyU() fixed (float* pdst = dst) fixed (float* pres = result) { - CpuMathNativeUtils.AddScaleCopyU(DefaultScale, psrc, pdst, pres, Length); + Thunk.AddScaleCopyU(DefaultScale, psrc, pdst, pres, Length); } } - + [Benchmark] public unsafe void AddU() { fixed (float* psrc = src) fixed (float* pdst = dst) { - CpuMathNativeUtils.AddU(psrc, pdst, Length); + Thunk.AddU(psrc, pdst, Length); } } - + [Benchmark] public unsafe void AddSU() { @@ -97,10 +108,10 @@ public unsafe void AddSU() fixed (float* pdst = dst) fixed (int* pidx = idx) { - CpuMathNativeUtils.AddSU(psrc, pidx, pdst, IndexLength); + Thunk.AddSU(psrc, pidx, pdst, IndexLength); } } - + [Benchmark] public unsafe void MulElementWiseU() { @@ -108,83 +119,83 @@ public unsafe void MulElementWiseU() fixed (float* psrc2 = src2) fixed (float* pdst = dst) { - CpuMathNativeUtils.MulElementWiseU(psrc1, psrc2, pdst, Length); + Thunk.MulElementWiseU(psrc1, psrc2, pdst, Length); } } - + [Benchmark] public unsafe float Sum() { fixed (float* psrc = src) { - return CpuMathNativeUtils.Sum(psrc, Length); + return Thunk.Sum(psrc, Length); } } - + [Benchmark] public unsafe float SumSqU() { fixed (float* psrc = src) { - return CpuMathNativeUtils.SumSqU(psrc, Length); + return Thunk.SumSqU(psrc, Length); } } - + [Benchmark] public unsafe float SumSqDiffU() { fixed (float* psrc = src) { - return CpuMathNativeUtils.SumSqDiffU(DefaultScale, psrc, Length); + return Thunk.SumSqDiffU(DefaultScale, psrc, Length); } } - + [Benchmark] public unsafe float SumAbsU() { fixed (float* psrc = src) { - return CpuMathNativeUtils.SumAbsU(psrc, Length); + return Thunk.SumAbsU(psrc, Length); } } - + [Benchmark] public unsafe float SumAbsDiffU() { fixed (float* psrc = src) { - return CpuMathNativeUtils.SumAbsDiffU(DefaultScale, psrc, Length); + return Thunk.SumAbsDiffU(DefaultScale, psrc, Length); } } - + [Benchmark] public unsafe float MaxAbsU() { fixed (float* psrc = src) { - return CpuMathNativeUtils.MaxAbsU(psrc, Length); + return Thunk.MaxAbsU(psrc, Length); } } - + [Benchmark] public unsafe float MaxAbsDiffU() { fixed (float* psrc = src) { - return CpuMathNativeUtils.MaxAbsDiffU(DefaultScale, psrc, Length); + return Thunk.MaxAbsDiffU(DefaultScale, psrc, Length); } } - + [Benchmark] public unsafe float DotU() { fixed (float* psrc = src) fixed (float* pdst = dst) { - return CpuMathNativeUtils.DotU(psrc, pdst, Length); + return Thunk.DotU(psrc, pdst, Length); } } - + [Benchmark] public unsafe float DotSU() { @@ -192,7 +203,7 @@ public unsafe float DotSU() fixed (float* pdst = dst) fixed (int* pidx = idx) { - return CpuMathNativeUtils.DotSU(psrc, pdst, pidx, IndexLength); + return Thunk.DotSU(psrc, pdst, pidx, IndexLength); } } @@ -202,7 +213,7 @@ public unsafe float Dist2() fixed (float* psrc = src) fixed (float* pdst = dst) { - return CpuMathNativeUtils.Dist2(psrc, pdst, Length); + return Thunk.Dist2(psrc, pdst, Length); } } @@ -213,7 +224,7 @@ public unsafe void SdcaL1UpdateU() fixed (float* pdst = dst) fixed (float* pres = result) { - CpuMathNativeUtils.SdcaL1UpdateU(DefaultScale, psrc, DefaultScale, pdst, pres, Length); + Thunk.SdcaL1UpdateU(DefaultScale, psrc, DefaultScale, pdst, pres, Length); } } @@ -225,42 +236,36 @@ public unsafe void SdcaL1UpdateSU() fixed (float* pres = result) fixed (int* pidx = idx) { - CpuMathNativeUtils.SdcaL1UpdateSU(DefaultScale, psrc, pidx, DefaultScale, pdst, pres, IndexLength); + Thunk.SdcaL1UpdateSU(DefaultScale, psrc, pidx, DefaultScale, pdst, pres, IndexLength); } } [Benchmark] public unsafe void MatMul() { - fixed (float* psrc = &src[0]) - fixed (float* pdst = &dst[0]) - fixed (float* psrc1 = &src1[0]) - { - Thunk.MatMul(psrc1, psrc, pdst, 1000, 1000); - } + fixed (float* pmat = &testMatrixAligned.Items[0]) + fixed (float* psrc = &testSrcVectorAligned.Items[0]) + fixed (float* pdst = &testDstVectorAligned.Items[0]) + Thunk.MatMul(Ptr(testMatrixAligned, pmat), Ptr(testSrcVectorAligned, psrc), Ptr(testDstVectorAligned, pdst), matrixLength, testSrcVectorAligned.Size); } - + [Benchmark] public unsafe void MatMulTran() { - fixed (float* psrc = &src[0]) - fixed (float* pdst = &dst[0]) - fixed (float* psrc1 = &src1[0]) - { - Thunk.MatMulTran(psrc1, psrc, pdst, 1000, 1000); - } + fixed (float* pmat = &testMatrixAligned.Items[0]) + fixed (float* psrc = &testSrcVectorAligned.Items[0]) + fixed (float* pdst = &testDstVectorAligned.Items[0]) + Thunk.MatMulTran(Ptr(testMatrixAligned, pmat), Ptr(testSrcVectorAligned, psrc), Ptr(testDstVectorAligned, pdst), testDstVectorAligned.Size, matrixLength); } [Benchmark] public unsafe void MatMulP() { - fixed (float* psrc = &src[0]) - fixed (float* pdst = &dst[0]) - fixed (float* psrc1 = &src1[0]) - fixed (int* pidx = &matrixIdx[0]) - { - Thunk.MatMulP(psrc1, pidx, psrc, 0, 0, MatrixIndexLength, pdst, 1000, 1000); - } + fixed (float* pmat = &testMatrixAligned.Items[0]) + fixed (float* psrc = &testSrcVectorAligned.Items[0]) + fixed (float* pdst = &testDstVectorAligned.Items[0]) + fixed (int* ppossrc = &matrixIdx[0]) + Thunk.MatMulP(Ptr(testMatrixAligned, pmat), ppossrc, Ptr(testSrcVectorAligned, psrc), 0, 0, MatrixIndexLength, Ptr(testDstVectorAligned, pdst), matrixLength, testSrcVectorAligned.Size); } } } diff --git a/test/Microsoft.ML.CpuMath.PerformanceTests/PerformanceTests.cs b/test/Microsoft.ML.CpuMath.PerformanceTests/PerformanceTests.cs index 6726603b5a..d2dcf3cfff 100644 --- a/test/Microsoft.ML.CpuMath.PerformanceTests/PerformanceTests.cs +++ b/test/Microsoft.ML.CpuMath.PerformanceTests/PerformanceTests.cs @@ -17,10 +17,16 @@ public abstract class PerformanceTests protected const int IndexLength = 1000003; protected const int Length = 1000003; - protected const int MatrixIndexLength = 100; + protected const int MatrixIndexLength = 1000; private const int DefaultSeed = 253421; protected const float DefaultScale = 1.11f; + protected int matrixLength = 1000; + protected virtual int align { get; set; } = 16; + + internal AlignedArray testMatrixAligned; + internal AlignedArray testSrcVectorAligned; + internal AlignedArray testDstVectorAligned; protected float[] src, dst, original, src1, src2, result; protected int[] idx; @@ -93,6 +99,15 @@ public void Setup() { matrixIdx[i] = rand.Next(0, 1000); } + + testMatrixAligned = new AlignedArray(matrixLength * matrixLength, align); + testMatrixAligned.CopyFrom(src.AsSpan(0, (matrixLength - 1) * ( matrixLength - 1))); + + testSrcVectorAligned = new AlignedArray(matrixLength, align); + testSrcVectorAligned.CopyFrom(src1.AsSpan(0, matrixLength - 1)); // odd input + + testDstVectorAligned = new AlignedArray(matrixLength, align); + testDstVectorAligned.CopyFrom(dst.AsSpan(0, matrixLength)); } [GlobalCleanup] diff --git a/test/Microsoft.ML.CpuMath.PerformanceTests/SsePerformanceTests.cs b/test/Microsoft.ML.CpuMath.PerformanceTests/SsePerformanceTests.cs index 4499af9ee5..e079e8bd7e 100644 --- a/test/Microsoft.ML.CpuMath.PerformanceTests/SsePerformanceTests.cs +++ b/test/Microsoft.ML.CpuMath.PerformanceTests/SsePerformanceTests.cs @@ -101,14 +101,14 @@ public void SdcaL1UpdateSU() [Benchmark] public void MatMul() - => SseIntrinsics.MatMul(src, src1, dst, 1000, 1000); + => SseIntrinsics.MatMul(testMatrixAligned, testSrcVectorAligned, testDstVectorAligned, matrixLength, matrixLength); [Benchmark] public void MatMulTran() - => SseIntrinsics.MatMulTran(src, src1, dst, 1000, 1000); + => SseIntrinsics.MatMulTran(testMatrixAligned, testSrcVectorAligned, testDstVectorAligned, matrixLength, matrixLength); [Benchmark] public void MatMulP() - => SseIntrinsics.MatMulP(src, matrixIdx, src1, 0, 0, MatrixIndexLength, dst, 1000, 1000); + => SseIntrinsics.MatMulP(testMatrixAligned, matrixIdx, testSrcVectorAligned, 0, 0, MatrixIndexLength, testDstVectorAligned, matrixLength, matrixLength); } } diff --git a/test/Microsoft.ML.Tests/Transformers/RffTests.cs b/test/Microsoft.ML.Tests/Transformers/RffTests.cs index f231082757..d647eddbca 100644 --- a/test/Microsoft.ML.Tests/Transformers/RffTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/RffTests.cs @@ -37,7 +37,7 @@ private class TestClassInvalidSchema public int A; } - [ConditionalFact(typeof(BaseTestBaseline), nameof(BaseTestBaseline.LessThanNetCore30OrNotNetCore))] // netcore3.0 output differs from Baseline + [Fact] public void RffWorkout() { Random rand = new Random(); diff --git a/test/Microsoft.ML.TimeSeries.Tests/TimeSeries.cs b/test/Microsoft.ML.TimeSeries.Tests/TimeSeries.cs index 39340d225b..f20de461f4 100644 --- a/test/Microsoft.ML.TimeSeries.Tests/TimeSeries.cs +++ b/test/Microsoft.ML.TimeSeries.Tests/TimeSeries.cs @@ -158,7 +158,7 @@ public void SavePipePercentileThreshold() Done(); } - [ConditionalFact(typeof(BaseTestBaseline), nameof(BaseTestBaseline.LessThanNetCore30OrNotNetCore))] // Test is Flaky on netcore 3.0 + [Fact] public void SavePipeMovingAverageUniform() { TestCore(null, true, diff --git a/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesEstimatorTests.cs b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesEstimatorTests.cs index a7892701d7..3d7cfd5d32 100644 --- a/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesEstimatorTests.cs +++ b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesEstimatorTests.cs @@ -41,7 +41,7 @@ public TimeSeriesEstimatorTests(ITestOutputHelper output) : base(output) { } - [ConditionalFact(typeof(BaseTestBaseline), nameof(BaseTestBaseline.LessThanNetCore30OrNotNetCore))] // netcore3.0 output differs from Baseline + [Fact] void TestSsaChangePointEstimator() { int Confidence = 95; @@ -75,7 +75,7 @@ void TestSsaChangePointEstimator() Done(); } - [ConditionalFact(typeof(BaseTestBaseline), nameof(BaseTestBaseline.LessThanNetCore30OrNotNetCore))] // netcore3.0 output differs from Baseline + [Fact] void TestSsaSpikeEstimator() { int Confidence = 95;