diff --git a/src/amalgam/gen/avx512fp16.c b/src/amalgam/gen/avx512fp16.c index fd202ee7f9e..4422f0e7e06 100644 --- a/src/amalgam/gen/avx512fp16.c +++ b/src/amalgam/gen/avx512fp16.c @@ -8,7 +8,9 @@ #include #include +#include #include +#include void xnn_f16_rmax_ukernel__avx512fp16_u128_acc4( @@ -149,3 +151,144 @@ void xnn_f16_rminmax_ukernel__avx512fp16_u128_acc4( *((uint16_t*) output + 1) = (uint16_t) _mm_extract_epi16(_mm_castph_si128(vmax), 0); #endif // defined(__AVX512FP16__) } + +void xnn_f16_vmul_minmax_ukernel__avx512fp16_u64( + size_t batch, + const void* restrict input_a, + const void* restrict input_b, + void* restrict output, + const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(batch != 0); + assert(batch % sizeof(uint16_t) == 0); + assert(input_a != NULL); + assert(input_b != NULL); + assert(output != NULL); + +#if defined(__AVX512FP16__) + const uint16_t* a = (const uint16_t*) input_a; + const uint16_t* b = (const uint16_t*) input_b; + uint16_t* o = (uint16_t*) output; + + const __m512h voutput_min = _mm512_castsi512_ph(_mm512_set1_epi16((params->fp16arith.min))); + const __m512h voutput_max = _mm512_castsi512_ph(_mm512_set1_epi16((params->fp16arith.max))); + + for (; batch >= 64 * sizeof(uint16_t); batch -= 64 * sizeof(uint16_t)) { + __m512h vacc0 = _mm512_loadu_ph(a); + __m512h vacc1 = _mm512_loadu_ph(a + 32); + a += 64; + + vacc0 = _mm512_mul_ph(vacc0, _mm512_loadu_ph(b)); + vacc1 = _mm512_mul_ph(vacc1, _mm512_loadu_ph(b + 32)); + b += 64; + + + vacc0 = _mm512_max_ph(voutput_min, vacc0); + vacc1 = _mm512_max_ph(voutput_min, vacc1); + + vacc0 = _mm512_min_ph(voutput_max, vacc0); + vacc1 = _mm512_min_ph(voutput_max, vacc1); + + _mm512_storeu_ph(o, vacc0); + _mm512_storeu_ph(o + 32, vacc1); + o += 64; + } + for (; batch >= 32 * sizeof(uint16_t); batch -= 32 * sizeof(uint16_t)) { + __m512h vacc = _mm512_loadu_ph(a); + a += 32; + + vacc = _mm512_mul_ph(vacc, _mm512_loadu_ph(b)); + b += 32; + + vacc = _mm512_max_ph(voutput_min, vacc); + vacc = _mm512_min_ph(voutput_max, vacc); + + _mm512_storeu_ph(o, vacc); + o += 32; + } + if XNN_UNLIKELY(batch != 0) { + assert(batch >= 1 * sizeof(uint16_t)); + assert(batch <= 31 * sizeof(uint16_t)); + // Prepare mask for valid 16-bit elements (depends on batch). + batch >>= XNN_LOG2_SIZEOF_HALF; + const __mmask32 vmask = _cvtu32_mask32((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1))); + + __m512h vacc = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(vmask, a)); + + vacc = _mm512_maskz_mul_ph(vmask, vacc, _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(vmask, b))); + + vacc = _mm512_maskz_max_ph(vmask, voutput_min, vacc); + vacc = _mm512_maskz_min_ph(vmask, voutput_max, vacc); + _mm512_mask_storeu_epi16(o, vmask, _mm512_castph_si512(vacc)); + } +#endif // defined(__AVX512FP16__) +} + +void xnn_f16_vmulc_minmax_ukernel__avx512fp16_u64( + size_t batch, + const void* restrict input_a, + const void* restrict input_b, + void* restrict output, + const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(batch != 0); + assert(batch % sizeof(uint16_t) == 0); + assert(input_a != NULL); + assert(input_b != NULL); + assert(output != NULL); + +#if defined(__AVX512FP16__) + const uint16_t* a = (const uint16_t*) input_a; + const uint16_t* b = (const uint16_t*) input_b; + uint16_t* o = (uint16_t*) output; + + const __m512h voutput_min = _mm512_castsi512_ph(_mm512_set1_epi16(params->fp16arith.min)); + const __m512h voutput_max = _mm512_castsi512_ph(_mm512_set1_epi16(params->fp16arith.max)); + const __m512h vb = _mm512_castsi512_ph(_mm512_set1_epi16(*b)); + + for (; batch >= 64 * sizeof(uint16_t); batch -= 64 * sizeof(uint16_t)) { + __m512h vacc0 = _mm512_loadu_ph(a); + __m512h vacc1 = _mm512_loadu_ph(a + 32); + a += 64; + + vacc0 = _mm512_mul_ph(vacc0, vb); + vacc1 = _mm512_mul_ph(vacc1, vb); + + + vacc0 = _mm512_max_ph(voutput_min, vacc0); + vacc1 = _mm512_max_ph(voutput_min, vacc1); + + vacc0 = _mm512_min_ph(voutput_max, vacc0); + vacc1 = _mm512_min_ph(voutput_max, vacc1); + + _mm512_storeu_ph(o, vacc0); + _mm512_storeu_ph(o + 32, vacc1); + o += 64; + } + for (; batch >= 32 * sizeof(uint16_t); batch -= 32 * sizeof(uint16_t)) { + __m512h vacc = _mm512_loadu_ph(a); + a += 32; + + vacc = _mm512_mul_ph(vacc, vb); + vacc = _mm512_max_ph(voutput_min, vacc); + vacc = _mm512_min_ph(voutput_max, vacc); + + _mm512_storeu_ph(o, vacc); + o += 32; + } + if XNN_UNLIKELY(batch != 0) { + assert(batch >= 1 * sizeof(uint16_t)); + assert(batch <= 31 * sizeof(uint16_t)); + // Prepare mask for valid 16-bit elements (depends on batch). + batch >>= XNN_LOG2_SIZEOF_HALF; + const __mmask32 vmask = _cvtu32_mask32((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1))); + + __m512h vacc = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(vmask, a)); + + vacc = _mm512_maskz_mul_ph(vmask, vacc, vb); + vacc = _mm512_maskz_max_ph(vmask, voutput_min, vacc); + vacc = _mm512_maskz_min_ph(vmask, voutput_max, vacc); + _mm512_mask_storeu_epi16(o, vmask, _mm512_castph_si512(vacc)); + } +#endif // defined(__AVX512FP16__) +} diff --git a/src/configs/binary-elementwise-config.c b/src/configs/binary-elementwise-config.c index fe955d85b37..77e1c20f517 100644 --- a/src/configs/binary-elementwise-config.c +++ b/src/configs/binary-elementwise-config.c @@ -235,6 +235,15 @@ static void init_f16_vmul_config(void) { #elif (XNN_ARCH_X86 || XNN_ARCH_X86_64) && !XNN_PLATFORM_MOBILE const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + #if XNN_ENABLE_AVX512FP16 + if (hardware_config->use_x86_avx512fp16) { + f16_vmul_config.minmax.op_ukernel = (xnn_vbinary_ukernel_fn) xnn_f16_vmul_minmax_ukernel__avx512fp16_u64; + f16_vmul_config.minmax.opc_ukernel = (xnn_vbinary_ukernel_fn) xnn_f16_vmulc_minmax_ukernel__avx512fp16_u64; + f16_vmul_config.minmax.ropc_ukernel = (xnn_vbinary_ukernel_fn) xnn_f16_vmulc_minmax_ukernel__avx512fp16_u64; + f16_vmul_config.init.f16_minmax = xnn_init_f16_minmax_fp16arith_params; + f16_vmul_config.minmax.element_tile = 64; + } else + #endif if (hardware_config->use_x86_avx2) { f16_vmul_config.minmax.op_ukernel = (xnn_vbinary_ukernel_fn) xnn_f16_vmul_minmax_ukernel__f16c_u16; f16_vmul_config.minmax.opc_ukernel = (xnn_vbinary_ukernel_fn) xnn_f16_vmulc_minmax_ukernel__f16c_u16;