Skip to content

Commit e2d5efd

Browse files
authored
[libc] Update the memory helper functions for simd types (#160174)
Summary: This unifies the interface to just be a bunch of `load` and `store` functions that optionally accept a mask / indices for gathers and scatters with masks. I had to rename this from `load` and `store` because it conflicts with the other version in `op_generic`. I might just work around that with a trait instead.
1 parent bb38b48 commit e2d5efd

File tree

3 files changed

+119
-18
lines changed

3 files changed

+119
-18
lines changed

libc/src/__support/CPP/simd.h

Lines changed: 54 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -287,34 +287,72 @@ LIBC_INLINE constexpr static T hmax(simd<T, N> v) {
287287
}
288288

289289
// Accessor helpers.
290-
template <typename T, internal::enable_if_simd_t<T> = 0>
291-
LIBC_INLINE T load_unaligned(const void *ptr) {
290+
template <typename T>
291+
LIBC_INLINE T constexpr static load(const void *ptr, bool aligned = false) {
292+
if (aligned)
293+
ptr = __builtin_assume_aligned(ptr, alignof(T));
292294
T tmp;
293-
__builtin_memcpy(&tmp, ptr, sizeof(T));
295+
__builtin_memcpy_inline(
296+
&tmp, reinterpret_cast<const simd_element_type_t<T> *>(ptr), sizeof(T));
294297
return tmp;
295298
}
296299
template <typename T, internal::enable_if_simd_t<T> = 0>
297-
LIBC_INLINE T load_aligned(const void *ptr) {
298-
return load_unaligned<T>(__builtin_assume_aligned(ptr, alignof(T)));
300+
LIBC_INLINE constexpr static void store(T v, void *ptr, bool aligned = false) {
301+
if (aligned)
302+
ptr = __builtin_assume_aligned(ptr, alignof(T));
303+
__builtin_memcpy_inline(ptr, &v, sizeof(T));
299304
}
300305
template <typename T, internal::enable_if_simd_t<T> = 0>
301-
LIBC_INLINE T store_unaligned(T v, void *ptr) {
302-
__builtin_memcpy(ptr, &v, sizeof(T));
306+
LIBC_INLINE constexpr static T
307+
load_masked(simd<bool, simd_size_v<T>> mask, const void *ptr,
308+
T passthru = internal::poison<T>(), bool aligned = false) {
309+
if (aligned)
310+
ptr = __builtin_assume_aligned(ptr, alignof(T));
311+
return __builtin_masked_load(
312+
mask, reinterpret_cast<const simd_element_type_t<T> *>(ptr), passthru);
303313
}
304314
template <typename T, internal::enable_if_simd_t<T> = 0>
305-
LIBC_INLINE T store_aligned(T v, void *ptr) {
306-
store_unaligned<T>(v, __builtin_assume_aligned(ptr, alignof(T)));
315+
LIBC_INLINE constexpr static void store_masked(simd<bool, simd_size_v<T>> mask,
316+
T v, void *ptr,
317+
bool aligned = false) {
318+
if (aligned)
319+
ptr = __builtin_assume_aligned(ptr, alignof(T));
320+
__builtin_masked_store(mask, v,
321+
reinterpret_cast<simd_element_type_t<T> *>(ptr));
322+
}
323+
template <typename T, typename Idx, internal::enable_if_simd_t<T> = 0>
324+
LIBC_INLINE constexpr static T gather(simd<bool, simd_size_v<T>> mask, Idx idx,
325+
const void *base, bool aligned = false) {
326+
if (aligned)
327+
base = __builtin_assume_aligned(base, alignof(T));
328+
return __builtin_masked_gather(
329+
mask, idx, reinterpret_cast<const simd_element_type_t<T> *>(base));
330+
}
331+
template <typename T, typename Idx, internal::enable_if_simd_t<T> = 0>
332+
LIBC_INLINE constexpr static void scatter(simd<bool, simd_size_v<T>> mask,
333+
Idx idx, T v, void *base,
334+
bool aligned = false) {
335+
if (aligned)
336+
base = __builtin_assume_aligned(base, alignof(T));
337+
__builtin_masked_scatter(mask, idx, v,
338+
reinterpret_cast<simd_element_type_t<T> *>(base));
307339
}
308340
template <typename T, internal::enable_if_simd_t<T> = 0>
309-
LIBC_INLINE T
310-
masked_load(simd<bool, simd_size_v<T>> m, void *ptr,
311-
T passthru = internal::poison<simd_element_type<T>>()) {
312-
return __builtin_masked_load(m, ptr, passthru);
341+
LIBC_INLINE constexpr static T
342+
expand(simd<bool, simd_size_v<T>> mask, const void *ptr,
343+
T passthru = internal::poison<T>(), bool aligned = false) {
344+
if (aligned)
345+
ptr = __builtin_assume_aligned(ptr, alignof(T));
346+
return __builtin_masked_expand_load(
347+
mask, reinterpret_cast<const simd_element_type_t<T> *>(ptr), passthru);
313348
}
314349
template <typename T, internal::enable_if_simd_t<T> = 0>
315-
LIBC_INLINE T masked_store(simd<bool, simd_size_v<T>> m, T v, void *ptr) {
316-
__builtin_masked_store(
317-
m, v, static_cast<T *>(__builtin_assume_aligned(ptr, alignof(T))));
350+
LIBC_INLINE constexpr static void compress(simd<bool, simd_size_v<T>> mask, T v,
351+
void *ptr, bool aligned = false) {
352+
if (aligned)
353+
ptr = __builtin_assume_aligned(ptr, alignof(T));
354+
__builtin_masked_compress_store(
355+
mask, v, reinterpret_cast<simd_element_type_t<T> *>(ptr));
318356
}
319357

320358
// Construction helpers.

libc/src/string/memory_utils/generic/inline_strlen.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,15 @@ string_length(const char *src) {
3232
const cpp::simd<char> *aligned = reinterpret_cast<const cpp::simd<char> *>(
3333
__builtin_align_down(src, alignment));
3434

35-
cpp::simd<char> chars = cpp::load_aligned<cpp::simd<char>>(aligned);
35+
cpp::simd<char> chars = cpp::load<cpp::simd<char>>(aligned, /*aligned=*/true);
3636
cpp::simd_mask<char> mask = chars == null_byte;
3737
size_t offset = src - reinterpret_cast<const char *>(aligned);
3838
if (cpp::any_of(shift_mask(mask, offset)))
3939
return cpp::find_first_set(shift_mask(mask, offset));
4040

4141
for (;;) {
42-
cpp::simd<char> chars = cpp::load_aligned<cpp::simd<char>>(++aligned);
42+
cpp::simd<char> chars = cpp::load<cpp::simd<char>>(++aligned,
43+
/*aligned=*/true);
4344
cpp::simd_mask<char> mask = chars == null_byte;
4445
if (cpp::any_of(mask))
4546
return (reinterpret_cast<const char *>(aligned) - src) +

libc/test/src/__support/CPP/simd_test.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,65 @@ TEST(LlvmLibcSIMDTest, SplitConcat) {
8686
cpp::simd<char, 8> n = cpp::concat(c, c, c, c, c, c, c, c);
8787
EXPECT_TRUE(cpp::all_of(n == ~0));
8888
}
89+
90+
TEST(LlvmLibcSIMDTest, LoadStore) {
91+
constexpr size_t SIZE = cpp::simd_size_v<cpp::simd<int>>;
92+
alignas(alignof(cpp::simd<int>)) int buf[SIZE];
93+
94+
cpp::simd<int> v1 = cpp::splat(1);
95+
cpp::store(v1, buf);
96+
cpp::simd<int> v2 = cpp::load<cpp::simd<int>>(buf);
97+
98+
EXPECT_TRUE(cpp::all_of(v1 == 1));
99+
EXPECT_TRUE(cpp::all_of(v2 == 1));
100+
101+
cpp::simd<int> v3 = cpp::splat(2);
102+
cpp::store(v3, buf, /*aligned=*/true);
103+
cpp::simd<int> v4 = cpp::load<cpp::simd<int>>(buf, /*aligned=*/true);
104+
105+
EXPECT_TRUE(cpp::all_of(v3 == 2));
106+
EXPECT_TRUE(cpp::all_of(v4 == 2));
107+
}
108+
109+
TEST(LlvmLibcSIMDTest, MaskedLoadStore) {
110+
constexpr size_t SIZE = cpp::simd_size_v<cpp::simd<int>>;
111+
alignas(alignof(cpp::simd<int>)) int buf[SIZE] = {0};
112+
113+
cpp::simd<int> mask = cpp::iota(0) % 2 == 0;
114+
cpp::simd<int> v1 = cpp::splat(1);
115+
116+
cpp::store_masked<cpp::simd<int>>(mask, v1, buf);
117+
cpp::simd<int> v2 = cpp::load_masked<cpp::simd<int>>(mask, buf);
118+
119+
EXPECT_TRUE(cpp::all_of((v2 == 1) == mask));
120+
}
121+
122+
TEST(LlvmLibcSIMDTest, GatherScatter) {
123+
constexpr int SIZE = cpp::simd_size_v<cpp::simd<int>>;
124+
alignas(alignof(cpp::simd<int>)) int buf[SIZE];
125+
126+
cpp::simd<int> mask = cpp::iota(1);
127+
cpp::simd<int> idx = cpp::iota(0);
128+
cpp::simd<int> v1 = cpp::splat(1);
129+
130+
cpp::scatter<cpp::simd<int>>(mask, idx, v1, buf);
131+
cpp::simd<int> v2 = cpp::gather<cpp::simd<int>>(mask, idx, buf);
132+
133+
EXPECT_TRUE(cpp::all_of(v1 == 1));
134+
EXPECT_TRUE(cpp::all_of(v2 == 1));
135+
}
136+
137+
TEST(LlvmLibcSIMDTest, MaskedCompressExpand) {
138+
constexpr size_t SIZE = cpp::simd_size_v<cpp::simd<int>>;
139+
alignas(alignof(cpp::simd<int>)) int buf[SIZE] = {0};
140+
141+
cpp::simd<int> mask_expand = cpp::iota(0) % 2 == 0;
142+
cpp::simd<int> mask_compress = 1;
143+
144+
cpp::simd<int> v1 = cpp::iota(0);
145+
146+
cpp::compress<cpp::simd<int>>(mask_compress, v1, buf);
147+
cpp::simd<int> v2 = cpp::expand<cpp::simd<int>>(mask_expand, buf);
148+
149+
EXPECT_TRUE(cpp::all_of(!mask_expand || v2 <= SIZE / 2));
150+
}

0 commit comments

Comments
 (0)