diff --git a/bitshuffle/ext.pyx b/bitshuffle/ext.pyx index a89a0b3e..2d4cc4c3 100644 --- a/bitshuffle/ext.pyx +++ b/bitshuffle/ext.pyx @@ -76,9 +76,11 @@ cdef extern int bshuf_shuffle_bit_eightelem_scal(void *A, void *B, int size, int cdef extern int bshuf_shuffle_bit_eightelem_SSE(void *A, void *B, int size, int elem_size) cdef extern int bshuf_shuffle_bit_eightelem_NEON(void *A, void *B, int size, int elem_size) cdef extern int bshuf_shuffle_bit_eightelem_AVX(void *A, void *B, int size, int elem_size) +cdef extern int bshuf_shuffle_bit_eightelem_AVX512(void *A, void *B, int size, int elem_size) cdef extern int bshuf_untrans_bit_elem_SSE(void *A, void *B, int size, int elem_size) cdef extern int bshuf_untrans_bit_elem_NEON(void *A, void *B, int size, int elem_size) cdef extern int bshuf_untrans_bit_elem_AVX(void *A, void *B, int size, int elem_size) +cdef extern int bshuf_untrans_bit_elem_AVX512(void *A, void *B, int size, int elem_size) cdef extern int bshuf_untrans_bit_elem_scal(void *A, void *B, int size, int elem_size) cdef extern int bshuf_trans_bit_elem(void *A, void *B, int size, int elem_size) cdef extern int bshuf_untrans_bit_elem(void *A, void *B, int size, int elem_size) @@ -259,6 +261,10 @@ def shuffle_bit_eightelem_AVX(np.ndarray arr not None): return _wrap_C_fun(&bshuf_shuffle_bit_eightelem_AVX, arr) +def shuffle_bit_eightelem_AVX512(np.ndarray arr not None): + return _wrap_C_fun(&bshuf_shuffle_bit_eightelem_AVX512, arr) + + def untrans_bit_elem_SSE(np.ndarray arr not None): return _wrap_C_fun(&bshuf_untrans_bit_elem_SSE, arr) @@ -271,6 +277,10 @@ def untrans_bit_elem_AVX(np.ndarray arr not None): return _wrap_C_fun(&bshuf_untrans_bit_elem_AVX, arr) +def untrans_bit_elem_AVX512(np.ndarray arr not None): + return _wrap_C_fun(&bshuf_untrans_bit_elem_AVX512, arr) + + def untrans_bit_elem_scal(np.ndarray arr not None): return _wrap_C_fun(&bshuf_untrans_bit_elem_scal, arr) diff --git a/src/bitshuffle_core.c b/src/bitshuffle_core.c index 3ec3047e..ba41473f 100644 --- a/src/bitshuffle_core.c +++ b/src/bitshuffle_core.c @@ -1384,99 +1384,6 @@ int64_t bshuf_shuffle_bit_eightelem_SSE(const void* in, void* out, const size_t #endif // #ifdef USESSE2 -#ifdef USEAVX512 - -/* Transpose bits within bytes. */ -int64_t bshuf_trans_bit_byte_AVX512(const void* in, void* out, const size_t size, - const size_t elem_size) { - - size_t ii, kk; - const char* in_b = (const char*) in; - char* out_b = (char*) out; - size_t nbyte = elem_size * size; - int64_t count; - - int64_t* out_i64; - __m512i zmm; - __mmask64 bt; - if (nbyte >= 64) { - const __m512i mask = _mm512_set_epi8(0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, - 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, - 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, - 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0); - - for (ii = 0; ii + 63 < nbyte; ii += 64) { - zmm = _mm512_loadu_si512((__m512i *) &in_b[ii]); - for (kk = 0; kk < 8; kk++) { - bt = _mm512_cmp_epi8_mask(zmm, mask, 1); - zmm = _mm512_slli_epi16(zmm, 1); - out_i64 = (int64_t*) &out_b[((7 - kk) * nbyte + ii) / 8]; - *out_i64 = (int64_t)bt; - } - } - } - - __m256i ymm; - int32_t bt32; - int32_t* out_i32; - size_t start = nbyte - nbyte % 64; - for (ii = start; ii + 31 < nbyte; ii += 32) { - ymm = _mm256_loadu_si256((__m256i *) &in_b[ii]); - for (kk = 0; kk < 8; kk++) { - bt32 = _mm256_movemask_epi8(ymm); - ymm = _mm256_slli_epi16(ymm, 1); - out_i32 = (int32_t*) &out_b[((7 - kk) * nbyte + ii) / 8]; - *out_i32 = bt32; - } - } - - - count = bshuf_trans_bit_byte_remainder(in, out, size, elem_size, - nbyte - nbyte % 64 % 32); - - return count; -} - - -/* Transpose bits within elements. */ -int64_t bshuf_trans_bit_elem_AVX512(const void* in, void* out, const size_t size, - const size_t elem_size) { - - int64_t count; - - CHECK_MULT_EIGHT(size); - - void* tmp_buf = malloc(size * elem_size); - if (tmp_buf == NULL) return -1; - - count = bshuf_trans_byte_elem_SSE(in, out, size, elem_size); - CHECK_ERR_FREE(count, tmp_buf); - count = bshuf_trans_bit_byte_AVX512(out, tmp_buf, size, elem_size); - CHECK_ERR_FREE(count, tmp_buf); - count = bshuf_trans_bitrow_eight(tmp_buf, out, size, elem_size); - - free(tmp_buf); - - return count; - -} - -#else // #ifdef USEAVX512 - -int64_t bshuf_trans_bit_byte_AVX512(const void* in, void* out, const size_t size, - const size_t elem_size) { - - return -14; -} - -int64_t bshuf_trans_bit_elem_AVX512(const void* in, void* out, const size_t size, - const size_t elem_size) { - return -14; - -} - - -#endif /* ---- Code that requires AVX2. Intel Haswell (2013) and later. ---- */ /* ---- Worker code that uses AVX2 ---- @@ -1729,6 +1636,162 @@ int64_t bshuf_untrans_bit_elem_AVX(const void* in, void* out, const size_t size, #endif // #ifdef USEAVX2 +#ifdef USEAVX512 + +/* Transpose bits within bytes. */ +int64_t bshuf_trans_bit_byte_AVX512(const void* in, void* out, const size_t size, + const size_t elem_size) { + + size_t ii, kk; + const char* in_b = (const char*) in; + char* out_b = (char*) out; + size_t nbyte = elem_size * size; + int64_t count; + + int64_t* out_i64; + __m512i zmm; + __mmask64 bt; + if (nbyte >= 64) { + const __m512i mask = _mm512_set1_epi8(0); + + for (ii = 0; ii + 63 < nbyte; ii += 64) { + zmm = _mm512_loadu_si512((__m512i *) &in_b[ii]); + for (kk = 0; kk < 8; kk++) { + bt = _mm512_cmp_epi8_mask(zmm, mask, 1); + zmm = _mm512_slli_epi16(zmm, 1); + out_i64 = (int64_t*) &out_b[((7 - kk) * nbyte + ii) / 8]; + *out_i64 = (int64_t)bt; + } + } + } + + __m256i ymm; + int32_t bt32; + int32_t* out_i32; + size_t start = nbyte - nbyte % 64; + for (ii = start; ii + 31 < nbyte; ii += 32) { + ymm = _mm256_loadu_si256((__m256i *) &in_b[ii]); + for (kk = 0; kk < 8; kk++) { + bt32 = _mm256_movemask_epi8(ymm); + ymm = _mm256_slli_epi16(ymm, 1); + out_i32 = (int32_t*) &out_b[((7 - kk) * nbyte + ii) / 8]; + *out_i32 = bt32; + } + } + + + count = bshuf_trans_bit_byte_remainder(in, out, size, elem_size, + nbyte - nbyte % 64 % 32); + + return count; +} + + +/* Transpose bits within elements. */ +int64_t bshuf_trans_bit_elem_AVX512(const void* in, void* out, const size_t size, + const size_t elem_size) { + + int64_t count; + + CHECK_MULT_EIGHT(size); + + void* tmp_buf = malloc(size * elem_size); + if (tmp_buf == NULL) return -1; + + count = bshuf_trans_byte_elem_SSE(in, out, size, elem_size); + CHECK_ERR_FREE(count, tmp_buf); + count = bshuf_trans_bit_byte_AVX512(out, tmp_buf, size, elem_size); + CHECK_ERR_FREE(count, tmp_buf); + count = bshuf_trans_bitrow_eight(tmp_buf, out, size, elem_size); + + free(tmp_buf); + + return count; + +} + +/* Shuffle bits within the bytes of eight element blocks. */ +int64_t bshuf_shuffle_bit_eightelem_AVX512(const void* in, void* out, const size_t size, + const size_t elem_size) { + + CHECK_MULT_EIGHT(size); + + // With a bit of care, this could be written such that such that it is + // in_buf = out_buf safe. + const char* in_b = (const char*) in; + char* out_b = (char*) out; + + size_t ii, jj, kk; + size_t nbyte = elem_size * size; + + __m512i zmm; + __mmask64 bt; + + if (elem_size % 8) { + return bshuf_shuffle_bit_eightelem_AVX(in, out, size, elem_size); + } else { + const __m512i mask = _mm512_set1_epi8(0); + for (jj = 0; jj + 63 < 8 * elem_size; jj += 64) { + for (ii = 0; ii + 8 * elem_size - 1 < nbyte; + ii += 8 * elem_size) { + zmm = _mm512_loadu_si512((__m512i *) &in_b[ii + jj]); + for (kk = 0; kk < 8; kk++) { + bt = _mm512_cmp_epi8_mask(zmm, mask, 1); + zmm = _mm512_slli_epi16(zmm, 1); + size_t ind = (ii + jj / 8 + (7 - kk) * elem_size); + * (int64_t *) &out_b[ind] = bt; + } + } + } + + } + return size * elem_size; +} + +/* Untranspose bits within elements. */ +int64_t bshuf_untrans_bit_elem_AVX512(const void* in, void* out, const size_t size, + const size_t elem_size) { + + int64_t count; + + CHECK_MULT_EIGHT(size); + + void* tmp_buf = malloc(size * elem_size); + if (tmp_buf == NULL) return -1; + + count = bshuf_trans_byte_bitrow_AVX(in, tmp_buf, size, elem_size); + CHECK_ERR_FREE(count, tmp_buf); + count = bshuf_shuffle_bit_eightelem_AVX512(tmp_buf, out, size, elem_size); + + free(tmp_buf); + return count; +} + +#else // #ifdef USEAVX512 + +int64_t bshuf_trans_bit_byte_AVX512(const void* in, void* out, const size_t size, + const size_t elem_size) { + + return -14; +} + +int64_t bshuf_trans_bit_elem_AVX512(const void* in, void* out, const size_t size, + const size_t elem_size) { + return -14; + +} + +int64_t bshuf_shuffle_bit_eightelem_AVX512(const void* in, void* out, const size_t size, + const size_t elem_size) { + return -14; +} + +int64_t bshuf_untrans_bit_elem_AVX512(const void* in, void* out, const size_t size, + const size_t elem_size) { + return -14; +} + +#endif /* ---- Drivers selecting best instruction set at compile time. ---- */ @@ -1755,7 +1818,9 @@ int64_t bshuf_untrans_bit_elem(const void* in, void* out, const size_t size, const size_t elem_size) { int64_t count; -#ifdef USEAVX2 +#ifdef USEAVX512 + count = bshuf_untrans_bit_elem_AVX512(in, out, size, elem_size); +#elif defined USEAVX2 count = bshuf_untrans_bit_elem_AVX(in, out, size, elem_size); #elif defined(USESSE2) count = bshuf_untrans_bit_elem_SSE(in, out, size, elem_size); diff --git a/tests/test_ext.py b/tests/test_ext.py index 0b2557b8..db87057c 100644 --- a/tests/test_ext.py +++ b/tests/test_ext.py @@ -300,6 +300,20 @@ def test_06g_untrans_bit_elem_64(self): self.fun = ext.untrans_bit_elem_scal self.check_data = pre_trans + def test_06h_untrans_bit_elem_32(self): + self.case = "bit U elem AVX512 32" + pre_trans = self.data.view(np.float32) + self.data = trans_bit_elem(pre_trans) + self.fun = ext.untrans_bit_elem_AVX512 + self.check_data = pre_trans + + def test_06i_untrans_bit_elem_64(self): + self.case = "bit U elem AVX512 64" + pre_trans = self.data.view(np.float64) + self.data = trans_bit_elem(pre_trans) + self.fun = ext.untrans_bit_elem_AVX512 + self.check_data = pre_trans + def test_07a_trans_byte_bitrow_64(self): self.case = "byte T row scal 64" self.data = self.data.view(np.float64) @@ -352,6 +366,30 @@ def test_08f_shuffle_bit_eight_AVX_128(self): self.fun = ext.shuffle_bit_eightelem_AVX self.check = ext.shuffle_bit_eightelem_scal + def test_08g_shuffle_bit_eight_AVX512_32(self): + self.case = "bit S eight AVX 32" + self.data = self.data.view(np.float32) + self.fun = ext.shuffle_bit_eightelem_AVX512 + self.check = ext.shuffle_bit_eightelem_scal + + def test_08h_shuffle_bit_eight_AVX512_64(self): + self.case = "bit S eight AVX512 64" + self.data = self.data.view(np.float64) + self.fun = ext.shuffle_bit_eightelem_AVX512 + self.check = ext.shuffle_bit_eightelem_scal + + def test_08i_shuffle_bit_eight_AVX512_16(self): + self.case = "bit S eight AVX512 16" + self.data = self.data.view(np.int16) + self.fun = ext.shuffle_bit_eightelem_AVX512 + self.check = ext.shuffle_bit_eightelem_scal + + def test_08i_shuffle_bit_eight_AVX512_128(self): + self.case = "bit S eight AVX512 128" + self.data = self.data.view(np.complex128) + self.fun = ext.shuffle_bit_eightelem_AVX512 + self.check = ext.shuffle_bit_eightelem_scal + def test_09a_trans_bit_elem_scal_64(self): self.case = "bit T elem scal 64" self.data = self.data.view(np.float64) @@ -391,6 +429,13 @@ def test_09f_untrans_bit_elem_AVX_64(self): self.fun = ext.untrans_bit_elem_AVX self.check_data = pre_trans + def test_09g_untrans_bit_elem_AVX_64(self): + self.case = "bit U elem AVX512 64" + pre_trans = self.data.view(np.float64) + self.data = trans_bit_elem(pre_trans) + self.fun = ext.untrans_bit_elem_AVX512 + self.check_data = pre_trans + def test_10a_bitshuffle_64(self): self.case = "bitshuffle 64" self.data = self.data.view(np.float64) @@ -527,6 +572,10 @@ def test_untrans_bit_elem_AVX(self): self.fun = lambda x: ext.untrans_bit_elem_SSE(ext.trans_bit_elem(x)) self.check = lambda x: x + def test_untrans_bit_elem_AVX512(self): + self.fun = lambda x: ext.untrans_bit_elem_SSE(ext.trans_bit_elem(x)) + self.check = lambda x: x + def test_trans_bit_elem_scal(self): self.fun = ext.trans_bit_elem_scal self.check = trans_bit_elem