diff --git a/libc/src/string/memory_utils/memset_implementations.h b/libc/src/string/memory_utils/memset_implementations.h index 16c11470572b5..dfc7df87aceb2 100644 --- a/libc/src/string/memory_utils/memset_implementations.h +++ b/libc/src/string/memory_utils/memset_implementations.h @@ -26,86 +26,101 @@ namespace __llvm_libc { inline_memset_embedded_tiny(Ptr dst, uint8_t value, size_t count) { LIBC_LOOP_NOUNROLL for (size_t offset = 0; offset < count; ++offset) - generic::Memset<1, 1>::block(dst + offset, value); + generic::Memset::block(dst + offset, value); } #if defined(LIBC_TARGET_ARCH_IS_X86) -template [[maybe_unused]] LIBC_INLINE static void inline_memset_x86(Ptr dst, uint8_t value, size_t count) { +#if defined(__AVX512F__) + using uint128_t = uint8x16_t; + using uint256_t = uint8x32_t; + using uint512_t = uint8x64_t; +#elif defined(__AVX__) + using uint128_t = uint8x16_t; + using uint256_t = uint8x32_t; + using uint512_t = cpp::array; +#elif defined(__SSE2__) + using uint128_t = uint8x16_t; + using uint256_t = cpp::array; + using uint512_t = cpp::array; +#else + using uint128_t = cpp::array; + using uint256_t = cpp::array; + using uint512_t = cpp::array; +#endif + if (count == 0) return; if (count == 1) - return generic::Memset<1, MaxSize>::block(dst, value); + return generic::Memset::block(dst, value); if (count == 2) - return generic::Memset<2, MaxSize>::block(dst, value); + return generic::Memset::block(dst, value); if (count == 3) - return generic::Memset<3, MaxSize>::block(dst, value); + return generic::Memset::block(dst, value); if (count <= 8) - return generic::Memset<4, MaxSize>::head_tail(dst, value, count); + return generic::Memset::head_tail(dst, value, count); if (count <= 16) - return generic::Memset<8, MaxSize>::head_tail(dst, value, count); + return generic::Memset::head_tail(dst, value, count); if (count <= 32) - return generic::Memset<16, MaxSize>::head_tail(dst, value, count); + return generic::Memset::head_tail(dst, value, count); if (count <= 64) - return generic::Memset<32, MaxSize>::head_tail(dst, value, count); + return generic::Memset::head_tail(dst, value, count); if (count <= 128) - return generic::Memset<64, MaxSize>::head_tail(dst, value, count); + return generic::Memset::head_tail(dst, value, count); // Aligned loop - generic::Memset<32, MaxSize>::block(dst, value); + generic::Memset::block(dst, value); align_to_next_boundary<32>(dst, count); - return generic::Memset<32, MaxSize>::loop_and_tail(dst, value, count); + return generic::Memset::loop_and_tail(dst, value, count); } #endif // defined(LIBC_TARGET_ARCH_IS_X86) #if defined(LIBC_TARGET_ARCH_IS_AARCH64) -template [[maybe_unused]] LIBC_INLINE static void inline_memset_aarch64(Ptr dst, uint8_t value, size_t count) { + static_assert(aarch64::kNeon, "aarch64 supports vector types"); + using uint128_t = uint8x16_t; + using uint256_t = uint8x32_t; + using uint512_t = uint8x64_t; if (count == 0) return; if (count <= 3) { - generic::Memset<1, MaxSize>::block(dst, value); + generic::Memset::block(dst, value); if (count > 1) - generic::Memset<2, MaxSize>::tail(dst, value, count); + generic::Memset::tail(dst, value, count); return; } if (count <= 8) - return generic::Memset<4, MaxSize>::head_tail(dst, value, count); + return generic::Memset::head_tail(dst, value, count); if (count <= 16) - return generic::Memset<8, MaxSize>::head_tail(dst, value, count); + return generic::Memset::head_tail(dst, value, count); if (count <= 32) - return generic::Memset<16, MaxSize>::head_tail(dst, value, count); + return generic::Memset::head_tail(dst, value, count); if (count <= (32 + 64)) { - generic::Memset<32, MaxSize>::block(dst, value); + generic::Memset::block(dst, value); if (count <= 64) - return generic::Memset<32, MaxSize>::tail(dst, value, count); - generic::Memset<32, MaxSize>::block(dst + 32, value); - generic::Memset<32, MaxSize>::tail(dst, value, count); + return generic::Memset::tail(dst, value, count); + generic::Memset::block(dst + 32, value); + generic::Memset::tail(dst, value, count); return; } if (count >= 448 && value == 0 && aarch64::neon::hasZva()) { - generic::Memset<64, MaxSize>::block(dst, 0); + generic::Memset::block(dst, 0); align_to_next_boundary<64>(dst, count); - return aarch64::neon::BzeroCacheLine<64>::loop_and_tail(dst, 0, count); + return aarch64::neon::BzeroCacheLine::loop_and_tail(dst, 0, count); } else { - generic::Memset<16, MaxSize>::block(dst, value); + generic::Memset::block(dst, value); align_to_next_boundary<16>(dst, count); - return generic::Memset<64, MaxSize>::loop_and_tail(dst, value, count); + return generic::Memset::loop_and_tail(dst, value, count); } } #endif // defined(LIBC_TARGET_ARCH_IS_AARCH64) LIBC_INLINE static void inline_memset(Ptr dst, uint8_t value, size_t count) { #if defined(LIBC_TARGET_ARCH_IS_X86) - static constexpr size_t kMaxSize = x86::kAvx512F ? 64 - : x86::kAvx ? 32 - : x86::kSse2 ? 16 - : 8; - return inline_memset_x86(dst, value, count); + return inline_memset_x86(dst, value, count); #elif defined(LIBC_TARGET_ARCH_IS_AARCH64) - static constexpr size_t kMaxSize = aarch64::kNeon ? 16 : 8; - return inline_memset_aarch64(dst, value, count); + return inline_memset_aarch64(dst, value, count); #else return inline_memset_embedded_tiny(dst, value, count); #endif diff --git a/libc/src/string/memory_utils/op_aarch64.h b/libc/src/string/memory_utils/op_aarch64.h index f9aabd0fbcade..e8c8b211e57b5 100644 --- a/libc/src/string/memory_utils/op_aarch64.h +++ b/libc/src/string/memory_utils/op_aarch64.h @@ -30,11 +30,10 @@ static inline constexpr bool kNeon = LLVM_LIBC_IS_DEFINED(__ARM_NEON); namespace neon { -template struct BzeroCacheLine { - static constexpr size_t SIZE = Size; +struct BzeroCacheLine { + static constexpr size_t SIZE = 64; LIBC_INLINE static void block(Ptr dst, uint8_t) { - static_assert(Size == 64); #if __SIZEOF_POINTER__ == 4 asm("dc zva, %w[dst]" : : [dst] "r"(dst) : "memory"); #else @@ -43,15 +42,13 @@ template struct BzeroCacheLine { } LIBC_INLINE static void loop_and_tail(Ptr dst, uint8_t value, size_t count) { - static_assert(Size > 1, "a loop of size 1 does not need tail"); size_t offset = 0; do { block(dst + offset, value); offset += SIZE; } while (offset < count - SIZE); // Unaligned store, we can't use 'dc zva' here. - static constexpr size_t kMaxSize = kNeon ? 16 : 8; - generic::Memset::tail(dst, value, count); + generic::Memset::tail(dst, value, count); } }; diff --git a/libc/src/string/memory_utils/op_generic.h b/libc/src/string/memory_utils/op_generic.h index fd63ac67d005a..a7c5636c2d1ca 100644 --- a/libc/src/string/memory_utils/op_generic.h +++ b/libc/src/string/memory_utils/op_generic.h @@ -33,8 +33,7 @@ #include -namespace __llvm_libc::generic { - +namespace __llvm_libc { // Compiler types using the vector attributes. using uint8x1_t = uint8_t __attribute__((__vector_size__(1))); using uint8x2_t = uint8_t __attribute__((__vector_size__(2))); @@ -43,13 +42,14 @@ using uint8x8_t = uint8_t __attribute__((__vector_size__(8))); using uint8x16_t = uint8_t __attribute__((__vector_size__(16))); using uint8x32_t = uint8_t __attribute__((__vector_size__(32))); using uint8x64_t = uint8_t __attribute__((__vector_size__(64))); +} // namespace __llvm_libc +namespace __llvm_libc::generic { // We accept three types of values as elements for generic operations: // - scalar : unsigned integral types // - vector : compiler types using the vector attributes // - array : a cpp::array where T is itself either a scalar or a vector. // The following traits help discriminate between these cases. - template constexpr bool is_scalar_v = cpp::is_integral_v && cpp::is_unsigned_v; @@ -109,23 +109,11 @@ template T splat(uint8_t value) { T Out; // This for loop is optimized out for vector types. for (size_t i = 0; i < sizeof(T); ++i) - Out[i] = static_cast(value); + Out[i] = value; return Out; } } -template void set(Ptr dst, uint8_t value) { - static_assert(is_element_type_v); - if constexpr (is_scalar_v || is_vector_v) { - store(dst, splat(value)); - } else if constexpr (is_array_v) { - using value_type = typename T::value_type; - const value_type Splat = splat(value); - for (size_t I = 0; I < array_size_v; ++I) - store(dst + (I * sizeof(value_type)), Splat); - } -} - static_assert((UINTPTR_MAX == 4294967295U) || (UINTPTR_MAX == 18446744073709551615UL), "We currently only support 32- or 64-bit platforms"); @@ -149,9 +137,7 @@ constexpr bool is_decreasing_size() { } template struct Largest; -template struct Largest { - using type = uint8_t; -}; +template struct Largest : cpp::type_identity {}; template struct Largest { using next = Largest; @@ -179,6 +165,11 @@ template struct SupportedTypes { using TypeFor = typename details::Largest::type; }; +// Returns the sum of the sizeof of all the TS types. +template static constexpr size_t sum_sizeof() { + return (... + sizeof(TS)); +} + // Map from sizes to structures offering static load, store and splat methods. // Note: On platforms lacking vector support, we use the ArrayType below and // decompose the operation in smaller pieces. @@ -220,27 +211,23 @@ using getTypeFor = cpp::conditional_t< /////////////////////////////////////////////////////////////////////////////// // Memset -// The MaxSize template argument gives the maximum size handled natively by the -// platform. For instance on x86 with AVX support this would be 32. If a size -// greater than MaxSize is requested we break the operation down in smaller -// pieces of size MaxSize. /////////////////////////////////////////////////////////////////////////////// -template struct Memset { - static_assert(is_power2(MaxSize)); - static constexpr size_t SIZE = Size; + +template struct Memset { + static constexpr size_t SIZE = sum_sizeof(); LIBC_INLINE static void block(Ptr dst, uint8_t value) { - if constexpr (Size == 3) { - Memset<1, MaxSize>::block(dst + 2, value); - Memset<2, MaxSize>::block(dst, value); - } else { - using T = details::getTypeFor; - if constexpr (details::is_void_v) { - deferred_static_assert("Unimplemented Size"); - } else { - set(dst, value); - } + static_assert(is_element_type_v); + if constexpr (is_scalar_v || is_vector_v) { + store(dst, splat(value)); + } else if constexpr (is_array_v) { + using value_type = typename T::value_type; + const auto Splat = splat(value); + for (size_t I = 0; I < array_size_v; ++I) + store(dst + (I * sizeof(value_type)), Splat); } + if constexpr (sizeof...(TS)) + Memset::block(dst + sizeof(T), value); } LIBC_INLINE static void tail(Ptr dst, uint8_t value, size_t count) { @@ -253,7 +240,7 @@ template struct Memset { } LIBC_INLINE static void loop_and_tail(Ptr dst, uint8_t value, size_t count) { - static_assert(SIZE > 1); + static_assert(SIZE > 1, "a loop of size 1 does not need tail"); size_t offset = 0; do { block(dst + offset, value); diff --git a/libc/test/src/string/memory_utils/op_tests.cpp b/libc/test/src/string/memory_utils/op_tests.cpp index 7f5d4d4ed460a..b63a629da3f05 100644 --- a/libc/test/src/string/memory_utils/op_tests.cpp +++ b/libc/test/src/string/memory_utils/op_tests.cpp @@ -119,24 +119,20 @@ using MemsetImplementations = testing::TypeList< builtin::Memset<64>, #endif #ifdef LLVM_LIBC_HAS_UINT64 - generic::Memset<8, 8>, // - generic::Memset<16, 8>, // - generic::Memset<32, 8>, // - generic::Memset<64, 8>, // + generic::Memset, generic::Memset>, #endif #ifdef __AVX512F__ - generic::Memset<64, 64>, // prevents warning about avx512f + generic::Memset, generic::Memset>, #endif - generic::Memset<1, 1>, // - generic::Memset<2, 1>, // - generic::Memset<2, 2>, // - generic::Memset<4, 2>, // - generic::Memset<4, 4>, // - generic::Memset<16, 16>, // - generic::Memset<32, 16>, // - generic::Memset<64, 16>, // - generic::Memset<32, 32>, // - generic::Memset<64, 32> // +#ifdef __AVX__ + generic::Memset, generic::Memset>, +#endif +#ifdef __SSE2__ + generic::Memset, generic::Memset>, +#endif + generic::Memset, generic::Memset>, // + generic::Memset, generic::Memset>, // + generic::Memset, generic::Memset> // >; // Adapt CheckMemset signature to op implementation signatures.