Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve trans and untrans with AVX512 #117

Merged
merged 2 commits into from
May 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions bitshuffle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using_NEON
using_SSE2
using_AVX2
using_AVX512
bitshuffle
bitunshuffle
compress_lz4
Expand All @@ -28,6 +29,7 @@
using_NEON,
using_SSE2,
using_AVX2,
using_AVX512,
compress_lz4,
decompress_lz4,
)
Expand All @@ -49,6 +51,7 @@
"using_NEON",
"using_SSE2",
"using_AVX2",
"using_AVX512",
"compress_lz4",
"decompress_lz4",
] + zstd_api
29 changes: 29 additions & 0 deletions bitshuffle/ext.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ cdef extern from b"bitshuffle.h":
int bshuf_using_NEON()
int bshuf_using_SSE2()
int bshuf_using_AVX2()
int bshuf_using_AVX512()
int bshuf_bitshuffle(void *A, void *B, int size, int elem_size,
int block_size) nogil
int bshuf_bitunshuffle(void *A, void *B, int size, int elem_size,
Expand Down Expand Up @@ -60,7 +61,9 @@ cdef extern int bshuf_trans_bit_byte_scal(void *A, void *B, int size, int elem_s
cdef extern int bshuf_trans_bit_byte_SSE(void *A, void *B, int size, int elem_size)
cdef extern int bshuf_trans_bit_byte_NEON(void *A, void *B, int size, int elem_size)
cdef extern int bshuf_trans_bit_byte_AVX(void *A, void *B, int size, int elem_size)
cdef extern int bshuf_trans_bit_byte_AVX512(void *A, void *B, int size, int elem_size)
cdef extern int bshuf_trans_bitrow_eight(void *A, void *B, int size, int elem_size)
cdef extern int bshuf_trans_bit_elem_AVX512(void *A, void *B, int size, int elem_size)
cdef extern int bshuf_trans_bit_elem_AVX(void *A, void *B, int size, int elem_size)
cdef extern int bshuf_trans_bit_elem_SSE(void *A, void *B, int size, int elem_size)
cdef extern int bshuf_trans_bit_elem_NEON(void *A, void *B, int size, int elem_size)
Expand All @@ -73,9 +76,11 @@ cdef extern int bshuf_shuffle_bit_eightelem_scal(void *A, void *B, int size, int
cdef extern int bshuf_shuffle_bit_eightelem_SSE(void *A, void *B, int size, int elem_size)
cdef extern int bshuf_shuffle_bit_eightelem_NEON(void *A, void *B, int size, int elem_size)
cdef extern int bshuf_shuffle_bit_eightelem_AVX(void *A, void *B, int size, int elem_size)
cdef extern int bshuf_shuffle_bit_eightelem_AVX512(void *A, void *B, int size, int elem_size)
cdef extern int bshuf_untrans_bit_elem_SSE(void *A, void *B, int size, int elem_size)
cdef extern int bshuf_untrans_bit_elem_NEON(void *A, void *B, int size, int elem_size)
cdef extern int bshuf_untrans_bit_elem_AVX(void *A, void *B, int size, int elem_size)
cdef extern int bshuf_untrans_bit_elem_AVX512(void *A, void *B, int size, int elem_size)
cdef extern int bshuf_untrans_bit_elem_scal(void *A, void *B, int size, int elem_size)
cdef extern int bshuf_trans_bit_elem(void *A, void *B, int size, int elem_size)
cdef extern int bshuf_untrans_bit_elem(void *A, void *B, int size, int elem_size)
Expand Down Expand Up @@ -108,6 +113,14 @@ def using_AVX2():
return False


def using_AVX512():
"""Whether compiled using AVX512 instructions."""
if bshuf_using_AVX512():
return True
else:
return False


def _setup_arr(arr):
shape = tuple(arr.shape)
if not arr.flags['C_CONTIGUOUS']:
Expand Down Expand Up @@ -188,10 +201,18 @@ def trans_bit_byte_AVX(np.ndarray arr not None):
return _wrap_C_fun(&bshuf_trans_bit_byte_AVX, arr)


def trans_bit_byte_AVX512(np.ndarray arr not None):
return _wrap_C_fun(&bshuf_trans_bit_byte_AVX512, arr)


def trans_bitrow_eight(np.ndarray arr not None):
return _wrap_C_fun(&bshuf_trans_bitrow_eight, arr)


def trans_bit_elem_AVX512(np.ndarray arr not None):
return _wrap_C_fun(&bshuf_trans_bit_elem_AVX512, arr)


def trans_bit_elem_AVX(np.ndarray arr not None):
return _wrap_C_fun(&bshuf_trans_bit_elem_AVX, arr)

Expand Down Expand Up @@ -240,6 +261,10 @@ def shuffle_bit_eightelem_AVX(np.ndarray arr not None):
return _wrap_C_fun(&bshuf_shuffle_bit_eightelem_AVX, arr)


def shuffle_bit_eightelem_AVX512(np.ndarray arr not None):
return _wrap_C_fun(&bshuf_shuffle_bit_eightelem_AVX512, arr)


def untrans_bit_elem_SSE(np.ndarray arr not None):
return _wrap_C_fun(&bshuf_untrans_bit_elem_SSE, arr)

Expand All @@ -252,6 +277,10 @@ def untrans_bit_elem_AVX(np.ndarray arr not None):
return _wrap_C_fun(&bshuf_untrans_bit_elem_AVX, arr)


def untrans_bit_elem_AVX512(np.ndarray arr not None):
return _wrap_C_fun(&bshuf_untrans_bit_elem_AVX512, arr)


def untrans_bit_elem_scal(np.ndarray arr not None):
return _wrap_C_fun(&bshuf_untrans_bit_elem_scal, arr)

Expand Down
177 changes: 174 additions & 3 deletions src/bitshuffle_core.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
#include <string.h>


#if defined(__AVX512F__) && defined (__AVX512BW__) && defined(__AVX2__) && defined(__SSE2__)
#define USEAVX512
#endif

#if defined(__AVX2__) && defined (__SSE2__)
#define USEAVX2
#endif
Expand Down Expand Up @@ -79,6 +83,14 @@ int bshuf_using_AVX2(void) {
}


int bshuf_using_AVX512(void) {
#ifdef USEAVX512
return 1;
#else
return 0;
#endif
}

/* ---- Worker code not requiring special instruction sets. ----
*
* The following code does not use any x86 specific vectorized instructions
Expand Down Expand Up @@ -1384,7 +1396,6 @@ int64_t bshuf_shuffle_bit_eightelem_SSE(const void* in, void* out, const size_t
*/

#ifdef USEAVX2

/* Transpose bits within bytes. */
int64_t bshuf_trans_bit_byte_AVX(const void* in, void* out, const size_t size,
const size_t elem_size) {
Expand Down Expand Up @@ -1625,14 +1636,172 @@ int64_t bshuf_untrans_bit_elem_AVX(const void* in, void* out, const size_t size,

#endif // #ifdef USEAVX2

#ifdef USEAVX512

/* Transpose bits within bytes. */
int64_t bshuf_trans_bit_byte_AVX512(const void* in, void* out, const size_t size,
const size_t elem_size) {

size_t ii, kk;
const char* in_b = (const char*) in;
char* out_b = (char*) out;
size_t nbyte = elem_size * size;
int64_t count;

int64_t* out_i64;
__m512i zmm;
__mmask64 bt;
if (nbyte >= 64) {
const __m512i mask = _mm512_set1_epi8(0);

for (ii = 0; ii + 63 < nbyte; ii += 64) {
zmm = _mm512_loadu_si512((__m512i *) &in_b[ii]);
for (kk = 0; kk < 8; kk++) {
bt = _mm512_cmp_epi8_mask(zmm, mask, 1);
zmm = _mm512_slli_epi16(zmm, 1);
out_i64 = (int64_t*) &out_b[((7 - kk) * nbyte + ii) / 8];
*out_i64 = (int64_t)bt;
}
}
}

__m256i ymm;
int32_t bt32;
int32_t* out_i32;
size_t start = nbyte - nbyte % 64;
for (ii = start; ii + 31 < nbyte; ii += 32) {
ymm = _mm256_loadu_si256((__m256i *) &in_b[ii]);
for (kk = 0; kk < 8; kk++) {
bt32 = _mm256_movemask_epi8(ymm);
ymm = _mm256_slli_epi16(ymm, 1);
out_i32 = (int32_t*) &out_b[((7 - kk) * nbyte + ii) / 8];
*out_i32 = bt32;
}
}


count = bshuf_trans_bit_byte_remainder(in, out, size, elem_size,
nbyte - nbyte % 64 % 32);

return count;
}


/* Transpose bits within elements. */
int64_t bshuf_trans_bit_elem_AVX512(const void* in, void* out, const size_t size,
const size_t elem_size) {

int64_t count;

CHECK_MULT_EIGHT(size);

void* tmp_buf = malloc(size * elem_size);
if (tmp_buf == NULL) return -1;

count = bshuf_trans_byte_elem_SSE(in, out, size, elem_size);
CHECK_ERR_FREE(count, tmp_buf);
count = bshuf_trans_bit_byte_AVX512(out, tmp_buf, size, elem_size);
CHECK_ERR_FREE(count, tmp_buf);
count = bshuf_trans_bitrow_eight(tmp_buf, out, size, elem_size);

free(tmp_buf);

return count;

}

/* Shuffle bits within the bytes of eight element blocks. */
int64_t bshuf_shuffle_bit_eightelem_AVX512(const void* in, void* out, const size_t size,
const size_t elem_size) {

CHECK_MULT_EIGHT(size);

// With a bit of care, this could be written such that such that it is
// in_buf = out_buf safe.
const char* in_b = (const char*) in;
char* out_b = (char*) out;

size_t ii, jj, kk;
size_t nbyte = elem_size * size;

__m512i zmm;
__mmask64 bt;

if (elem_size % 8) {
return bshuf_shuffle_bit_eightelem_AVX(in, out, size, elem_size);
} else {
const __m512i mask = _mm512_set1_epi8(0);
for (jj = 0; jj + 63 < 8 * elem_size; jj += 64) {
for (ii = 0; ii + 8 * elem_size - 1 < nbyte;
ii += 8 * elem_size) {
zmm = _mm512_loadu_si512((__m512i *) &in_b[ii + jj]);
for (kk = 0; kk < 8; kk++) {
bt = _mm512_cmp_epi8_mask(zmm, mask, 1);
zmm = _mm512_slli_epi16(zmm, 1);
size_t ind = (ii + jj / 8 + (7 - kk) * elem_size);
* (int64_t *) &out_b[ind] = bt;
}
}
}

}
return size * elem_size;
}

/* Untranspose bits within elements. */
int64_t bshuf_untrans_bit_elem_AVX512(const void* in, void* out, const size_t size,
const size_t elem_size) {

int64_t count;

CHECK_MULT_EIGHT(size);

void* tmp_buf = malloc(size * elem_size);
if (tmp_buf == NULL) return -1;

count = bshuf_trans_byte_bitrow_AVX(in, tmp_buf, size, elem_size);
CHECK_ERR_FREE(count, tmp_buf);
count = bshuf_shuffle_bit_eightelem_AVX512(tmp_buf, out, size, elem_size);

free(tmp_buf);
return count;
}

#else // #ifdef USEAVX512

int64_t bshuf_trans_bit_byte_AVX512(const void* in, void* out, const size_t size,
const size_t elem_size) {

return -14;
}

int64_t bshuf_trans_bit_elem_AVX512(const void* in, void* out, const size_t size,
const size_t elem_size) {
return -14;

}

int64_t bshuf_shuffle_bit_eightelem_AVX512(const void* in, void* out, const size_t size,
const size_t elem_size) {
return -14;
}

int64_t bshuf_untrans_bit_elem_AVX512(const void* in, void* out, const size_t size,
const size_t elem_size) {
return -14;
}

#endif

/* ---- Drivers selecting best instruction set at compile time. ---- */

int64_t bshuf_trans_bit_elem(const void* in, void* out, const size_t size,
const size_t elem_size) {

int64_t count;
#ifdef USEAVX2
#ifdef USEAVX512
count = bshuf_trans_bit_elem_AVX512(in, out, size, elem_size);
#elif defined USEAVX2
count = bshuf_trans_bit_elem_AVX(in, out, size, elem_size);
#elif defined(USESSE2)
count = bshuf_trans_bit_elem_SSE(in, out, size, elem_size);
Expand All @@ -1649,7 +1818,9 @@ int64_t bshuf_untrans_bit_elem(const void* in, void* out, const size_t size,
const size_t elem_size) {

int64_t count;
#ifdef USEAVX2
#ifdef USEAVX512
count = bshuf_untrans_bit_elem_AVX512(in, out, size, elem_size);
#elif defined USEAVX2
count = bshuf_untrans_bit_elem_AVX(in, out, size, elem_size);
#elif defined(USESSE2)
count = bshuf_untrans_bit_elem_SSE(in, out, size, elem_size);
Expand Down
13 changes: 13 additions & 0 deletions src/bitshuffle_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
* -11 : Missing SSE.
* -12 : Missing AVX.
* -13 : Missing Arm Neon.
* -14 : Missing AVX512.
* -80 : Input size not a multiple of 8.
* -81 : block_size not multiple of 8.
* -91 : Decompression error, wrong number of bytes processed.
Expand Down Expand Up @@ -91,6 +92,18 @@ int bshuf_using_NEON(void);
int bshuf_using_AVX2(void);


/* ---- bshuf_using_AVX512 ----
*
* Whether routines where compiled with the AVX512 instruction set.
*
* Returns
* -------
* 1 if using AVX512, 0 otherwise.
*
*/
int bshuf_using_AVX512(void);


/* ---- bshuf_default_block_size ----
*
* The default block size as function of element size.
Expand Down
Loading