-
Notifications
You must be signed in to change notification settings - Fork 0
/
argminmax.h
172 lines (155 loc) · 7.44 KB
/
argminmax.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
#pragma once
#ifndef SIMD_SAMPLING_API
#define SIMD_SAMPLING_API
#endif
#ifndef __cplusplus
#include <stdint.h>
#include <stdlib.h>
#else
#include <cstdint>
#include <cstdlib>
#include <algorithm>
#include <queue>
#endif
#include "macros.h"
enum ArgReduction {
ARGMIN,
ARGMAX
};
#ifdef __cplusplus
using std::uint64_t;
using std::size_t;
using std::ptrdiff_t;
extern "C" {
#endif
SIMD_SAMPLING_API uint64_t fargsel(const float *weights, size_t n, enum ArgReduction ar, int mt);
SIMD_SAMPLING_API uint64_t fargmin(const float *weights, size_t n, int mt);
SIMD_SAMPLING_API uint64_t fargmax(const float *weights, size_t n, int mt);
SIMD_SAMPLING_API uint64_t dargsel(const double *weights, size_t n, enum ArgReduction ar, int mt);
SIMD_SAMPLING_API uint64_t dargmin(const double *weights, size_t n, int mt);
SIMD_SAMPLING_API uint64_t dargmax(const double *weights, size_t n, int mt);
SIMD_SAMPLING_API uint64_t fargmin_st(const float *weights, size_t n);
SIMD_SAMPLING_API uint64_t fargmax_st(const float *weights, size_t n);
SIMD_SAMPLING_API uint64_t dargmin_st(const double *weights, size_t n);
SIMD_SAMPLING_API uint64_t dargmax_st(const double *weights, size_t n);
SIMD_SAMPLING_API uint64_t fargmin_mt(const float *weights, size_t n);
SIMD_SAMPLING_API uint64_t fargmax_mt(const float *weights, size_t n);
SIMD_SAMPLING_API uint64_t dargmin_mt(const double *weights, size_t n);
SIMD_SAMPLING_API uint64_t dargmax_mt(const double *weights, size_t n);
SIMD_SAMPLING_API ptrdiff_t fargsel_k(const float *weights, size_t n, ptrdiff_t k, uint64_t *ret, enum ArgReduction ar, int mt);
SIMD_SAMPLING_API ptrdiff_t dargsel_k(const double *weights, size_t n, ptrdiff_t k, uint64_t *ret, enum ArgReduction ar, int mt);
SIMD_SAMPLING_API ptrdiff_t fargsel_k_st(const float *weights, size_t n, ptrdiff_t k, uint64_t *ret, enum ArgReduction ar);
SIMD_SAMPLING_API ptrdiff_t dargsel_k_st(const double *weights, size_t n, ptrdiff_t k, uint64_t *ret, enum ArgReduction ar);
SIMD_SAMPLING_API ptrdiff_t fargsel_k_mt(const float *weights, size_t n, ptrdiff_t k, uint64_t *ret, enum ArgReduction ar);
SIMD_SAMPLING_API ptrdiff_t dargsel_k_mt(const double *weights, size_t n, ptrdiff_t k, uint64_t *ret, enum ArgReduction ar);
SIMD_SAMPLING_API ptrdiff_t fargmin_k_st(const float *weights, size_t n, ptrdiff_t k, uint64_t *ret);
SIMD_SAMPLING_API ptrdiff_t dargmin_k_st(const double *weights, size_t n, ptrdiff_t k, uint64_t *ret);
SIMD_SAMPLING_API ptrdiff_t fargmin_k_mt(const float *weights, size_t n, ptrdiff_t k, uint64_t *ret);
SIMD_SAMPLING_API ptrdiff_t dargmin_k_mt(const double *weights, size_t n, ptrdiff_t k, uint64_t *ret);
SIMD_SAMPLING_API ptrdiff_t fargmax_k_st(const float *weights, size_t n, ptrdiff_t k, uint64_t *ret);
SIMD_SAMPLING_API ptrdiff_t dargmax_k_st(const double *weights, size_t n, ptrdiff_t k, uint64_t *ret);
SIMD_SAMPLING_API ptrdiff_t fargmax_k_mt(const float *weights, size_t n, ptrdiff_t k, uint64_t *ret);
SIMD_SAMPLING_API ptrdiff_t dargmax_k_mt(const double *weights, size_t n, ptrdiff_t k, uint64_t *ret);
SIMD_SAMPLING_API ptrdiff_t dargmin_k(const double *weights, size_t n, ptrdiff_t k, uint64_t *ret, int mt);
SIMD_SAMPLING_API ptrdiff_t fargmin_k(const float *weights, size_t n, ptrdiff_t k, uint64_t *ret, int mt);
SIMD_SAMPLING_API ptrdiff_t dargmax_k(const double *weights, size_t n, ptrdiff_t k, uint64_t *ret, int mt);
SIMD_SAMPLING_API ptrdiff_t fargmax_k(const float *weights, size_t n, ptrdiff_t k, uint64_t *ret, int mt);
#ifdef __cplusplus
}
#include <vector>
#include <stdexcept>
namespace reservoir_simd {
template<typename T>
static inline uint64_t argsel(const T *weights, size_t n, ArgReduction ar, int mt) {
return (ar == ARGMIN ? std::min_element(weights, weights + n)
: std::max_element(weights, weights + n))
- weights;
}
template<> inline uint64_t argsel<double>(const double *weights, size_t n, ArgReduction ar, int mt) {
return dargsel(weights, n, ar, mt);
}
template<> inline uint64_t argsel<float>(const float *weights, size_t n, ArgReduction ar, int mt) {
return fargsel(weights, n, ar, mt);
}
template<typename T> inline uint64_t argmax(const T *weights, size_t n, int mt) {
return std::max_element(weights, weights + n) - weights;
}
template<typename T> inline uint64_t argmin(const T *weights, size_t n, int mt) {
return std::min_element(weights, weights + n) - weights;
}
template<> inline uint64_t argmin<double>(const double *weights, size_t n, int mt) {
return dargsel(weights, n, ARGMIN, mt);
}
template<> inline uint64_t argmin<float>(const float *weights, size_t n, int mt) {
return fargsel(weights, n, ARGMIN, mt);
}
template<> inline uint64_t argmax<double>(const double *weights, size_t n, int mt) {
return dargsel(weights, n, ARGMAX, mt);
}
template<> inline uint64_t argmax<float>(const float *weights, size_t n, int mt) {
return fargsel(weights, n, ARGMAX, mt);
}
template<typename Container, typename=typename std::enable_if<!std::is_pointer<typename std::decay<Container>::type>::value>::type>
INLINE uint64_t argmax(const Container &x, bool mt) {
return argmax(x.data(), x.size(), mt);
}
template<typename Container, typename=typename std::enable_if<!std::is_pointer<typename std::decay<Container>::type>::value>::type>
INLINE uint64_t argmax(const Container &x) {
return argmax(x.data(), x.size(), true);
}
template<typename Container, typename=typename std::enable_if<!std::is_pointer<typename std::decay<Container>::type>::value>::type>
INLINE uint64_t argmin(const Container &x, bool mt) {
return argmin(x.data(), x.size(), mt);
}
template<typename Container, typename=typename std::enable_if<!std::is_pointer<typename std::decay<Container>::type>::value>::type>
INLINE uint64_t argmin(const Container &x) {
return argmin(x.data(), x.size(), true);
}
template<typename T>
INLINE ptrdiff_t argsel(T *ptr, size_t n, ptrdiff_t k, uint64_t *ret, ArgReduction ar, bool mt) {
if(mt) throw std::invalid_argument("mt only available for float/double");
if(ar == ARGMAX) {
std::priority_queue<std::pair<T, ptrdiff_t>, std::vector<std::pair<T, ptrdiff_t>>, std::greater<std::pair<T, ptrdiff_t>>> pq;
for(auto p = ptr; p != ptr + n; ++p) {
if(pq.size() < k) pq.push(*p, p - ptr);
else if(*p > pq.top()) {
pq.pop();
pq.push(*p, p - ptr);
}
}
while(pq.size()) {
ret[pq.size() - 1] = pq.top().second;
pq.pop();
}
} else {
std::priority_queue<std::pair<T, ptrdiff_t>, std::vector<std::pair<T, ptrdiff_t>>, std::less<std::pair<T, ptrdiff_t>>> pq;
for(auto p = ptr; p != ptr + n; ++p) {
if(pq.size() < k) pq.push(*p, ptr - ptr);
else if(*p < pq.top()) {
pq.pop();
pq.push(*p, p - ptr);
}
}
while(pq.size()) {
ret[pq.size() - 1] = pq.top().second;
pq.pop();
}
}
return n;
}
template<>
INLINE ptrdiff_t argsel<double>(double *ptr, size_t n, ptrdiff_t k, uint64_t *ret, ArgReduction ar, bool mt) {
return dargsel_k(ptr, n, k, ret, ar, mt);
}
template<>
INLINE ptrdiff_t argsel<float>(float *ptr, size_t n, ptrdiff_t k, uint64_t *ret, ArgReduction ar, bool mt) {
return fargsel_k(ptr, n, k, ret, ar, mt);
}
template<typename T>
inline std::vector<uint64_t> argsel(T *ptr, size_t n, ptrdiff_t k, ArgReduction ar, bool mt) {
std::vector<uint64_t> ret(k);
argsel(ptr, n, k, ret.data(), ar, mt);
return ret;
}
} // namespace reservoir_simd
#endif