Skip to content

Commit

Permalink
Translate my AVX512BW C++ code to C
Browse files Browse the repository at this point in the history
Unit tests do not pass now.
  • Loading branch information
WojciechMula committed Apr 20, 2018
1 parent 7e63ba3 commit d15f456
Show file tree
Hide file tree
Showing 5 changed files with 271 additions and 3 deletions.
14 changes: 11 additions & 3 deletions Makefile
Expand Up @@ -3,24 +3,32 @@
.PHONY: all clean
#
.SUFFIXES: .cpp .o .c .h
CFLAGS= -fPIC -std=c99 -Wall -Wextra -Wshadow -Wpsabi
CFLAGS= -fPIC -std=c99 -Wall -Wextra -Wshadow -Wpsabi -Wfatal-errors
ifeq ($(DEBUG),1)
CFLAGS += -ggdb -fsanitize=undefined -fno-omit-frame-pointer -fsanitize=address
else
CFLAGS += -O3 -march=native -mavx2
CFLAGS += -O3 -march=native
endif # debug

ifeq ($(AVX512BW),1)
CFLAGS += -march=native -mavx512bw -DHAVE_AVX512BW
else
CFLAGS += -march=native -mavx2
endif # AVX512BW

all: unit basic_benchmark


HEADERS=./include/chromiumbase64.h \
./include/klompavxbase64.h \
./include/quicktimebase64.h \
./include/scalarbase64.h \
./include/fastavxbase64.h
./include/fastavxbase64.h \
./include/fastavx512bwbase64.h \

OBJECTS=chromiumbase64.o \
fastavxbase64.o \
fastavx512bwbase64.o \
klompavxbase64.o \
quicktimebase64.o \
scalarbase64.o
Expand Down
9 changes: 9 additions & 0 deletions benchmarks/basic_benchmark.c
Expand Up @@ -17,6 +17,9 @@
#include "chromiumbase64.h"
#include "quicktimebase64.h"
#include "linuxbase64.h"
#ifdef HAVE_AVX512BW
#include "fastavxbase64.h"
#endif // HAVE_AVX512BW

static const int repeat = 50;

Expand All @@ -36,6 +39,9 @@ void testencode(const char * data, size_t datalength, bool verbose) {
assert(outputlength == expected);
BEST_TIME_CHECK(scalar_base64_encode(data,datalength,buffer,&outputlength),(outputlength == avxexpected), , repeat, datalength,verbose);
BEST_TIME_CHECK(fast_avx2_base64_encode(buffer, data, datalength), (int) expected, , repeat, datalength,verbose);
#ifndef HAVE_AVX512BW
BEST_TIME_CHECK(fast_avx512bw_base64_encode(buffer, data, datalength), (int) expected, , repeat, datalength,verbose);
#endif // HAVE_AVX512BW
free(buffer);
if(verbose) printf("\n");
}
Expand All @@ -61,6 +67,9 @@ void testdecode(const char * data, size_t datalength, bool verbose) {
BEST_TIME(scalar_base64_decode(data,datalength,buffer,&outputlength), avxexpected, , repeat, datalength,verbose);
BEST_TIME(klomp_avx2_base64_decode(data,datalength,buffer,&outputlength), avxexpected, , repeat, datalength,verbose);
BEST_TIME(fast_avx2_base64_decode(buffer, data, datalength), (int) expected, , repeat, datalength,verbose);
#ifndef HAVE_AVX512BW
BEST_TIME(fast_avx512bw_base64_decode(buffer, data, datalength), (int) expected, , repeat, datalength,verbose);
#endif // HAVE_AVX512BW

free(buffer);
if(verbose) printf("\n");
Expand Down
23 changes: 23 additions & 0 deletions include/fastavx512bwbase64.h
@@ -0,0 +1,23 @@
#ifndef FASTBASE64_AVX512BW_H_
#define FASTBASE64_AVX512BW_H_

/**
* Assumes recent x64 hardware with AVX512BW instructions.
*/

#include <stddef.h>
#include <stdint.h>
#include "chromiumbase64.h"

#ifdef __cplusplus
extern "C" {
#endif /* __cplusplus */

size_t fast_avx512bw_base64_decode(char *out, const char *src, size_t srclen);
size_t fast_avx512bw_base64_encode(char* dest, const char* str, size_t len);

#ifdef __cplusplus
}
#endif /* __cplusplus */

#endif
196 changes: 196 additions & 0 deletions src/fastavx512bwbase64.c
@@ -0,0 +1,196 @@
#include "fastavx512bwbase64.h"

#include <x86intrin.h>
#include <stdbool.h>

static inline __m512i enc_reshuffle(const __m512i input) {

// from https://github.com/WojciechMula/base64simd/blob/master/encode/encode.avx512bw.cpp

// place each 12-byte subarray in seprate 128-bit lane
// tmp1 = [?? ?? ?? ??|D2 D1 D0 C2|C1 C0 B2 B1|B0 A2 A1 A0] x 4
// ignored
const __m512i tmp1 = _mm512_permutexvar_epi32(
_mm512_set_epi32(-1, 11, 10, 9, -1, 8, 7, 6, -1, 5, 4, 3, -1, 2, 1, 0),
input
);

// reshuffle bytes within 128-bit lanes to format required by
// AVX512BW unpack procedure
// tmp2 = [D1 D2 D0 D1|C1 C2 C0 C1|B1 B2 B0 B1|A1 A2 A0 A1] x 4
// 10 11 9 10 7 8 6 7 4 5 3 4 1 2 0 1
const __m512i tmp2 = _mm512_shuffle_epi8(
tmp1,
_mm512_set4_epi32(0x0a0b090a, 0x07080607, 0x04050304, 0x01020001)
);

return tmp2;
}

static inline __m512i enc_translate(const __m512i input) {

// from https://github.com/WojciechMula/base64simd/blob/master/encode/lookup.avx512bw.cpp

// reduce 0..51 -> 0
// 52..61 -> 1 .. 10
// 62 -> 11
// 63 -> 12
__m512i result = _mm512_subs_epu8(input, _mm512_set1_epi8(51));

// distinguish between ranges 0..25 and 26..51:
// 0 .. 25 -> remains 0
// 26 .. 51 -> becomes 13
const __mmask64 less = _mm512_cmpgt_epi8_mask(_mm512_set1_epi8(26), input);
result = _mm512_mask_mov_epi8(result, less, _mm512_set1_epi8(13));

/* the SSE lookup
const __m128i shift_LUT = _mm_setr_epi8(
'a' - 26, '0' - 52, '0' - 52, '0' - 52, '0' - 52, '0' - 52,
'0' - 52, '0' - 52, '0' - 52, '0' - 52, '0' - 52, '+' - 62,
'/' - 63, 'A', 0, 0
);
which is:
0x47, 0xfc, 0xfc, 0xfc,
0xfc, 0xfc, 0xfc, 0xfc,
0xfc, 0xfc, 0xfc, 0xed,
0xf0, 0x41, 0x00, 0x00
Note that the order of above list is reserved (due to _mm_setr_epi8),
so the invocation _mm512_set4_epi32 looks... odd.
*/
const __m512i shift_LUT = _mm512_set4_epi32(
0x000041f0,
0xedfcfcfc,
0xfcfcfcfc,
0xfcfcfc47
);

// read shift
result = _mm512_shuffle_epi8(shift_LUT, result);

return _mm512_add_epi8(result, input);
}

static inline __m512i dec_reshuffle(__m512i input) {

// from https://github.com/WojciechMula/base64simd/blob/master/decode/pack.avx512bw.cpp
const __m512i merge_ab_and_bc = _mm512_maddubs_epi16(input, _mm512_set1_epi32(0x01400140));

return _mm512_madd_epi16(merge_ab_and_bc, _mm512_set1_epi32(0x00011000));
}


size_t fast_avx512bw_base64_encode(char* dest, const char* str, size_t len) {
const char* const dest_orig = dest;
__m512i inputvector;
while (len >= 64) {
inputvector = _mm512_loadu_si512((__m512i *)(str));
inputvector = enc_reshuffle(inputvector);
inputvector = enc_translate(inputvector);
_mm512_storeu_si512((__m512i *)dest, inputvector);
str += 48;
dest += 64;
len -= 48;
}
size_t scalarret = chromium_base64_encode(dest, str, len);
if(scalarret == MODP_B64_ERROR) return MODP_B64_ERROR;
return (dest - dest_orig) + scalarret;
}

#define build_dword(b0, b1, b2, b3) \
(((uint32_t)((uint8_t)(b0)) << 0*8) \
| ((uint32_t)((uint8_t)(b1)) << 1*8) \
| ((uint32_t)((uint8_t)(b2)) << 2*8) \
| ((uint32_t)((uint8_t)(b3)) << 3*8))

#define _mm512_set4lanes_epi8(b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15) \
_mm512_setr4_epi32( \
build_dword( b0, b1, b2, b3), \
build_dword( b4, b5, b6, b7), \
build_dword( b8, b9, b10, b11), \
build_dword(b12, b13, b14, b15))

size_t fast_avx512bw_base64_decode(char *out, const char *src, size_t srclen) {
char* out_orig = out;
while (srclen >= 64) {

// load
const __m512i input = _mm512_loadu_si512((const __m512i*)(src));

// translate from ASCII
// and https://github.com/WojciechMula/base64simd/blob/master/decode/lookup.avx512bw.cpp
const __m512i higher_nibble = _mm512_and_si512(_mm512_srli_epi32(input, 4), _mm512_set1_epi8(0x0f));
const __m512i lower_nibble = _mm512_and_si512(input, _mm512_set1_epi8(0x0f));

const __m512i shiftLUT = _mm512_set4lanes_epi8(
0, 0, 19, 4, -65, -65, -71, -71,
0, 0, 0, 0, 0, 0, 0, 0);

const __m512i maskLUT = _mm512_set4lanes_epi8(
/* 0 : 0b1010_1000*/ 0xa8,
/* 1 .. 9 : 0b1111_1000*/ 0xf8, 0xf8, 0xf8, 0xf8,
0xf8, 0xf8, 0xf8, 0xf8,
0xf8,
/* 10 : 0b1111_0000*/ 0xf0,
/* 11 : 0b0101_0100*/ 0x54,
/* 12 .. 14 : 0b0101_0000*/ 0x50, 0x50, 0x50,
/* 15 : 0b0101_0100*/ 0x54
);

const __m512i bitposLUT = _mm512_set4lanes_epi8(
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00
);

const __m512i sh = _mm512_shuffle_epi8(shiftLUT, higher_nibble);
const __mmask64 eq_2f = _mm512_cmpeq_epi8_mask(input, _mm512_set1_epi8(0x2f));
const __m512i shift = _mm512_mask_mov_epi8(sh, eq_2f, _mm512_set1_epi8(16));

const __m512i M = _mm512_shuffle_epi8(maskLUT, lower_nibble);
const __m512i bit = _mm512_shuffle_epi8(bitposLUT, higher_nibble);

const uint64_t match = _mm512_test_epi8_mask(M, bit);

if (match != (uint64_t)(-1)) {
// some characters do not match the valid range
return MODP_B64_ERROR;
}

const __m512i translated = _mm512_add_epi8(input, shift);

const __m512i packed = dec_reshuffle(translated);

// and https://github.com/WojciechMula/base64simd/blob/master/decode/decode.avx512bw.cpp
//
const __m512i t1 = _mm512_shuffle_epi8(
packed,
_mm512_set4lanes_epi8(
2, 1, 0,
6, 5, 4,
10, 9, 8,
14, 13, 12,
-1, -1, -1, -1)
);

// shuffle bytes
const __m512i s6 = _mm512_setr_epi32(
0, 1, 2,
4, 5, 6,
8, 9, 10,
12, 13, 14,
// unused
0, 0, 0, 0);

const __m512i t2 = _mm512_permutexvar_epi32(s6, t1);

_mm512_storeu_si512((__m512i*)(out), t2);

srclen -= 64;
src += 64;
out += 48;
}
size_t scalarret = chromium_base64_decode(out, src, srclen);
if(scalarret == MODP_B64_ERROR) return MODP_B64_ERROR;
return (out - out_orig) + scalarret;
}

32 changes: 32 additions & 0 deletions tests/unit.c
Expand Up @@ -10,6 +10,9 @@
#include "chromiumbase64.h"
#include "klompavxbase64.h"
#include "fastavxbase64.h"
#ifdef HAVE_AVX512BW
#include "fastavx512bwbase64.h"
#endif

#include "scalarbase64.h"

Expand Down Expand Up @@ -134,6 +137,29 @@ void fast_avx2_checkExample(const char * source, const char * coded) {
free(dest3);
}

#ifdef HAVE_AVX512BW
void fast_avx512bw_checkExample(const char * source, const char * coded) {
printf("fast_avx512bw codec check.\n");
size_t len;
size_t codedlen;

char * dest1 = (char*) malloc(chromium_base64_encode_len(strlen(source)));
codedlen = fast_avx512bw_base64_encode(dest1, source, strlen(source));
assert(strncmp(dest1,coded,codedlen) == 0);
char *dest2 = (char*) malloc(chromium_base64_decode_len(codedlen));
len = fast_avx512bw_base64_decode(dest2, coded, codedlen);
assert(len == strlen(source));
assert(strncmp(dest2,source,strlen(source)) == 0);
char *dest3 = (char*) malloc(chromium_base64_decode_len(codedlen));
len = fast_avx512bw_base64_decode(dest3, dest1, codedlen);
assert(len == strlen(source));
assert(strncmp(dest3,source,strlen(source)) == 0);
free(dest1);
free(dest2);
free(dest3);
}
#endif // HAVE_AVX512BW

static const uint8_t base64_table_enc[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz"
"0123456789+/";
Expand Down Expand Up @@ -223,6 +249,12 @@ ZSBzaG9ydCB2ZWhlbWVuY2Ugb2YgYW55IGNhcm5hbCBwbGVhc3VyZS4=";
fast_avx2_checkExample(gosource,gocoded);
fast_avx2_checkExample(tutosource,tutocoded);

#ifdef HAVE_AVX512BW
fast_avx512bw_checkExample(wikipediasource,wikipediacoded);
fast_avx512bw_checkExample(gosource,gocoded);
fast_avx512bw_checkExample(tutosource,tutocoded);
#endif // HAVE_AVX512BW

scalar_checkExample(wikipediasource,wikipediacoded);
scalar_checkExample(gosource,gocoded);
scalar_checkExample(tutosource,tutocoded);
Expand Down

0 comments on commit d15f456

Please sign in to comment.