diff --git a/Makefile b/Makefile index 3af92aad..ed5d03be 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -CXX ?= g++ +CXX = g++-12 SRCDIR = ./src TESTDIR = ./tests BENCHDIR = ./benchmarks @@ -6,11 +6,10 @@ UTILS = ./utils SRCS = $(wildcard $(SRCDIR)/*.hpp) TESTS = $(wildcard $(TESTDIR)/*.cpp) TESTOBJS = $(patsubst $(TESTDIR)/%.cpp,$(TESTDIR)/%.o,$(TESTS)) -TESTOBJS := $(filter-out $(TESTDIR)/main.o ,$(TESTOBJS)) CXXFLAGS += -I$(SRCDIR) -I$(UTILS) -GTESTCFLAGS = `pkg-config --cflags gtest` -GTESTLDFLAGS = `pkg-config --libs gtest` -MARCHFLAG = -march=icelake-client -O3 +GTESTCFLAGS = `pkg-config --cflags gtest_main` +GTESTLDFLAGS = `pkg-config --libs gtest_main` +MARCHFLAG = -march=sapphirerapids -O3 all : test bench @@ -20,11 +19,15 @@ $(UTILS)/cpuinfo.o : $(UTILS)/cpuinfo.cpp $(TESTDIR)/%.o : $(TESTDIR)/%.cpp $(SRCS) $(CXX) $(CXXFLAGS) $(MARCHFLAG) $(GTESTCFLAGS) -c $< -o $@ -test: $(TESTDIR)/main.cpp $(TESTOBJS) $(UTILS)/cpuinfo.o $(SRCS) - $(CXX) tests/main.cpp $(TESTOBJS) $(UTILS)/cpuinfo.o $(MARCHFLAG) $(CXXFLAGS) $(GTESTLDFLAGS) -o testexe +test: $(TESTOBJS) $(UTILS)/cpuinfo.o $(SRCS) + $(CXX) $(TESTOBJS) $(UTILS)/cpuinfo.o $(MARCHFLAG) $(CXXFLAGS) -lgtest_main $(GTESTLDFLAGS) -o testexe bench: $(BENCHDIR)/main.cpp $(SRCS) $(UTILS)/cpuinfo.o $(CXX) $(BENCHDIR)/main.cpp $(CXXFLAGS) $(UTILS)/cpuinfo.o $(MARCHFLAG) -o benchexe +meson: + meson setup --warnlevel 0 --buildtype plain builddir + cd builddir && ninja + clean: - rm -f $(TESTDIR)/*.o testexe benchexe + $(RM) -rf $(TESTDIR)/*.o $(UTILS)/*.o testexe benchexe builddir diff --git a/meson.build b/meson.build index 3141d3e1..45dfd255 100644 --- a/meson.build +++ b/meson.build @@ -1,32 +1,24 @@ -project('x86-simd-sort', 'c', 'cpp', +project('x86-simd-sort', 'cpp', version : '1.0.0', license : 'BSD 3-clause') -cc = meson.get_compiler('c') cpp = meson.get_compiler('cpp') -src = include_directories('./src') -bench = include_directories('./benchmarks') -utils = include_directories('./utils') -tests = include_directories('./tests') -gtest_dep = dependency('gtest', fallback : ['gtest', 'gtest_dep']) -subdir('./tests') +src = include_directories('src') +bench = include_directories('benchmarks') +utils = include_directories('utils') +tests = include_directories('tests') +gtest_dep = dependency('gtest_main', required : true) +subdir('utils') +subdir('tests') -testexe = executable('testexe', 'tests/main.cpp', +testexe = executable('testexe', + include_directories : [src, utils], dependencies : gtest_dep, - link_whole : [ - libtests, - ] - ) + link_whole : [libtests, libcpuinfo] + ) benchexe = executable('benchexe', 'benchmarks/main.cpp', - include_directories : [ - src, - utils, - bench, - ], - cpp_args : [ - '-O3', - '-march=icelake-client', - ], - dependencies : [], - link_whole : [], + include_directories : [src, utils, bench], + cpp_args : [ '-O3', '-march=icelake-client' ], + dependencies : [], + link_whole : [libcpuinfo], ) diff --git a/src/avx512-16bit-common.h b/src/avx512-16bit-common.h new file mode 100644 index 00000000..7ab22123 --- /dev/null +++ b/src/avx512-16bit-common.h @@ -0,0 +1,292 @@ +/******************************************************************* + * Copyright (C) 2022 Intel Corporation + * SPDX-License-Identifier: BSD-3-Clause + * Authors: Raghuveer Devulapalli + * ****************************************************************/ + +#ifndef AVX512_16BIT_COMMON +#define AVX512_16BIT_COMMON + +#include "avx512-common-qsort.h" + +/* + * Constants used in sorting 32 elements in a ZMM registers. Based on Bitonic + * sorting network (see + * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) + */ +// ZMM register: 31,30,29,28,27,26,25,24,23,22,21,20,19,18,17,16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0 +static const uint16_t network[6][32] + = {{7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8, + 23, 22, 21, 20, 19, 18, 17, 16, 31, 30, 29, 28, 27, 26, 25, 24}, + {15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, + 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16}, + {4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11, + 20, 21, 22, 23, 16, 17, 18, 19, 28, 29, 30, 31, 24, 25, 26, 27}, + {31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, + 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}, + {8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 24, 25, 26, 27, 28, 29, 30, 31, 16, 17, 18, 19, 20, 21, 22, 23}, + {16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}; + +/* + * Assumes zmm is random and performs a full sorting network defined in + * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg + */ +template +X86_SIMD_SORT_INLINE zmm_t sort_zmm_16bit(zmm_t zmm) +{ + // Level 1 + zmm = cmp_merge( + zmm, + vtype::template shuffle(zmm), + 0xAAAAAAAA); + // Level 2 + zmm = cmp_merge( + zmm, + vtype::template shuffle(zmm), + 0xCCCCCCCC); + zmm = cmp_merge( + zmm, + vtype::template shuffle(zmm), + 0xAAAAAAAA); + // Level 3 + zmm = cmp_merge( + zmm, vtype::permutexvar(vtype::get_network(1), zmm), 0xF0F0F0F0); + zmm = cmp_merge( + zmm, + vtype::template shuffle(zmm), + 0xCCCCCCCC); + zmm = cmp_merge( + zmm, + vtype::template shuffle(zmm), + 0xAAAAAAAA); + // Level 4 + zmm = cmp_merge( + zmm, vtype::permutexvar(vtype::get_network(2), zmm), 0xFF00FF00); + zmm = cmp_merge( + zmm, vtype::permutexvar(vtype::get_network(3), zmm), 0xF0F0F0F0); + zmm = cmp_merge( + zmm, + vtype::template shuffle(zmm), + 0xCCCCCCCC); + zmm = cmp_merge( + zmm, + vtype::template shuffle(zmm), + 0xAAAAAAAA); + // Level 5 + zmm = cmp_merge( + zmm, vtype::permutexvar(vtype::get_network(4), zmm), 0xFFFF0000); + zmm = cmp_merge( + zmm, vtype::permutexvar(vtype::get_network(5), zmm), 0xFF00FF00); + zmm = cmp_merge( + zmm, vtype::permutexvar(vtype::get_network(3), zmm), 0xF0F0F0F0); + zmm = cmp_merge( + zmm, + vtype::template shuffle(zmm), + 0xCCCCCCCC); + zmm = cmp_merge( + zmm, + vtype::template shuffle(zmm), + 0xAAAAAAAA); + return zmm; +} + +// Assumes zmm is bitonic and performs a recursive half cleaner +template +X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_16bit(zmm_t zmm) +{ + // 1) half_cleaner[32]: compare 1-17, 2-18, 3-19 etc .. + zmm = cmp_merge( + zmm, vtype::permutexvar(vtype::get_network(6), zmm), 0xFFFF0000); + // 2) half_cleaner[16]: compare 1-9, 2-10, 3-11 etc .. + zmm = cmp_merge( + zmm, vtype::permutexvar(vtype::get_network(5), zmm), 0xFF00FF00); + // 3) half_cleaner[8] + zmm = cmp_merge( + zmm, vtype::permutexvar(vtype::get_network(3), zmm), 0xF0F0F0F0); + // 3) half_cleaner[4] + zmm = cmp_merge( + zmm, + vtype::template shuffle(zmm), + 0xCCCCCCCC); + // 3) half_cleaner[2] + zmm = cmp_merge( + zmm, + vtype::template shuffle(zmm), + 0xAAAAAAAA); + return zmm; +} + +// Assumes zmm1 and zmm2 are sorted and performs a recursive half cleaner +template +X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_16bit(zmm_t &zmm1, zmm_t &zmm2) +{ + // 1) First step of a merging network: coex of zmm1 and zmm2 reversed + zmm2 = vtype::permutexvar(vtype::get_network(4), zmm2); + zmm_t zmm3 = vtype::min(zmm1, zmm2); + zmm_t zmm4 = vtype::max(zmm1, zmm2); + // 2) Recursive half cleaner for each + zmm1 = bitonic_merge_zmm_16bit(zmm3); + zmm2 = bitonic_merge_zmm_16bit(zmm4); +} + +// Assumes [zmm0, zmm1] and [zmm2, zmm3] are sorted and performs a recursive +// half cleaner +template +X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_16bit(zmm_t *zmm) +{ + zmm_t zmm2r = vtype::permutexvar(vtype::get_network(4), zmm[2]); + zmm_t zmm3r = vtype::permutexvar(vtype::get_network(4), zmm[3]); + zmm_t zmm_t1 = vtype::min(zmm[0], zmm3r); + zmm_t zmm_t2 = vtype::min(zmm[1], zmm2r); + zmm_t zmm_t3 = vtype::permutexvar(vtype::get_network(4), + vtype::max(zmm[1], zmm2r)); + zmm_t zmm_t4 = vtype::permutexvar(vtype::get_network(4), + vtype::max(zmm[0], zmm3r)); + zmm_t zmm0 = vtype::min(zmm_t1, zmm_t2); + zmm_t zmm1 = vtype::max(zmm_t1, zmm_t2); + zmm_t zmm2 = vtype::min(zmm_t3, zmm_t4); + zmm_t zmm3 = vtype::max(zmm_t3, zmm_t4); + zmm[0] = bitonic_merge_zmm_16bit(zmm0); + zmm[1] = bitonic_merge_zmm_16bit(zmm1); + zmm[2] = bitonic_merge_zmm_16bit(zmm2); + zmm[3] = bitonic_merge_zmm_16bit(zmm3); +} + +template +X86_SIMD_SORT_INLINE void sort_32_16bit(type_t *arr, int32_t N) +{ + typename vtype::opmask_t load_mask = ((0x1ull << N) - 0x1ull) & 0xFFFFFFFF; + typename vtype::zmm_t zmm + = vtype::mask_loadu(vtype::zmm_max(), load_mask, arr); + vtype::mask_storeu(arr, load_mask, sort_zmm_16bit(zmm)); +} + +template +X86_SIMD_SORT_INLINE void sort_64_16bit(type_t *arr, int32_t N) +{ + if (N <= 32) { + sort_32_16bit(arr, N); + return; + } + using zmm_t = typename vtype::zmm_t; + typename vtype::opmask_t load_mask + = ((0x1ull << (N - 32)) - 0x1ull) & 0xFFFFFFFF; + zmm_t zmm1 = vtype::loadu(arr); + zmm_t zmm2 = vtype::mask_loadu(vtype::zmm_max(), load_mask, arr + 32); + zmm1 = sort_zmm_16bit(zmm1); + zmm2 = sort_zmm_16bit(zmm2); + bitonic_merge_two_zmm_16bit(zmm1, zmm2); + vtype::storeu(arr, zmm1); + vtype::mask_storeu(arr + 32, load_mask, zmm2); +} + +template +X86_SIMD_SORT_INLINE void sort_128_16bit(type_t *arr, int32_t N) +{ + if (N <= 64) { + sort_64_16bit(arr, N); + return; + } + using zmm_t = typename vtype::zmm_t; + using opmask_t = typename vtype::opmask_t; + zmm_t zmm[4]; + zmm[0] = vtype::loadu(arr); + zmm[1] = vtype::loadu(arr + 32); + opmask_t load_mask1 = 0xFFFFFFFF, load_mask2 = 0xFFFFFFFF; + if (N != 128) { + uint64_t combined_mask = (0x1ull << (N - 64)) - 0x1ull; + load_mask1 = combined_mask & 0xFFFFFFFF; + load_mask2 = (combined_mask >> 32) & 0xFFFFFFFF; + } + zmm[2] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 64); + zmm[3] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 96); + zmm[0] = sort_zmm_16bit(zmm[0]); + zmm[1] = sort_zmm_16bit(zmm[1]); + zmm[2] = sort_zmm_16bit(zmm[2]); + zmm[3] = sort_zmm_16bit(zmm[3]); + bitonic_merge_two_zmm_16bit(zmm[0], zmm[1]); + bitonic_merge_two_zmm_16bit(zmm[2], zmm[3]); + bitonic_merge_four_zmm_16bit(zmm); + vtype::storeu(arr, zmm[0]); + vtype::storeu(arr + 32, zmm[1]); + vtype::mask_storeu(arr + 64, load_mask1, zmm[2]); + vtype::mask_storeu(arr + 96, load_mask2, zmm[3]); +} + +template +X86_SIMD_SORT_INLINE type_t get_pivot_16bit(type_t *arr, + const int64_t left, + const int64_t right) +{ + // median of 32 + int64_t size = (right - left) / 32; + type_t vec_arr[32] = {arr[left], + arr[left + size], + arr[left + 2 * size], + arr[left + 3 * size], + arr[left + 4 * size], + arr[left + 5 * size], + arr[left + 6 * size], + arr[left + 7 * size], + arr[left + 8 * size], + arr[left + 9 * size], + arr[left + 10 * size], + arr[left + 11 * size], + arr[left + 12 * size], + arr[left + 13 * size], + arr[left + 14 * size], + arr[left + 15 * size], + arr[left + 16 * size], + arr[left + 17 * size], + arr[left + 18 * size], + arr[left + 19 * size], + arr[left + 20 * size], + arr[left + 21 * size], + arr[left + 22 * size], + arr[left + 23 * size], + arr[left + 24 * size], + arr[left + 25 * size], + arr[left + 26 * size], + arr[left + 27 * size], + arr[left + 28 * size], + arr[left + 29 * size], + arr[left + 30 * size], + arr[left + 31 * size]}; + typename vtype::zmm_t rand_vec = vtype::loadu(vec_arr); + typename vtype::zmm_t sort = sort_zmm_16bit(rand_vec); + return ((type_t *)&sort)[16]; +} + +template +static void +qsort_16bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters) +{ + /* + * Resort to std::sort if quicksort isnt making any progress + */ + if (max_iters <= 0) { + std::sort(arr + left, arr + right + 1, comparison_func); + return; + } + /* + * Base case: use bitonic networks to sort arrays <= 128 + */ + if (right + 1 - left <= 128) { + sort_128_16bit(arr + left, (int32_t)(right + 1 - left)); + return; + } + + type_t pivot = get_pivot_16bit(arr, left, right); + type_t smallest = vtype::type_max(); + type_t biggest = vtype::type_min(); + int64_t pivot_index = partition_avx512( + arr, left, right + 1, pivot, &smallest, &biggest); + if (pivot != smallest) + qsort_16bit_(arr, left, pivot_index - 1, max_iters - 1); + if (pivot != biggest) + qsort_16bit_(arr, pivot_index, right, max_iters - 1); +} + +#endif // AVX512_16BIT_COMMON diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index b7130e2f..fcbaf879 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -7,27 +7,7 @@ #ifndef AVX512_QSORT_16BIT #define AVX512_QSORT_16BIT -#include "avx512-common-qsort.h" - -/* - * Constants used in sorting 32 elements in a ZMM registers. Based on Bitonic - * sorting network (see - * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) - */ -// ZMM register: 31,30,29,28,27,26,25,24,23,22,21,20,19,18,17,16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0 -static const uint16_t network[6][32] - = {{7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8, - 23, 22, 21, 20, 19, 18, 17, 16, 31, 30, 29, 28, 27, 26, 25, 24}, - {15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, - 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16}, - {4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11, - 20, 21, 22, 23, 16, 17, 18, 19, 28, 29, 30, 31, 24, 25, 26, 27}, - {31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, - 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}, - {8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, - 24, 25, 26, 27, 28, 29, 30, 31, 16, 17, 18, 19, 20, 21, 22, 23}, - {16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}; +#include "avx512-16bit-common.h" struct float16 { uint16_t val; @@ -369,236 +349,6 @@ struct zmm_vector { } }; -/* - * Assumes zmm is random and performs a full sorting network defined in - * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg - */ -template -X86_SIMD_SORT_INLINE zmm_t sort_zmm_16bit(zmm_t zmm) -{ - // Level 1 - zmm = cmp_merge( - zmm, - vtype::template shuffle(zmm), - 0xAAAAAAAA); - // Level 2 - zmm = cmp_merge( - zmm, - vtype::template shuffle(zmm), - 0xCCCCCCCC); - zmm = cmp_merge( - zmm, - vtype::template shuffle(zmm), - 0xAAAAAAAA); - // Level 3 - zmm = cmp_merge( - zmm, vtype::permutexvar(vtype::get_network(1), zmm), 0xF0F0F0F0); - zmm = cmp_merge( - zmm, - vtype::template shuffle(zmm), - 0xCCCCCCCC); - zmm = cmp_merge( - zmm, - vtype::template shuffle(zmm), - 0xAAAAAAAA); - // Level 4 - zmm = cmp_merge( - zmm, vtype::permutexvar(vtype::get_network(2), zmm), 0xFF00FF00); - zmm = cmp_merge( - zmm, vtype::permutexvar(vtype::get_network(3), zmm), 0xF0F0F0F0); - zmm = cmp_merge( - zmm, - vtype::template shuffle(zmm), - 0xCCCCCCCC); - zmm = cmp_merge( - zmm, - vtype::template shuffle(zmm), - 0xAAAAAAAA); - // Level 5 - zmm = cmp_merge( - zmm, vtype::permutexvar(vtype::get_network(4), zmm), 0xFFFF0000); - zmm = cmp_merge( - zmm, vtype::permutexvar(vtype::get_network(5), zmm), 0xFF00FF00); - zmm = cmp_merge( - zmm, vtype::permutexvar(vtype::get_network(3), zmm), 0xF0F0F0F0); - zmm = cmp_merge( - zmm, - vtype::template shuffle(zmm), - 0xCCCCCCCC); - zmm = cmp_merge( - zmm, - vtype::template shuffle(zmm), - 0xAAAAAAAA); - return zmm; -} - -// Assumes zmm is bitonic and performs a recursive half cleaner -template -X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_16bit(zmm_t zmm) -{ - // 1) half_cleaner[32]: compare 1-17, 2-18, 3-19 etc .. - zmm = cmp_merge( - zmm, vtype::permutexvar(vtype::get_network(6), zmm), 0xFFFF0000); - // 2) half_cleaner[16]: compare 1-9, 2-10, 3-11 etc .. - zmm = cmp_merge( - zmm, vtype::permutexvar(vtype::get_network(5), zmm), 0xFF00FF00); - // 3) half_cleaner[8] - zmm = cmp_merge( - zmm, vtype::permutexvar(vtype::get_network(3), zmm), 0xF0F0F0F0); - // 3) half_cleaner[4] - zmm = cmp_merge( - zmm, - vtype::template shuffle(zmm), - 0xCCCCCCCC); - // 3) half_cleaner[2] - zmm = cmp_merge( - zmm, - vtype::template shuffle(zmm), - 0xAAAAAAAA); - return zmm; -} - -// Assumes zmm1 and zmm2 are sorted and performs a recursive half cleaner -template -X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_16bit(zmm_t &zmm1, zmm_t &zmm2) -{ - // 1) First step of a merging network: coex of zmm1 and zmm2 reversed - zmm2 = vtype::permutexvar(vtype::get_network(4), zmm2); - zmm_t zmm3 = vtype::min(zmm1, zmm2); - zmm_t zmm4 = vtype::max(zmm1, zmm2); - // 2) Recursive half cleaner for each - zmm1 = bitonic_merge_zmm_16bit(zmm3); - zmm2 = bitonic_merge_zmm_16bit(zmm4); -} - -// Assumes [zmm0, zmm1] and [zmm2, zmm3] are sorted and performs a recursive -// half cleaner -template -X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_16bit(zmm_t *zmm) -{ - zmm_t zmm2r = vtype::permutexvar(vtype::get_network(4), zmm[2]); - zmm_t zmm3r = vtype::permutexvar(vtype::get_network(4), zmm[3]); - zmm_t zmm_t1 = vtype::min(zmm[0], zmm3r); - zmm_t zmm_t2 = vtype::min(zmm[1], zmm2r); - zmm_t zmm_t3 = vtype::permutexvar(vtype::get_network(4), - vtype::max(zmm[1], zmm2r)); - zmm_t zmm_t4 = vtype::permutexvar(vtype::get_network(4), - vtype::max(zmm[0], zmm3r)); - zmm_t zmm0 = vtype::min(zmm_t1, zmm_t2); - zmm_t zmm1 = vtype::max(zmm_t1, zmm_t2); - zmm_t zmm2 = vtype::min(zmm_t3, zmm_t4); - zmm_t zmm3 = vtype::max(zmm_t3, zmm_t4); - zmm[0] = bitonic_merge_zmm_16bit(zmm0); - zmm[1] = bitonic_merge_zmm_16bit(zmm1); - zmm[2] = bitonic_merge_zmm_16bit(zmm2); - zmm[3] = bitonic_merge_zmm_16bit(zmm3); -} - -template -X86_SIMD_SORT_INLINE void sort_32_16bit(type_t *arr, int32_t N) -{ - typename vtype::opmask_t load_mask = ((0x1ull << N) - 0x1ull) & 0xFFFFFFFF; - typename vtype::zmm_t zmm - = vtype::mask_loadu(vtype::zmm_max(), load_mask, arr); - vtype::mask_storeu(arr, load_mask, sort_zmm_16bit(zmm)); -} - -template -X86_SIMD_SORT_INLINE void sort_64_16bit(type_t *arr, int32_t N) -{ - if (N <= 32) { - sort_32_16bit(arr, N); - return; - } - using zmm_t = typename vtype::zmm_t; - typename vtype::opmask_t load_mask - = ((0x1ull << (N - 32)) - 0x1ull) & 0xFFFFFFFF; - zmm_t zmm1 = vtype::loadu(arr); - zmm_t zmm2 = vtype::mask_loadu(vtype::zmm_max(), load_mask, arr + 32); - zmm1 = sort_zmm_16bit(zmm1); - zmm2 = sort_zmm_16bit(zmm2); - bitonic_merge_two_zmm_16bit(zmm1, zmm2); - vtype::storeu(arr, zmm1); - vtype::mask_storeu(arr + 32, load_mask, zmm2); -} - -template -X86_SIMD_SORT_INLINE void sort_128_16bit(type_t *arr, int32_t N) -{ - if (N <= 64) { - sort_64_16bit(arr, N); - return; - } - using zmm_t = typename vtype::zmm_t; - using opmask_t = typename vtype::opmask_t; - zmm_t zmm[4]; - zmm[0] = vtype::loadu(arr); - zmm[1] = vtype::loadu(arr + 32); - opmask_t load_mask1 = 0xFFFFFFFF, load_mask2 = 0xFFFFFFFF; - if (N != 128) { - uint64_t combined_mask = (0x1ull << (N - 64)) - 0x1ull; - load_mask1 = combined_mask & 0xFFFFFFFF; - load_mask2 = (combined_mask >> 32) & 0xFFFFFFFF; - } - zmm[2] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 64); - zmm[3] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 96); - zmm[0] = sort_zmm_16bit(zmm[0]); - zmm[1] = sort_zmm_16bit(zmm[1]); - zmm[2] = sort_zmm_16bit(zmm[2]); - zmm[3] = sort_zmm_16bit(zmm[3]); - bitonic_merge_two_zmm_16bit(zmm[0], zmm[1]); - bitonic_merge_two_zmm_16bit(zmm[2], zmm[3]); - bitonic_merge_four_zmm_16bit(zmm); - vtype::storeu(arr, zmm[0]); - vtype::storeu(arr + 32, zmm[1]); - vtype::mask_storeu(arr + 64, load_mask1, zmm[2]); - vtype::mask_storeu(arr + 96, load_mask2, zmm[3]); -} - -template -X86_SIMD_SORT_INLINE type_t get_pivot_16bit(type_t *arr, - const int64_t left, - const int64_t right) -{ - // median of 32 - int64_t size = (right - left) / 32; - type_t vec_arr[32] = {arr[left], - arr[left + size], - arr[left + 2 * size], - arr[left + 3 * size], - arr[left + 4 * size], - arr[left + 5 * size], - arr[left + 6 * size], - arr[left + 7 * size], - arr[left + 8 * size], - arr[left + 9 * size], - arr[left + 10 * size], - arr[left + 11 * size], - arr[left + 12 * size], - arr[left + 13 * size], - arr[left + 14 * size], - arr[left + 15 * size], - arr[left + 16 * size], - arr[left + 17 * size], - arr[left + 18 * size], - arr[left + 19 * size], - arr[left + 20 * size], - arr[left + 21 * size], - arr[left + 22 * size], - arr[left + 23 * size], - arr[left + 24 * size], - arr[left + 25 * size], - arr[left + 26 * size], - arr[left + 27 * size], - arr[left + 28 * size], - arr[left + 29 * size], - arr[left + 30 * size], - arr[left + 31 * size]}; - __m512i rand_vec = _mm512_loadu_si512(vec_arr); - __m512i sort = sort_zmm_16bit(rand_vec); - return ((type_t *)&sort)[16]; -} - template <> bool comparison_func>(const uint16_t &a, const uint16_t &b) { @@ -627,36 +377,6 @@ bool comparison_func>(const uint16_t &a, const uint16_t &b) //return npy_half_to_float(a) < npy_half_to_float(b); } -template -static void -qsort_16bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters) -{ - /* - * Resort to std::sort if quicksort isnt making any progress - */ - if (max_iters <= 0) { - std::sort(arr + left, arr + right + 1, comparison_func); - return; - } - /* - * Base case: use bitonic networks to sort arrays <= 128 - */ - if (right + 1 - left <= 128) { - sort_128_16bit(arr + left, (int32_t)(right + 1 - left)); - return; - } - - type_t pivot = get_pivot_16bit(arr, left, right); - type_t smallest = vtype::type_max(); - type_t biggest = vtype::type_min(); - int64_t pivot_index = partition_avx512( - arr, left, right + 1, pivot, &smallest, &biggest); - if (pivot != smallest) - qsort_16bit_(arr, left, pivot_index - 1, max_iters - 1); - if (pivot != biggest) - qsort_16bit_(arr, pivot_index, right, max_iters - 1); -} - X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(uint16_t *arr, int64_t arrsize) { diff --git a/src/avx512-common-qsort.h b/src/avx512-common-qsort.h index 1816baf3..a80f2721 100644 --- a/src/avx512-common-qsort.h +++ b/src/avx512-common-qsort.h @@ -36,6 +36,7 @@ #include #include #include +#include #include #include @@ -58,6 +59,7 @@ #define ZMM_MAX_FLOAT _mm512_set1_ps(X86_SIMD_SORT_INFINITYF) #define ZMM_MAX_UINT _mm512_set1_epi32(X86_SIMD_SORT_MAX_UINT32) #define ZMM_MAX_INT _mm512_set1_epi32(X86_SIMD_SORT_MAX_INT32) +#define ZMM_MAX_HALF _mm512_set1_epi16(X86_SIMD_SORT_INFINITYH) #define YMM_MAX_HALF _mm256_set1_epi16(X86_SIMD_SORT_INFINITYH) #define ZMM_MAX_UINT16 _mm512_set1_epi16(X86_SIMD_SORT_MAX_UINT16) #define ZMM_MAX_INT16 _mm512_set1_epi16(X86_SIMD_SORT_MAX_INT16) @@ -87,6 +89,8 @@ struct zmm_vector; template void avx512_qsort(T *arr, int64_t arrsize); +void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize); + template bool comparison_func(const T &a, const T &b) { diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp new file mode 100644 index 00000000..363d2b55 --- /dev/null +++ b/src/avx512fp16-16bit-qsort.hpp @@ -0,0 +1,157 @@ +/******************************************************************* + * Copyright (C) 2022 Intel Corporation + * SPDX-License-Identifier: BSD-3-Clause + * Authors: Raghuveer Devulapalli + * ****************************************************************/ + +#ifndef AVX512FP16_QSORT_16BIT +#define AVX512FP16_QSORT_16BIT + +#include "avx512-16bit-common.h" + +typedef union { + _Float16 f_; + uint16_t i_; +} Fp16Bits; + +template <> +struct zmm_vector<_Float16> { + using type_t = _Float16; + using zmm_t = __m512h; + using ymm_t = __m256h; + using opmask_t = __mmask32; + static const uint8_t numlanes = 32; + + static __m512i get_network(int index) + { + return _mm512_loadu_si512(&network[index - 1][0]); + } + static type_t type_max() + { + Fp16Bits val; + val.i_ = X86_SIMD_SORT_INFINITYH; + return val.f_; + } + static type_t type_min() + { + Fp16Bits val; + val.i_ = X86_SIMD_SORT_NEGINFINITYH; + return val.f_; + } + static zmm_t zmm_max() + { + return _mm512_set1_ph(type_max()); + } + static opmask_t knot_opmask(opmask_t x) + { + return _knot_mask32(x); + } + + static opmask_t ge(zmm_t x, zmm_t y) + { + return _mm512_cmp_ph_mask(x, y, _CMP_GE_OQ); + } + static zmm_t loadu(void const *mem) + { + return _mm512_loadu_ph(mem); + } + static zmm_t max(zmm_t x, zmm_t y) + { + return _mm512_max_ph(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + { + __m512i temp = _mm512_castph_si512(x); + // AVX512_VBMI2 + return _mm512_mask_compressstoreu_epi16(mem, mask, temp); + } + static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + { + // AVX512BW + return _mm512_castsi512_ph( + _mm512_mask_loadu_epi16(_mm512_castph_si512(x), mask, mem)); + } + static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + { + return _mm512_castsi512_ph(_mm512_mask_mov_epi16( + _mm512_castph_si512(x), mask, _mm512_castph_si512(y))); + } + static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_storeu_epi16(mem, mask, _mm512_castph_si512(x)); + } + static zmm_t min(zmm_t x, zmm_t y) + { + return _mm512_min_ph(x, y); + } + static zmm_t permutexvar(__m512i idx, zmm_t zmm) + { + return _mm512_permutexvar_ph(idx, zmm); + } + static type_t reducemax(zmm_t v) + { + return _mm512_reduce_max_ph(v); + } + static type_t reducemin(zmm_t v) + { + return _mm512_reduce_min_ph(v); + } + static zmm_t set1(type_t v) + { + return _mm512_set1_ph(v); + } + template + static zmm_t shuffle(zmm_t zmm) + { + __m512i temp = _mm512_shufflehi_epi16(_mm512_castph_si512(zmm), + (_MM_PERM_ENUM)mask); + return _mm512_castsi512_ph( + _mm512_shufflelo_epi16(temp, (_MM_PERM_ENUM)mask)); + } + static void storeu(void *mem, zmm_t x) + { + return _mm512_storeu_ph(mem, x); + } +}; + +X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(_Float16 *arr, + int64_t arrsize) +{ + int64_t nan_count = 0; + __mmask32 loadmask = 0xFFFFFFFF; + __m512h in_zmm; + while (arrsize > 0) { + if (arrsize < 32) { + loadmask = (0x00000001 << arrsize) - 0x00000001; + in_zmm = _mm512_castsi512_ph( + _mm512_maskz_loadu_epi16(loadmask, arr)); + } + else { + in_zmm = _mm512_loadu_ph(arr); + } + __mmask32 nanmask = _mm512_cmp_ph_mask(in_zmm, in_zmm, _CMP_NEQ_UQ); + nan_count += _mm_popcnt_u32((int32_t)nanmask); + _mm512_mask_storeu_epi16(arr, nanmask, ZMM_MAX_HALF); + arr += 32; + arrsize -= 32; + } + return nan_count; +} + +X86_SIMD_SORT_INLINE void +replace_inf_with_nan(_Float16 *arr, int64_t arrsize, int64_t nan_count) +{ + memset(arr + arrsize - nan_count, 0xFF, nan_count * 2); +} + +template <> +void avx512_qsort(_Float16 *arr, int64_t arrsize) +{ + if (arrsize > 1) { + int64_t nan_count = replace_nan_with_inf(arr, arrsize); + qsort_16bit_, _Float16>( + arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + replace_inf_with_nan(arr, arrsize, nan_count); + } +} +#endif // AVX512FP16_QSORT_16BIT diff --git a/tests/main.cpp b/tests/main.cpp deleted file mode 100644 index 42ee7088..00000000 --- a/tests/main.cpp +++ /dev/null @@ -1,16 +0,0 @@ -/******************************************* - * * Copyright (C) 2022 Intel Corporation - * * SPDX-License-Identifier: BSD-3-Clause - * *******************************************/ - -#include "avx512-16bit-qsort.hpp" -#include "avx512-32bit-qsort.hpp" -#include "avx512-64bit-keyvaluesort.hpp" -#include "avx512-64bit-qsort.hpp" -#include - -int main(int argc, char **argv) -{ - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/tests/meson.build b/tests/meson.build index 40cd4685..dd88be98 100644 --- a/tests/meson.build +++ b/tests/meson.build @@ -1,15 +1,28 @@ libtests = [] - if cc.has_argument('-march=icelake-client') libtests - += static_library('tests_', files('test_all.cpp', ), dependencies - : gtest_dep, include_directories - : - [ - src, - utils, - ], - cpp_args - : [ - '-O3', - '-march=icelake-client', - ], ) endif +if cpp.has_argument('-march=skylake-avx512') + libtests += static_library('tests_kv', + files('test_keyvalue.cpp', ), + dependencies: gtest_dep, + include_directories : [src, utils], + cpp_args : ['-O3', '-march=skylake-avx512'], + ) +endif + +if cpp.has_argument('-march=icelake-client') + libtests += static_library('tests_qsort', + files('test_qsort.cpp', ), + dependencies: gtest_dep, + include_directories : [src, utils], + cpp_args : ['-O3', '-march=icelake-client'], + ) +endif + +if cpp.has_argument('-march=sapphirerapids') + libtests += static_library('tests_qsortfp16', + files('test_qsortfp16.cpp', ), + dependencies: gtest_dep, + include_directories : [src, utils], + cpp_args : ['-O3', '-march=sapphirerapids'], + ) +endif diff --git a/tests/test_keyvalue.cpp b/tests/test_keyvalue.cpp index 95555d7e..c978967f 100644 --- a/tests/test_keyvalue.cpp +++ b/tests/test_keyvalue.cpp @@ -3,7 +3,7 @@ * * SPDX-License-Identifier: BSD-3-Clause * *******************************************/ -#include "avx512-common-keyvaluesort.h" +#include "avx512-64bit-keyvaluesort.hpp" #include "rand_array.h" #include #include @@ -38,8 +38,8 @@ TYPED_TEST_P(TestKeyValueSort, KeyValueSort) for (size_t ii = 0; ii < keysizes.size(); ++ii) { /* Random array */ - keys = - get_uniform_rand_array_with_uniquevalues(keysizes[ii]); + keys = get_uniform_rand_array_with_uniquevalues( + keysizes[ii]); values = get_uniform_rand_array(keysizes[ii]); for (size_t i = 0; i < keys.size(); i++) { sorted_t tmp_s; diff --git a/tests/test_qsort.cpp b/tests/test_qsort.cpp index ef3aea4f..6d82a35b 100644 --- a/tests/test_qsort.cpp +++ b/tests/test_qsort.cpp @@ -3,7 +3,9 @@ * * SPDX-License-Identifier: BSD-3-Clause * *******************************************/ -#include "avx512-common-qsort.h" +#include "avx512-16bit-qsort.hpp" +#include "avx512-32bit-qsort.hpp" +#include "avx512-64bit-qsort.hpp" #include "cpuinfo.h" #include "rand_array.h" #include diff --git a/tests/test_qsortfp16.cpp b/tests/test_qsortfp16.cpp new file mode 100644 index 00000000..ab5c10fe --- /dev/null +++ b/tests/test_qsortfp16.cpp @@ -0,0 +1,74 @@ +/******************************************* + * * Copyright (C) 2022 Intel Corporation + * * SPDX-License-Identifier: BSD-3-Clause + * *******************************************/ + +#include "avx512fp16-16bit-qsort.hpp" +#include "cpuinfo.h" +#include "rand_array.h" +#include +#include + +TEST(avx512_qsort_float16, test_arrsizes) +{ + if (cpu_has_avx512fp16()) { + std::vector arrsizes; + for (int64_t ii = 0; ii < 1024; ++ii) { + arrsizes.push_back(ii); + } + std::vector<_Float16> arr; + std::vector<_Float16> sortedarr; + + for (size_t ii = 0; ii < arrsizes.size(); ++ii) { + /* Random array */ + for (size_t jj = 0; jj < arrsizes[ii]; ++jj) { + _Float16 temp = (float)rand() / (float)(RAND_MAX); + arr.push_back(temp); + sortedarr.push_back(temp); + } + /* Sort with std::sort for comparison */ + std::sort(sortedarr.begin(), sortedarr.end()); + avx512_qsort<_Float16>(arr.data(), arr.size()); + ASSERT_EQ(sortedarr, arr); + arr.clear(); + sortedarr.clear(); + } + } + else { + GTEST_SKIP() << "Skipping this test, it requires avx512fp16 ISA"; + } +} + +TEST(avx512_qsort_float16, test_special_floats) +{ + if (cpu_has_avx512fp16()) { + const int arrsize = 1111; + std::vector<_Float16> arr; + std::vector<_Float16> sortedarr; + Fp16Bits temp; + for (size_t jj = 0; jj < arrsize; ++jj) { + temp.f_ = (float)rand() / (float)(RAND_MAX); + switch (rand() % 10) { + case 0: temp.i_ = 0xFFFF; break; + case 1: temp.i_ = X86_SIMD_SORT_INFINITYH; break; + case 2: temp.i_ = X86_SIMD_SORT_NEGINFINITYH; break; + default: break; + } + arr.push_back(temp.f_); + sortedarr.push_back(temp.f_); + } + /* Cannot use std::sort because it treats NAN differently */ + avx512_qsort_fp16(reinterpret_cast(sortedarr.data()), + sortedarr.size()); + avx512_qsort<_Float16>(arr.data(), arr.size()); + // Cannot rely on ASSERT_EQ since it returns false if there are NAN's + if (memcmp(arr.data(), sortedarr.data(), arrsize * 2) != 0) { + ASSERT_EQ(sortedarr, arr); + } + arr.clear(); + sortedarr.clear(); + } + else { + GTEST_SKIP() << "Skipping this test, it requires avx512fp16 ISA"; + } +} diff --git a/utils/cpuinfo.cpp b/utils/cpuinfo.cpp index 722b82c4..c05acf34 100644 --- a/utils/cpuinfo.cpp +++ b/utils/cpuinfo.cpp @@ -29,6 +29,13 @@ int cpu_has_avx512bw() return (ebx >> 30) & 0x1; } +int cpu_has_avx512fp16() +{ + uint32_t eax(0), ebx(0), ecx(0), edx(0); + cpuid(0x07, &eax, &ebx, &ecx, &edx); + return (edx >> 23) & 0x1; +} + // TODO: //int check_os_supports_avx512() //{ diff --git a/utils/cpuinfo.h b/utils/cpuinfo.h index 1049a343..96f167e5 100644 --- a/utils/cpuinfo.h +++ b/utils/cpuinfo.h @@ -3,9 +3,11 @@ * * SPDX-License-Identifier: BSD-3-Clause * *******************************************/ -#include #include +#include int cpu_has_avx512_vbmi2(); int cpu_has_avx512bw(); + +int cpu_has_avx512fp16(); diff --git a/utils/rand_array.h b/utils/rand_array.h index 01163582..24a03ddd 100644 --- a/utils/rand_array.h +++ b/utils/rand_array.h @@ -43,14 +43,14 @@ static std::vector get_uniform_rand_array( } template -static std::vector get_uniform_rand_array_with_uniquevalues( - int64_t arrsize, - T max = std::numeric_limits::max(), - T min = std::numeric_limits::min()) +static std::vector +get_uniform_rand_array_with_uniquevalues(int64_t arrsize, + T max = std::numeric_limits::max(), + T min = std::numeric_limits::min()) { std::vector arr = get_uniform_rand_array(arrsize, max, min); - typename std::vector::iterator ip = - std::unique(arr.begin(), arr.begin() + arrsize); + typename std::vector::iterator ip + = std::unique(arr.begin(), arr.begin() + arrsize); arr.resize(std::distance(arr.begin(), ip)); return arr; }