Skip to content

Commit

Permalink
Improve untrans with AVX512
Browse files Browse the repository at this point in the history
Signed-off-by: Wu, Kaiqiang <kaiqiang.wu@intel.com>
  • Loading branch information
HackToday committed Apr 12, 2022
1 parent b2cbc1b commit f354ace
Show file tree
Hide file tree
Showing 3 changed files with 225 additions and 94 deletions.
10 changes: 10 additions & 0 deletions bitshuffle/ext.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,11 @@ cdef extern int bshuf_shuffle_bit_eightelem_scal(void *A, void *B, int size, int
cdef extern int bshuf_shuffle_bit_eightelem_SSE(void *A, void *B, int size, int elem_size)
cdef extern int bshuf_shuffle_bit_eightelem_NEON(void *A, void *B, int size, int elem_size)
cdef extern int bshuf_shuffle_bit_eightelem_AVX(void *A, void *B, int size, int elem_size)
cdef extern int bshuf_shuffle_bit_eightelem_AVX512(void *A, void *B, int size, int elem_size)
cdef extern int bshuf_untrans_bit_elem_SSE(void *A, void *B, int size, int elem_size)
cdef extern int bshuf_untrans_bit_elem_NEON(void *A, void *B, int size, int elem_size)
cdef extern int bshuf_untrans_bit_elem_AVX(void *A, void *B, int size, int elem_size)
cdef extern int bshuf_untrans_bit_elem_AVX512(void *A, void *B, int size, int elem_size)
cdef extern int bshuf_untrans_bit_elem_scal(void *A, void *B, int size, int elem_size)
cdef extern int bshuf_trans_bit_elem(void *A, void *B, int size, int elem_size)
cdef extern int bshuf_untrans_bit_elem(void *A, void *B, int size, int elem_size)
Expand Down Expand Up @@ -259,6 +261,10 @@ def shuffle_bit_eightelem_AVX(np.ndarray arr not None):
return _wrap_C_fun(&bshuf_shuffle_bit_eightelem_AVX, arr)


def shuffle_bit_eightelem_AVX512(np.ndarray arr not None):
return _wrap_C_fun(&bshuf_shuffle_bit_eightelem_AVX512, arr)


def untrans_bit_elem_SSE(np.ndarray arr not None):
return _wrap_C_fun(&bshuf_untrans_bit_elem_SSE, arr)

Expand All @@ -271,6 +277,10 @@ def untrans_bit_elem_AVX(np.ndarray arr not None):
return _wrap_C_fun(&bshuf_untrans_bit_elem_AVX, arr)


def untrans_bit_elem_AVX512(np.ndarray arr not None):
return _wrap_C_fun(&bshuf_untrans_bit_elem_AVX512, arr)


def untrans_bit_elem_scal(np.ndarray arr not None):
return _wrap_C_fun(&bshuf_untrans_bit_elem_scal, arr)

Expand Down
260 changes: 166 additions & 94 deletions src/bitshuffle_core.c
Original file line number Diff line number Diff line change
Expand Up @@ -1384,99 +1384,6 @@ int64_t bshuf_shuffle_bit_eightelem_SSE(const void* in, void* out, const size_t
#endif // #ifdef USESSE2


#ifdef USEAVX512

/* Transpose bits within bytes. */
int64_t bshuf_trans_bit_byte_AVX512(const void* in, void* out, const size_t size,
const size_t elem_size) {

size_t ii, kk;
const char* in_b = (const char*) in;
char* out_b = (char*) out;
size_t nbyte = elem_size * size;
int64_t count;

int64_t* out_i64;
__m512i zmm;
__mmask64 bt;
if (nbyte >= 64) {
const __m512i mask = _mm512_set_epi8(0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0);

for (ii = 0; ii + 63 < nbyte; ii += 64) {
zmm = _mm512_loadu_si512((__m512i *) &in_b[ii]);
for (kk = 0; kk < 8; kk++) {
bt = _mm512_cmp_epi8_mask(zmm, mask, 1);
zmm = _mm512_slli_epi16(zmm, 1);
out_i64 = (int64_t*) &out_b[((7 - kk) * nbyte + ii) / 8];
*out_i64 = (int64_t)bt;
}
}
}

__m256i ymm;
int32_t bt32;
int32_t* out_i32;
size_t start = nbyte - nbyte % 64;
for (ii = start; ii + 31 < nbyte; ii += 32) {
ymm = _mm256_loadu_si256((__m256i *) &in_b[ii]);
for (kk = 0; kk < 8; kk++) {
bt32 = _mm256_movemask_epi8(ymm);
ymm = _mm256_slli_epi16(ymm, 1);
out_i32 = (int32_t*) &out_b[((7 - kk) * nbyte + ii) / 8];
*out_i32 = bt32;
}
}


count = bshuf_trans_bit_byte_remainder(in, out, size, elem_size,
nbyte - nbyte % 64 % 32);

return count;
}


/* Transpose bits within elements. */
int64_t bshuf_trans_bit_elem_AVX512(const void* in, void* out, const size_t size,
const size_t elem_size) {

int64_t count;

CHECK_MULT_EIGHT(size);

void* tmp_buf = malloc(size * elem_size);
if (tmp_buf == NULL) return -1;

count = bshuf_trans_byte_elem_SSE(in, out, size, elem_size);
CHECK_ERR_FREE(count, tmp_buf);
count = bshuf_trans_bit_byte_AVX512(out, tmp_buf, size, elem_size);
CHECK_ERR_FREE(count, tmp_buf);
count = bshuf_trans_bitrow_eight(tmp_buf, out, size, elem_size);

free(tmp_buf);

return count;

}

#else // #ifdef USEAVX512

int64_t bshuf_trans_bit_byte_AVX512(const void* in, void* out, const size_t size,
const size_t elem_size) {

return -14;
}

int64_t bshuf_trans_bit_elem_AVX512(const void* in, void* out, const size_t size,
const size_t elem_size) {
return -14;

}


#endif
/* ---- Code that requires AVX2. Intel Haswell (2013) and later. ---- */

/* ---- Worker code that uses AVX2 ----
Expand Down Expand Up @@ -1729,6 +1636,169 @@ int64_t bshuf_untrans_bit_elem_AVX(const void* in, void* out, const size_t size,

#endif // #ifdef USEAVX2

#ifdef USEAVX512

/* Transpose bits within bytes. */
int64_t bshuf_trans_bit_byte_AVX512(const void* in, void* out, const size_t size,
const size_t elem_size) {

size_t ii, kk;
const char* in_b = (const char*) in;
char* out_b = (char*) out;
size_t nbyte = elem_size * size;
int64_t count;

int64_t* out_i64;
__m512i zmm;
__mmask64 bt;
if (nbyte >= 64) {
const __m512i mask = _mm512_set_epi8(0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0);

for (ii = 0; ii + 63 < nbyte; ii += 64) {
zmm = _mm512_loadu_si512((__m512i *) &in_b[ii]);
for (kk = 0; kk < 8; kk++) {
bt = _mm512_cmp_epi8_mask(zmm, mask, 1);
zmm = _mm512_slli_epi16(zmm, 1);
out_i64 = (int64_t*) &out_b[((7 - kk) * nbyte + ii) / 8];
*out_i64 = (int64_t)bt;
}
}
}

__m256i ymm;
int32_t bt32;
int32_t* out_i32;
size_t start = nbyte - nbyte % 64;
for (ii = start; ii + 31 < nbyte; ii += 32) {
ymm = _mm256_loadu_si256((__m256i *) &in_b[ii]);
for (kk = 0; kk < 8; kk++) {
bt32 = _mm256_movemask_epi8(ymm);
ymm = _mm256_slli_epi16(ymm, 1);
out_i32 = (int32_t*) &out_b[((7 - kk) * nbyte + ii) / 8];
*out_i32 = bt32;
}
}


count = bshuf_trans_bit_byte_remainder(in, out, size, elem_size,
nbyte - nbyte % 64 % 32);

return count;
}


/* Transpose bits within elements. */
int64_t bshuf_trans_bit_elem_AVX512(const void* in, void* out, const size_t size,
const size_t elem_size) {

int64_t count;

CHECK_MULT_EIGHT(size);

void* tmp_buf = malloc(size * elem_size);
if (tmp_buf == NULL) return -1;

count = bshuf_trans_byte_elem_SSE(in, out, size, elem_size);
CHECK_ERR_FREE(count, tmp_buf);
count = bshuf_trans_bit_byte_AVX512(out, tmp_buf, size, elem_size);
CHECK_ERR_FREE(count, tmp_buf);
count = bshuf_trans_bitrow_eight(tmp_buf, out, size, elem_size);

free(tmp_buf);

return count;

}

/* Shuffle bits within the bytes of eight element blocks. */
int64_t bshuf_shuffle_bit_eightelem_AVX512(const void* in, void* out, const size_t size,
const size_t elem_size) {

CHECK_MULT_EIGHT(size);

// With a bit of care, this could be written such that such that it is
// in_buf = out_buf safe.
const char* in_b = (const char*) in;
char* out_b = (char*) out;

size_t ii, jj, kk;
size_t nbyte = elem_size * size;

__m512i zmm;
__mmask64 bt;
int64_t* out_i64;

if (elem_size % 8) {
return bshuf_shuffle_bit_eightelem_AVX(in, out, size, elem_size);
} else {
const __m512i mask = _mm512_set_epi8(0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0);
for (jj = 0; jj + 63 < 8 * elem_size; jj += 64) {
for (ii = 0; ii + 8 * elem_size - 1 < nbyte;
ii += 8 * elem_size) {
zmm = _mm512_loadu_si512((__m512i *) &in_b[ii + jj]);
for (kk = 0; kk < 8; kk++) {
bt = _mm512_cmp_epi8_mask(zmm, mask, 1);
zmm = _mm512_slli_epi16(zmm, 1);
size_t ind = (ii + jj / 8 + (7 - kk) * elem_size);
* (int64_t *) &out_b[ind] = bt;
}
}
}

}
return size * elem_size;
}

/* Untranspose bits within elements. */
int64_t bshuf_untrans_bit_elem_AVX512(const void* in, void* out, const size_t size,
const size_t elem_size) {

int64_t count;

CHECK_MULT_EIGHT(size);

void* tmp_buf = malloc(size * elem_size);
if (tmp_buf == NULL) return -1;

count = bshuf_trans_byte_bitrow_AVX(in, tmp_buf, size, elem_size);
CHECK_ERR_FREE(count, tmp_buf);
count = bshuf_shuffle_bit_eightelem_AVX512(tmp_buf, out, size, elem_size);

free(tmp_buf);
return count;
}

#else // #ifdef USEAVX512

int64_t bshuf_trans_bit_byte_AVX512(const void* in, void* out, const size_t size,
const size_t elem_size) {

return -14;
}

int64_t bshuf_trans_bit_elem_AVX512(const void* in, void* out, const size_t size,
const size_t elem_size) {
return -14;

}

int64_t bshuf_shuffle_bit_eightelem_AVX512(const void* in, void* out, const size_t size,
const size_t elem_size) {
return -14;
}

int64_t bshuf_untrans_bit_elem_AVX512(const void* in, void* out, const size_t size,
const size_t elem_size) {
return -14;
}

#endif

/* ---- Drivers selecting best instruction set at compile time. ---- */

Expand All @@ -1755,7 +1825,9 @@ int64_t bshuf_untrans_bit_elem(const void* in, void* out, const size_t size,
const size_t elem_size) {

int64_t count;
#ifdef USEAVX2
#ifdef USEAVX512
count = bshuf_untrans_bit_elem_AVX512(in, out, size, elem_size);
#elif defined USEAVX2
count = bshuf_untrans_bit_elem_AVX(in, out, size, elem_size);
#elif defined(USESSE2)
count = bshuf_untrans_bit_elem_SSE(in, out, size, elem_size);
Expand Down
49 changes: 49 additions & 0 deletions tests/test_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,20 @@ def test_06g_untrans_bit_elem_64(self):
self.fun = ext.untrans_bit_elem_scal
self.check_data = pre_trans

def test_06h_untrans_bit_elem_32(self):
self.case = "bit U elem AVX512 32"
pre_trans = self.data.view(np.float32)
self.data = trans_bit_elem(pre_trans)
self.fun = ext.untrans_bit_elem_AVX512
self.check_data = pre_trans

def test_06i_untrans_bit_elem_64(self):
self.case = "bit U elem AVX512 64"
pre_trans = self.data.view(np.float64)
self.data = trans_bit_elem(pre_trans)
self.fun = ext.untrans_bit_elem_AVX512
self.check_data = pre_trans

def test_07a_trans_byte_bitrow_64(self):
self.case = "byte T row scal 64"
self.data = self.data.view(np.float64)
Expand Down Expand Up @@ -352,6 +366,30 @@ def test_08f_shuffle_bit_eight_AVX_128(self):
self.fun = ext.shuffle_bit_eightelem_AVX
self.check = ext.shuffle_bit_eightelem_scal

def test_08g_shuffle_bit_eight_AVX512_32(self):
self.case = "bit S eight AVX 32"
self.data = self.data.view(np.float32)
self.fun = ext.shuffle_bit_eightelem_AVX512
self.check = ext.shuffle_bit_eightelem_scal

def test_08h_shuffle_bit_eight_AVX512_64(self):
self.case = "bit S eight AVX512 64"
self.data = self.data.view(np.float64)
self.fun = ext.shuffle_bit_eightelem_AVX512
self.check = ext.shuffle_bit_eightelem_scal

def test_08i_shuffle_bit_eight_AVX512_16(self):
self.case = "bit S eight AVX512 16"
self.data = self.data.view(np.int16)
self.fun = ext.shuffle_bit_eightelem_AVX512
self.check = ext.shuffle_bit_eightelem_scal

def test_08i_shuffle_bit_eight_AVX512_128(self):
self.case = "bit S eight AVX512 128"
self.data = self.data.view(np.complex128)
self.fun = ext.shuffle_bit_eightelem_AVX512
self.check = ext.shuffle_bit_eightelem_scal

def test_09a_trans_bit_elem_scal_64(self):
self.case = "bit T elem scal 64"
self.data = self.data.view(np.float64)
Expand Down Expand Up @@ -391,6 +429,13 @@ def test_09f_untrans_bit_elem_AVX_64(self):
self.fun = ext.untrans_bit_elem_AVX
self.check_data = pre_trans

def test_09g_untrans_bit_elem_AVX_64(self):
self.case = "bit U elem AVX512 64"
pre_trans = self.data.view(np.float64)
self.data = trans_bit_elem(pre_trans)
self.fun = ext.untrans_bit_elem_AVX512
self.check_data = pre_trans

def test_10a_bitshuffle_64(self):
self.case = "bitshuffle 64"
self.data = self.data.view(np.float64)
Expand Down Expand Up @@ -527,6 +572,10 @@ def test_untrans_bit_elem_AVX(self):
self.fun = lambda x: ext.untrans_bit_elem_SSE(ext.trans_bit_elem(x))
self.check = lambda x: x

def test_untrans_bit_elem_AVX512(self):
self.fun = lambda x: ext.untrans_bit_elem_SSE(ext.trans_bit_elem(x))
self.check = lambda x: x

def test_trans_bit_elem_scal(self):
self.fun = ext.trans_bit_elem_scal
self.check = trans_bit_elem
Expand Down

0 comments on commit f354ace

Please sign in to comment.