-
Notifications
You must be signed in to change notification settings - Fork 4.7k
/
TensorPrimitives.IndexOfMinMagnitude.cs
137 lines (123 loc) · 7.12 KB
/
TensorPrimitives.IndexOfMinMagnitude.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Runtime.CompilerServices;
using System.Runtime.Intrinsics;
namespace System.Numerics.Tensors
{
public static partial class TensorPrimitives
{
/// <summary>Searches for the index of the number with the smallest magnitude in the specified tensor.</summary>
/// <param name="x">The tensor, represented as a span.</param>
/// <returns>The index of the element in <paramref name="x"/> with the smallest magnitude (absolute value), or -1 if <paramref name="x"/> is empty.</returns>
/// <remarks>
/// <para>
/// The determination of the minimum magnitude matches the IEEE 754:2019 `minimumMagnitude` function. If any value equal to NaN
/// is present, the index of the first is returned. If two values have the same magnitude and one is positive and the other is negative,
/// the negative value is considered to have the smaller magnitude.
/// </para>
/// <para>
/// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different
/// operating systems or architectures.
/// </para>
/// </remarks>
public static int IndexOfMinMagnitude<T>(ReadOnlySpan<T> x)
where T : INumber<T> =>
IndexOfMinMaxCore<T, IndexOfMinMagnitudeOperator<T>>(x);
internal readonly struct IndexOfMinMagnitudeOperator<T> : IIndexOfOperator<T> where T : INumber<T>
{
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static void Invoke(ref Vector128<T> result, Vector128<T> current, ref Vector128<T> resultIndex, Vector128<T> currentIndex)
{
Vector128<T> resultMag = Vector128.Abs(result), currentMag = Vector128.Abs(current);
Vector128<T> useResult = Vector128.LessThan(resultMag, currentMag);
Vector128<T> equalMask = Vector128.Equals(resultMag, currentMag);
if (equalMask != Vector128<T>.Zero)
{
Vector128<T> lessThanIndexMask = IndexLessThan(resultIndex, currentIndex);
if (typeof(T) == typeof(float) || typeof(T) == typeof(double))
{
// bool useResult = equal && ((IsNegative(result) == IsNegative(current)) ? (resultIndex < currentIndex) : IsNegative(result));
Vector128<T> resultNegative = IsNegative(result);
Vector128<T> sameSign = Vector128.Equals(resultNegative.AsInt32(), IsNegative(current).AsInt32()).As<int, T>();
useResult |= equalMask & ElementWiseSelect(sameSign, lessThanIndexMask, resultNegative);
}
else
{
useResult |= equalMask & lessThanIndexMask;
}
}
result = ElementWiseSelect(useResult, result, current);
resultIndex = ElementWiseSelect(useResult, resultIndex, currentIndex);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static void Invoke(ref Vector256<T> result, Vector256<T> current, ref Vector256<T> resultIndex, Vector256<T> currentIndex)
{
Vector256<T> resultMag = Vector256.Abs(result), currentMag = Vector256.Abs(current);
Vector256<T> useResult = Vector256.LessThan(resultMag, currentMag);
Vector256<T> equalMask = Vector256.Equals(resultMag, currentMag);
if (equalMask != Vector256<T>.Zero)
{
Vector256<T> lessThanIndexMask = IndexLessThan(resultIndex, currentIndex);
if (typeof(T) == typeof(float) || typeof(T) == typeof(double))
{
// bool useResult = equal && ((IsNegative(result) == IsNegative(current)) ? (resultIndex < currentIndex) : IsNegative(result));
Vector256<T> resultNegative = IsNegative(result);
Vector256<T> sameSign = Vector256.Equals(resultNegative.AsInt32(), IsNegative(current).AsInt32()).As<int, T>();
useResult |= equalMask & ElementWiseSelect(sameSign, lessThanIndexMask, resultNegative);
}
else
{
useResult |= equalMask & lessThanIndexMask;
}
}
result = ElementWiseSelect(useResult, result, current);
resultIndex = ElementWiseSelect(useResult, resultIndex, currentIndex);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static void Invoke(ref Vector512<T> result, Vector512<T> current, ref Vector512<T> resultIndex, Vector512<T> currentIndex)
{
Vector512<T> resultMag = Vector512.Abs(result), currentMag = Vector512.Abs(current);
Vector512<T> useResult = Vector512.LessThan(resultMag, currentMag);
Vector512<T> equalMask = Vector512.Equals(resultMag, currentMag);
if (equalMask != Vector512<T>.Zero)
{
Vector512<T> lessThanIndexMask = IndexLessThan(resultIndex, currentIndex);
if (typeof(T) == typeof(float) || typeof(T) == typeof(double))
{
// bool useResult = equal && ((IsNegative(result) == IsNegative(current)) ? (resultIndex < currentIndex) : IsNegative(result));
Vector512<T> resultNegative = IsNegative(result);
Vector512<T> sameSign = Vector512.Equals(resultNegative.AsInt32(), IsNegative(current).AsInt32()).As<int, T>();
useResult |= equalMask & ElementWiseSelect(sameSign, lessThanIndexMask, resultNegative);
}
else
{
useResult |= equalMask & lessThanIndexMask;
}
}
result = ElementWiseSelect(useResult, result, current);
resultIndex = ElementWiseSelect(useResult, resultIndex, currentIndex);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static int Invoke(ref T result, T current, int resultIndex, int currentIndex)
{
T resultMag = T.Abs(result);
T currentMag = T.Abs(current);
if (resultMag == currentMag)
{
bool currentNegative = IsNegative(current);
if ((IsNegative(result) == currentNegative) ? (currentIndex < resultIndex) : currentNegative)
{
result = current;
return currentIndex;
}
}
else if (currentMag < resultMag)
{
result = current;
return currentIndex;
}
return resultIndex;
}
}
}
}