|
16 | 16 | #include "hdr/stdint_proxy.h"
|
17 | 17 | #include "src/__support/CPP/algorithm.h"
|
18 | 18 | #include "src/__support/CPP/limits.h"
|
| 19 | +#include "src/__support/CPP/tuple.h" |
19 | 20 | #include "src/__support/CPP/type_traits.h"
|
| 21 | +#include "src/__support/CPP/utility/integer_sequence.h" |
20 | 22 | #include "src/__support/macros/attributes.h"
|
21 | 23 | #include "src/__support/macros/config.h"
|
22 | 24 |
|
@@ -51,6 +53,7 @@ template <typename T> inline constexpr size_t native_vector_size = 1;
|
51 | 53 | template <typename T> LIBC_INLINE constexpr T poison() {
|
52 | 54 | return __builtin_nondeterministic_value(T());
|
53 | 55 | }
|
| 56 | + |
54 | 57 | } // namespace internal
|
55 | 58 |
|
56 | 59 | // Type aliases.
|
@@ -267,6 +270,77 @@ LIBC_INLINE constexpr simd<T, N> select(simd<bool, N> m, simd<T, N> x,
|
267 | 270 | return m ? x : y;
|
268 | 271 | }
|
269 | 272 |
|
| 273 | +namespace internal { |
| 274 | +template <typename T, size_t N, size_t O, size_t... I> |
| 275 | +LIBC_INLINE constexpr static cpp::simd<T, sizeof...(I)> |
| 276 | +extend(cpp::simd<T, N> x, cpp::index_sequence<I...>) { |
| 277 | + return __builtin_shufflevector(x, x, (I < O ? static_cast<int>(I) : -1)...); |
| 278 | +} |
| 279 | +template <typename T, size_t N, size_t M, size_t O> |
| 280 | +LIBC_INLINE constexpr static auto extend(cpp::simd<T, N> x) { |
| 281 | + if constexpr (N == M) |
| 282 | + return x; |
| 283 | + else if constexpr (M <= 2 * N) |
| 284 | + return extend<T, N, M>(x, cpp::make_index_sequence<M>{}); |
| 285 | + else |
| 286 | + return extend<T, 2 * N, M, O>( |
| 287 | + extend<T, N, 2 * N>(x, cpp::make_index_sequence<2 * N>{})); |
| 288 | +} |
| 289 | +template <typename T, size_t N, size_t M, size_t... I> |
| 290 | +LIBC_INLINE constexpr static cpp::simd<T, N + M> |
| 291 | +concat(cpp::simd<T, N> x, cpp::simd<T, M> y, cpp::index_sequence<I...>) { |
| 292 | + constexpr size_t L = (N > M ? N : M); |
| 293 | + |
| 294 | + auto x_ext = extend<T, N, L, N>(x); |
| 295 | + auto y_ext = extend<T, M, L, M>(y); |
| 296 | + |
| 297 | + auto remap = [](size_t idx) -> int { |
| 298 | + if (idx < N) |
| 299 | + return static_cast<int>(idx); |
| 300 | + if (idx < N + M) |
| 301 | + return static_cast<int>((idx - N) + L); |
| 302 | + return -1; |
| 303 | + }; |
| 304 | + |
| 305 | + return __builtin_shufflevector(x_ext, y_ext, remap(I)...); |
| 306 | +} |
| 307 | + |
| 308 | +template <typename T, size_t N, size_t Count, size_t Offset, size_t... I> |
| 309 | +LIBC_INLINE constexpr static cpp::simd<T, Count> |
| 310 | +slice(cpp::simd<T, N> x, cpp::index_sequence<I...>) { |
| 311 | + return __builtin_shufflevector(x, x, (Offset + I)...); |
| 312 | +} |
| 313 | +template <typename T, size_t N, size_t Offset, size_t Head, size_t... Tail> |
| 314 | +LIBC_INLINE constexpr static auto split(cpp::simd<T, N> x) { |
| 315 | + auto first = cpp::make_tuple( |
| 316 | + slice<T, N, Head, Offset>(x, cpp::make_index_sequence<Head>{})); |
| 317 | + if constexpr (sizeof...(Tail) > 0) |
| 318 | + return cpp::tuple_cat(first, split<T, N, Offset + Head, Tail...>(x)); |
| 319 | + else |
| 320 | + return first; |
| 321 | +} |
| 322 | + |
| 323 | +} // namespace internal |
| 324 | + |
| 325 | +// Shuffling helpers. |
| 326 | +template <typename T, size_t N, size_t M> |
| 327 | +LIBC_INLINE constexpr static auto concat(cpp::simd<T, N> x, cpp::simd<T, M> y) { |
| 328 | + return internal::concat(x, y, make_index_sequence<N + M>{}); |
| 329 | +} |
| 330 | +template <typename T, size_t N, size_t M, typename... Rest> |
| 331 | +LIBC_INLINE constexpr static auto concat(cpp::simd<T, N> x, cpp::simd<T, M> y, |
| 332 | + Rest... rest) { |
| 333 | + auto xy = concat(x, y); |
| 334 | + if constexpr (sizeof...(Rest)) |
| 335 | + return concat(xy, rest...); |
| 336 | + else |
| 337 | + return xy; |
| 338 | +} |
| 339 | +template <size_t... Sizes, typename T, size_t N> auto split(cpp::simd<T, N> x) { |
| 340 | + static_assert((... + Sizes) == N, "split sizes must sum to vector size"); |
| 341 | + return internal::split<T, N, 0, Sizes...>(x); |
| 342 | +} |
| 343 | + |
270 | 344 | // TODO: where expressions, scalar overloads, ABI types.
|
271 | 345 |
|
272 | 346 | } // namespace cpp
|
|
0 commit comments