Skip to content

Commit

Permalink
Enable AVX512FP16 vmul vbinary microkernels
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 632653372
  • Loading branch information
fbarchard authored and xnnpack-bot committed May 11, 2024
1 parent 0d273d4 commit 48bf404
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 0 deletions.
143 changes: 143 additions & 0 deletions src/amalgam/gen/avx512fp16.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
#include <immintrin.h>

#include <xnnpack/common.h>
#include <xnnpack/intrinsics-polyfill.h>
#include <xnnpack/reduce.h>
#include <xnnpack/vbinary.h>


void xnn_f16_rmax_ukernel__avx512fp16_u128_acc4(
Expand Down Expand Up @@ -149,3 +151,144 @@ void xnn_f16_rminmax_ukernel__avx512fp16_u128_acc4(
*((uint16_t*) output + 1) = (uint16_t) _mm_extract_epi16(_mm_castph_si128(vmax), 0);
#endif // defined(__AVX512FP16__)
}

void xnn_f16_vmul_minmax_ukernel__avx512fp16_u64(
size_t batch,
const void* restrict input_a,
const void* restrict input_b,
void* restrict output,
const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
{
assert(batch != 0);
assert(batch % sizeof(uint16_t) == 0);
assert(input_a != NULL);
assert(input_b != NULL);
assert(output != NULL);

#if defined(__AVX512FP16__)
const uint16_t* a = (const uint16_t*) input_a;
const uint16_t* b = (const uint16_t*) input_b;
uint16_t* o = (uint16_t*) output;

const __m512h voutput_min = _mm512_castsi512_ph(_mm512_set1_epi16((params->fp16arith.min)));
const __m512h voutput_max = _mm512_castsi512_ph(_mm512_set1_epi16((params->fp16arith.max)));

for (; batch >= 64 * sizeof(uint16_t); batch -= 64 * sizeof(uint16_t)) {
__m512h vacc0 = _mm512_loadu_ph(a);
__m512h vacc1 = _mm512_loadu_ph(a + 32);
a += 64;

vacc0 = _mm512_mul_ph(vacc0, _mm512_loadu_ph(b));
vacc1 = _mm512_mul_ph(vacc1, _mm512_loadu_ph(b + 32));
b += 64;


vacc0 = _mm512_max_ph(voutput_min, vacc0);
vacc1 = _mm512_max_ph(voutput_min, vacc1);

vacc0 = _mm512_min_ph(voutput_max, vacc0);
vacc1 = _mm512_min_ph(voutput_max, vacc1);

_mm512_storeu_ph(o, vacc0);
_mm512_storeu_ph(o + 32, vacc1);
o += 64;
}
for (; batch >= 32 * sizeof(uint16_t); batch -= 32 * sizeof(uint16_t)) {
__m512h vacc = _mm512_loadu_ph(a);
a += 32;

vacc = _mm512_mul_ph(vacc, _mm512_loadu_ph(b));
b += 32;

vacc = _mm512_max_ph(voutput_min, vacc);
vacc = _mm512_min_ph(voutput_max, vacc);

_mm512_storeu_ph(o, vacc);
o += 32;
}
if XNN_UNLIKELY(batch != 0) {
assert(batch >= 1 * sizeof(uint16_t));
assert(batch <= 31 * sizeof(uint16_t));
// Prepare mask for valid 16-bit elements (depends on batch).
batch >>= XNN_LOG2_SIZEOF_HALF;
const __mmask32 vmask = _cvtu32_mask32((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1)));

__m512h vacc = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(vmask, a));

vacc = _mm512_maskz_mul_ph(vmask, vacc, _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(vmask, b)));

vacc = _mm512_maskz_max_ph(vmask, voutput_min, vacc);
vacc = _mm512_maskz_min_ph(vmask, voutput_max, vacc);
_mm512_mask_storeu_epi16(o, vmask, _mm512_castph_si512(vacc));
}
#endif // defined(__AVX512FP16__)
}

void xnn_f16_vmulc_minmax_ukernel__avx512fp16_u64(
size_t batch,
const void* restrict input_a,
const void* restrict input_b,
void* restrict output,
const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
{
assert(batch != 0);
assert(batch % sizeof(uint16_t) == 0);
assert(input_a != NULL);
assert(input_b != NULL);
assert(output != NULL);

#if defined(__AVX512FP16__)
const uint16_t* a = (const uint16_t*) input_a;
const uint16_t* b = (const uint16_t*) input_b;
uint16_t* o = (uint16_t*) output;

const __m512h voutput_min = _mm512_castsi512_ph(_mm512_set1_epi16(params->fp16arith.min));
const __m512h voutput_max = _mm512_castsi512_ph(_mm512_set1_epi16(params->fp16arith.max));
const __m512h vb = _mm512_castsi512_ph(_mm512_set1_epi16(*b));

for (; batch >= 64 * sizeof(uint16_t); batch -= 64 * sizeof(uint16_t)) {
__m512h vacc0 = _mm512_loadu_ph(a);
__m512h vacc1 = _mm512_loadu_ph(a + 32);
a += 64;

vacc0 = _mm512_mul_ph(vacc0, vb);
vacc1 = _mm512_mul_ph(vacc1, vb);


vacc0 = _mm512_max_ph(voutput_min, vacc0);
vacc1 = _mm512_max_ph(voutput_min, vacc1);

vacc0 = _mm512_min_ph(voutput_max, vacc0);
vacc1 = _mm512_min_ph(voutput_max, vacc1);

_mm512_storeu_ph(o, vacc0);
_mm512_storeu_ph(o + 32, vacc1);
o += 64;
}
for (; batch >= 32 * sizeof(uint16_t); batch -= 32 * sizeof(uint16_t)) {
__m512h vacc = _mm512_loadu_ph(a);
a += 32;

vacc = _mm512_mul_ph(vacc, vb);
vacc = _mm512_max_ph(voutput_min, vacc);
vacc = _mm512_min_ph(voutput_max, vacc);

_mm512_storeu_ph(o, vacc);
o += 32;
}
if XNN_UNLIKELY(batch != 0) {
assert(batch >= 1 * sizeof(uint16_t));
assert(batch <= 31 * sizeof(uint16_t));
// Prepare mask for valid 16-bit elements (depends on batch).
batch >>= XNN_LOG2_SIZEOF_HALF;
const __mmask32 vmask = _cvtu32_mask32((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1)));

__m512h vacc = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(vmask, a));

vacc = _mm512_maskz_mul_ph(vmask, vacc, vb);
vacc = _mm512_maskz_max_ph(vmask, voutput_min, vacc);
vacc = _mm512_maskz_min_ph(vmask, voutput_max, vacc);
_mm512_mask_storeu_epi16(o, vmask, _mm512_castph_si512(vacc));
}
#endif // defined(__AVX512FP16__)
}
9 changes: 9 additions & 0 deletions src/configs/binary-elementwise-config.c
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,15 @@ static void init_f16_vmul_config(void) {
#elif (XNN_ARCH_X86 || XNN_ARCH_X86_64) && !XNN_PLATFORM_MOBILE
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
assert(hardware_config != NULL);
#if XNN_ENABLE_AVX512FP16
if (hardware_config->use_x86_avx512fp16) {
f16_vmul_config.minmax.op_ukernel = (xnn_vbinary_ukernel_fn) xnn_f16_vmul_minmax_ukernel__avx512fp16_u64;
f16_vmul_config.minmax.opc_ukernel = (xnn_vbinary_ukernel_fn) xnn_f16_vmulc_minmax_ukernel__avx512fp16_u64;
f16_vmul_config.minmax.ropc_ukernel = (xnn_vbinary_ukernel_fn) xnn_f16_vmulc_minmax_ukernel__avx512fp16_u64;
f16_vmul_config.init.f16_minmax = xnn_init_f16_minmax_fp16arith_params;
f16_vmul_config.minmax.element_tile = 64;
} else
#endif
if (hardware_config->use_x86_avx2) {
f16_vmul_config.minmax.op_ukernel = (xnn_vbinary_ukernel_fn) xnn_f16_vmul_minmax_ukernel__f16c_u16;
f16_vmul_config.minmax.opc_ukernel = (xnn_vbinary_ukernel_fn) xnn_f16_vmulc_minmax_ukernel__f16c_u16;
Expand Down

0 comments on commit 48bf404

Please sign in to comment.