diff --git a/include/xnnpack.h b/include/xnnpack.h index 15915ae47bf..1626562dd48 100644 --- a/include/xnnpack.h +++ b/include/xnnpack.h @@ -4862,13 +4862,10 @@ enum xnn_status xnn_reshape_mean_nd_f16( const size_t* reduction_axes, size_t num_input_dims, const size_t* input_shape, - size_t* workspace_size, - size_t* workspace_alignment, pthreadpool_t threadpool); enum xnn_status xnn_setup_mean_nd_f16( xnn_operator_t mean_op, - void* workspace, const void* input, void* output); @@ -4882,13 +4879,10 @@ enum xnn_status xnn_reshape_mean_nd_f32( const size_t* reduction_axes, size_t num_input_dims, const size_t* input_shape, - size_t* workspace_size, - size_t* workspace_alignment, pthreadpool_t threadpool); enum xnn_status xnn_setup_mean_nd_f32( xnn_operator_t mean_op, - void* workspace, const float* input, float* output); diff --git a/src/amalgam/gen/avx.c b/src/amalgam/gen/avx.c index dc395ef7c47..44c74d5f809 100644 --- a/src/amalgam/gen/avx.c +++ b/src/amalgam/gen/avx.c @@ -3642,6 +3642,249 @@ void xnn_f32_qu8_vcvt_ukernel__avx_u32( } } +void xnn_f32_rdsum_ukernel_7p7x__avx_c32( + size_t rows, + size_t channels, + const float* input, + size_t input_stride, + const float* zero, + float* output, + const union xnn_f32_scale_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(rows != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const __m256 vscale = _mm256_set1_ps(params->avx.scale); + + size_t input_increment = 7 * input_stride; + for (; channels >= 32; channels -= 32) { + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride); + const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride); + const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride); + const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride); + const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride); + const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride); + + __m256 vacc0 = _mm256_setzero_ps(); + __m256 vacc1 = _mm256_setzero_ps(); + __m256 vacc2 = _mm256_setzero_ps(); + __m256 vacc3 = _mm256_setzero_ps(); + + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + __m256 vin0; + __m256 vin1; + __m256 vin2; + __m256 vin3; + vin0 = _mm256_loadu_ps(&i0[0]); + vin1 = _mm256_loadu_ps(&i0[8]); + vin2 = _mm256_loadu_ps(&i0[16]); + vin3 = _mm256_loadu_ps(&i0[24]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vin0 = _mm256_loadu_ps(&i1[0]); + vin1 = _mm256_loadu_ps(&i1[8]); + vin2 = _mm256_loadu_ps(&i1[16]); + vin3 = _mm256_loadu_ps(&i1[24]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vin0 = _mm256_loadu_ps(&i2[0]); + vin1 = _mm256_loadu_ps(&i2[8]); + vin2 = _mm256_loadu_ps(&i2[16]); + vin3 = _mm256_loadu_ps(&i2[24]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vin0 = _mm256_loadu_ps(&i3[0]); + vin1 = _mm256_loadu_ps(&i3[8]); + vin2 = _mm256_loadu_ps(&i3[16]); + vin3 = _mm256_loadu_ps(&i3[24]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vin0 = _mm256_loadu_ps(&i4[0]); + vin1 = _mm256_loadu_ps(&i4[8]); + vin2 = _mm256_loadu_ps(&i4[16]); + vin3 = _mm256_loadu_ps(&i4[24]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vin0 = _mm256_loadu_ps(&i5[0]); + vin1 = _mm256_loadu_ps(&i5[8]); + vin2 = _mm256_loadu_ps(&i5[16]); + vin3 = _mm256_loadu_ps(&i5[24]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vin0 = _mm256_loadu_ps(&i6[0]); + vin1 = _mm256_loadu_ps(&i6[8]); + vin2 = _mm256_loadu_ps(&i6[16]); + vin3 = _mm256_loadu_ps(&i6[24]); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = _mm256_mul_ps(vacc0, vscale); + vacc1 = _mm256_mul_ps(vacc1, vscale); + vacc2 = _mm256_mul_ps(vacc2, vscale); + vacc3 = _mm256_mul_ps(vacc3, vscale); + + const float* o = output; + __m256 vo0 = _mm256_loadu_ps(o); o += 8; + __m256 vo1 = _mm256_loadu_ps(o); o += 8; + __m256 vo2 = _mm256_loadu_ps(o); o += 8; + __m256 vo3 = _mm256_loadu_ps(o); o += 8; + vacc0 = _mm256_add_ps(vo0, vacc0); + vacc1 = _mm256_add_ps(vo1, vacc1); + vacc2 = _mm256_add_ps(vo2, vacc2); + vacc3 = _mm256_add_ps(vo3, vacc3); + _mm256_storeu_ps(output, vacc0); output += 8; + _mm256_storeu_ps(output, vacc1); output += 8; + _mm256_storeu_ps(output, vacc2); output += 8; + _mm256_storeu_ps(output, vacc3); output += 8; + + input = (const float*) ((uintptr_t) input + 32 * sizeof(float)); + } + __m256i vmask; + if (channels != 0) { + input_increment = 7 * input_stride; + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride); + const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride); + const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride); + const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride); + const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride); + const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride); + __m256 vacc[4]; + vacc[0] = _mm256_setzero_ps(); + vacc[1] = _mm256_setzero_ps(); + vacc[2] = _mm256_setzero_ps(); + vacc[3] = _mm256_setzero_ps(); + + const size_t num_full_chunks = channels >> 3; + const size_t num_chunks = round_up_po2(channels, 8) >> 3; + const size_t remainder = channels & 0x7; + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i0[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i1[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i2[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i3[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i4[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i5[i*8]), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_loadu_ps(&i6[i*8]), vacc[i]); + } + + if (remainder) { + vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) ¶ms->avx.mask_table[7] - (channels & 0x7) * sizeof(float))); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i0[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i1[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i2[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i3[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i4[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i5[num_full_chunks*8], vmask), vacc[num_full_chunks]); + vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i6[num_full_chunks*8], vmask), vacc[num_full_chunks]); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = _mm256_mul_ps(vacc[i], vscale); + } + + __m256 vo[4]; + const float* o = output; + for (int i = 0; i < channels >> 3; ++i) { + vo[i] = _mm256_loadu_ps(o); o += 8; + } + for (int i = 0; i < channels >> 3; ++i) { + vacc[i] = _mm256_add_ps(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 3; ++i) { + _mm256_storeu_ps(output, vacc[i]); output += 8; + } + if (remainder) { + const size_t pos = num_full_chunks; + __m256 vout = vacc[pos]; + const __m256 vdata = _mm256_maskload_ps(output, vmask); + vout = _mm256_add_ps(vout, vdata); + __m128 vout_lo = _mm256_castps256_ps128(vout); + if (channels & 4) { + _mm_storeu_ps(output, vout_lo); + vout_lo = _mm256_extractf128_ps(vout, 1); + output += 4; + } + if (channels & 2) { + _mm_storel_pi((__m64*) output, vout_lo); + vout_lo = _mm_movehl_ps(vout_lo, vout_lo); + output += 2; + } + if (channels & 1) { + _mm_store_ss(output, vout_lo); + } + } + } +} + void xnn_f32_rmax_ukernel__avx_u32_acc4( size_t batch, const float* input, diff --git a/src/amalgam/gen/avx512f.c b/src/amalgam/gen/avx512f.c index 647974fe9b8..d49b8a2eba4 100644 --- a/src/amalgam/gen/avx512f.c +++ b/src/amalgam/gen/avx512f.c @@ -1990,6 +1990,240 @@ void xnn_f32_prelu_ukernel__avx512f_2x16( } while (rows != 0); } +void xnn_f32_rdsum_ukernel_7p7x__avx512f_c64( + size_t rows, + size_t channels, + const float* input, + size_t input_stride, + const float* zero, + float* output, + const union xnn_f32_scale_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(rows != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const __m512 vscale = _mm512_set1_ps(params->scalar.scale); + + size_t input_increment = 7 * input_stride; + for (; channels >= 64; channels -= 64) { + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride); + const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride); + const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride); + const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride); + const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride); + const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride); + + __m512 vacc0 = _mm512_setzero_ps(); + __m512 vacc1 = _mm512_setzero_ps(); + __m512 vacc2 = _mm512_setzero_ps(); + __m512 vacc3 = _mm512_setzero_ps(); + + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + __m512 vin0; + __m512 vin1; + __m512 vin2; + __m512 vin3; + vin0 = _mm512_loadu_ps(&i0[0]); + vin1 = _mm512_loadu_ps(&i0[16]); + vin2 = _mm512_loadu_ps(&i0[32]); + vin3 = _mm512_loadu_ps(&i0[48]); + vacc0 = _mm512_add_ps(vin0, vacc0); + vacc1 = _mm512_add_ps(vin1, vacc1); + vacc2 = _mm512_add_ps(vin2, vacc2); + vacc3 = _mm512_add_ps(vin3, vacc3); + vin0 = _mm512_loadu_ps(&i1[0]); + vin1 = _mm512_loadu_ps(&i1[16]); + vin2 = _mm512_loadu_ps(&i1[32]); + vin3 = _mm512_loadu_ps(&i1[48]); + vacc0 = _mm512_add_ps(vin0, vacc0); + vacc1 = _mm512_add_ps(vin1, vacc1); + vacc2 = _mm512_add_ps(vin2, vacc2); + vacc3 = _mm512_add_ps(vin3, vacc3); + vin0 = _mm512_loadu_ps(&i2[0]); + vin1 = _mm512_loadu_ps(&i2[16]); + vin2 = _mm512_loadu_ps(&i2[32]); + vin3 = _mm512_loadu_ps(&i2[48]); + vacc0 = _mm512_add_ps(vin0, vacc0); + vacc1 = _mm512_add_ps(vin1, vacc1); + vacc2 = _mm512_add_ps(vin2, vacc2); + vacc3 = _mm512_add_ps(vin3, vacc3); + vin0 = _mm512_loadu_ps(&i3[0]); + vin1 = _mm512_loadu_ps(&i3[16]); + vin2 = _mm512_loadu_ps(&i3[32]); + vin3 = _mm512_loadu_ps(&i3[48]); + vacc0 = _mm512_add_ps(vin0, vacc0); + vacc1 = _mm512_add_ps(vin1, vacc1); + vacc2 = _mm512_add_ps(vin2, vacc2); + vacc3 = _mm512_add_ps(vin3, vacc3); + vin0 = _mm512_loadu_ps(&i4[0]); + vin1 = _mm512_loadu_ps(&i4[16]); + vin2 = _mm512_loadu_ps(&i4[32]); + vin3 = _mm512_loadu_ps(&i4[48]); + vacc0 = _mm512_add_ps(vin0, vacc0); + vacc1 = _mm512_add_ps(vin1, vacc1); + vacc2 = _mm512_add_ps(vin2, vacc2); + vacc3 = _mm512_add_ps(vin3, vacc3); + vin0 = _mm512_loadu_ps(&i5[0]); + vin1 = _mm512_loadu_ps(&i5[16]); + vin2 = _mm512_loadu_ps(&i5[32]); + vin3 = _mm512_loadu_ps(&i5[48]); + vacc0 = _mm512_add_ps(vin0, vacc0); + vacc1 = _mm512_add_ps(vin1, vacc1); + vacc2 = _mm512_add_ps(vin2, vacc2); + vacc3 = _mm512_add_ps(vin3, vacc3); + vin0 = _mm512_loadu_ps(&i6[0]); + vin1 = _mm512_loadu_ps(&i6[16]); + vin2 = _mm512_loadu_ps(&i6[32]); + vin3 = _mm512_loadu_ps(&i6[48]); + vacc0 = _mm512_add_ps(vin0, vacc0); + vacc1 = _mm512_add_ps(vin1, vacc1); + vacc2 = _mm512_add_ps(vin2, vacc2); + vacc3 = _mm512_add_ps(vin3, vacc3); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = _mm512_mul_ps(vacc0, vscale); + vacc1 = _mm512_mul_ps(vacc1, vscale); + vacc2 = _mm512_mul_ps(vacc2, vscale); + vacc3 = _mm512_mul_ps(vacc3, vscale); + + const float* o = output; + const __m512 vo0 = _mm512_loadu_ps(o); o += 16; + const __m512 vo1 = _mm512_loadu_ps(o); o += 16; + const __m512 vo2 = _mm512_loadu_ps(o); o += 16; + const __m512 vo3 = _mm512_loadu_ps(o); o += 16; + vacc0 = _mm512_add_ps(vo0, vacc0); + vacc1 = _mm512_add_ps(vo1, vacc1); + vacc2 = _mm512_add_ps(vo2, vacc2); + vacc3 = _mm512_add_ps(vo3, vacc3); + _mm512_storeu_ps(output, vacc0); output += 16; + _mm512_storeu_ps(output, vacc1); output += 16; + _mm512_storeu_ps(output, vacc2); output += 16; + _mm512_storeu_ps(output, vacc3); output += 16; + + input = (const float*) ((uintptr_t) input + 64 * sizeof(float)); + } + if (channels != 0) { + input_increment = 7 * input_stride; + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride); + const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride); + const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride); + const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride); + const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride); + const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride); + __m512 vacc[4]; + vacc[0] = _mm512_setzero_ps(); + vacc[1] = _mm512_setzero_ps(); + vacc[2] = _mm512_setzero_ps(); + vacc[3] = _mm512_setzero_ps(); + + const size_t num_full_chunks = channels >> 4; + const size_t num_chunks = round_up_po2(channels, 16) >> 4; + const size_t remainder = channels & 0xF; + const size_t batch = channels & 0xF; + __mmask16 vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1))); + if (remainder) { + assert(batch >= 1); + assert(batch <= 15); + vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1))); + } + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i0[i*16]), vacc[i]); + vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i1[i*16]), vacc[i]); + vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i2[i*16]), vacc[i]); + vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i3[i*16]), vacc[i]); + vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i4[i*16]), vacc[i]); + vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i5[i*16]), vacc[i]); + vacc[i] = _mm512_add_ps(_mm512_loadu_ps(&i6[i*16]), vacc[i]); + } + + if (remainder) { + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i0[num_full_chunks*16])); + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i1[num_full_chunks*16])); + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i2[num_full_chunks*16])); + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i3[num_full_chunks*16])); + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i4[num_full_chunks*16])); + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i5[num_full_chunks*16])); + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_maskz_loadu_ps(vmask, &i6[num_full_chunks*16])); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = _mm512_mul_ps(vacc[i], vscale); + } + + __m512 vo[4]; + const float* o = output; + for (int i = 0; i < channels >> 4; ++i) { + vo[i] = _mm512_loadu_ps(o); o += 16; + } + for (int i = 0; i < channels >> 4; ++i) { + vacc[i] = _mm512_add_ps(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 4; ++i) { + _mm512_storeu_ps(output, vacc[i]); output += 16; + } + if (remainder) { + const size_t pos = num_full_chunks; + __m512 vout = vacc[pos]; + vout = _mm512_maskz_add_ps(vmask, vout, _mm512_maskz_loadu_ps(vmask, output)); + _mm512_mask_storeu_ps(output, vmask, vout); + } + } +} + void xnn_f32_rmax_ukernel__avx512f_u64_acc4( size_t batch, const float* input, diff --git a/src/amalgam/gen/avx512skx.c b/src/amalgam/gen/avx512skx.c index 5198496fced..874fdd99190 100644 --- a/src/amalgam/gen/avx512skx.c +++ b/src/amalgam/gen/avx512skx.c @@ -54,6 +54,239 @@ void xnn_f16_f32_vcvt_ukernel__avx512skx_u16( } } +void xnn_f16_f32acc_rdsum_ukernel_7p7x__avx512skx_c64( + size_t rows, + size_t channels, + const void* input, + size_t input_stride, + const void* zero, + void* output, + const union xnn_f16_f32acc_scale_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(rows != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const __m512 vscale = _mm512_set1_ps(params->scalar.scale); + + size_t input_increment = 7 * input_stride; + for (; channels >= 64; channels -= 64) { + const uint16_t* i0 = input; + const uint16_t* i1 = (const uint16_t*) ((uintptr_t) input + 1 * input_stride); + const uint16_t* i2 = (const uint16_t*) ((uintptr_t) input + 2 * input_stride); + const uint16_t* i3 = (const uint16_t*) ((uintptr_t) input + 3 * input_stride); + const uint16_t* i4 = (const uint16_t*) ((uintptr_t) input + 4 * input_stride); + const uint16_t* i5 = (const uint16_t*) ((uintptr_t) input + 5 * input_stride); + const uint16_t* i6 = (const uint16_t*) ((uintptr_t) input + 6 * input_stride); + + __m512 vacc0 = _mm512_setzero_ps(); + __m512 vacc1 = _mm512_setzero_ps(); + __m512 vacc2 = _mm512_setzero_ps(); + __m512 vacc3 = _mm512_setzero_ps(); + + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + __m512 vin0; + __m512 vin1; + __m512 vin2; + __m512 vin3; + vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i0[0]))); + vin1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i0[16]))); + vin2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i0[32]))); + vin3 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i0[48]))); + vacc0 = _mm512_add_ps(vin0, vacc0); + vacc1 = _mm512_add_ps(vin1, vacc1); + vacc2 = _mm512_add_ps(vin2, vacc2); + vacc3 = _mm512_add_ps(vin3, vacc3); + vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i1[0]))); + vin1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i1[16]))); + vin2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i1[32]))); + vin3 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i1[48]))); + vacc0 = _mm512_add_ps(vin0, vacc0); + vacc1 = _mm512_add_ps(vin1, vacc1); + vacc2 = _mm512_add_ps(vin2, vacc2); + vacc3 = _mm512_add_ps(vin3, vacc3); + vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i2[0]))); + vin1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i2[16]))); + vin2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i2[32]))); + vin3 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i2[48]))); + vacc0 = _mm512_add_ps(vin0, vacc0); + vacc1 = _mm512_add_ps(vin1, vacc1); + vacc2 = _mm512_add_ps(vin2, vacc2); + vacc3 = _mm512_add_ps(vin3, vacc3); + vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i3[0]))); + vin1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i3[16]))); + vin2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i3[32]))); + vin3 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i3[48]))); + vacc0 = _mm512_add_ps(vin0, vacc0); + vacc1 = _mm512_add_ps(vin1, vacc1); + vacc2 = _mm512_add_ps(vin2, vacc2); + vacc3 = _mm512_add_ps(vin3, vacc3); + vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i4[0]))); + vin1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i4[16]))); + vin2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i4[32]))); + vin3 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i4[48]))); + vacc0 = _mm512_add_ps(vin0, vacc0); + vacc1 = _mm512_add_ps(vin1, vacc1); + vacc2 = _mm512_add_ps(vin2, vacc2); + vacc3 = _mm512_add_ps(vin3, vacc3); + vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i5[0]))); + vin1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i5[16]))); + vin2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i5[32]))); + vin3 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i5[48]))); + vacc0 = _mm512_add_ps(vin0, vacc0); + vacc1 = _mm512_add_ps(vin1, vacc1); + vacc2 = _mm512_add_ps(vin2, vacc2); + vacc3 = _mm512_add_ps(vin3, vacc3); + vin0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i6[0]))); + vin1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i6[16]))); + vin2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i6[32]))); + vin3 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i6[48]))); + vacc0 = _mm512_add_ps(vin0, vacc0); + vacc1 = _mm512_add_ps(vin1, vacc1); + vacc2 = _mm512_add_ps(vin2, vacc2); + vacc3 = _mm512_add_ps(vin3, vacc3); + i0 = (const uint16_t*) ((uintptr_t) i0 + input_increment); + i1 = (const uint16_t*) ((uintptr_t) i1 + input_increment); + i2 = (const uint16_t*) ((uintptr_t) i2 + input_increment); + i3 = (const uint16_t*) ((uintptr_t) i3 + input_increment); + i4 = (const uint16_t*) ((uintptr_t) i4 + input_increment); + i5 = (const uint16_t*) ((uintptr_t) i5 + input_increment); + i6 = (const uint16_t*) ((uintptr_t) i6 + input_increment); + } + vacc0 = _mm512_mul_ps(vacc0, vscale); + vacc1 = _mm512_mul_ps(vacc1, vscale); + vacc2 = _mm512_mul_ps(vacc2, vscale); + vacc3 = _mm512_mul_ps(vacc3, vscale); + + const uint16_t* o = output; + __m512 vo0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) o)); o = (const void*) ((uintptr_t) o + 16 * sizeof(uint16_t)); + __m512 vo1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) o)); o = (const void*) ((uintptr_t) o + 16 * sizeof(uint16_t)); + __m512 vo2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) o)); o = (const void*) ((uintptr_t) o + 16 * sizeof(uint16_t)); + __m512 vo3 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) o)); o = (const void*) ((uintptr_t) o + 16 * sizeof(uint16_t)); + vacc0 = _mm512_add_ps(vo0, vacc0); + vacc1 = _mm512_add_ps(vo1, vacc1); + vacc2 = _mm512_add_ps(vo2, vacc2); + vacc3 = _mm512_add_ps(vo3, vacc3); + _mm256_storeu_si256((__m256i*) output, _mm512_cvtps_ph(vacc0, _MM_FROUND_TO_NEAREST_INT)); output = (void*) ((uintptr_t) output + 16 * sizeof(uint16_t)); + _mm256_storeu_si256((__m256i*) output, _mm512_cvtps_ph(vacc1, _MM_FROUND_TO_NEAREST_INT)); output = (void*) ((uintptr_t) output + 16 * sizeof(uint16_t)); + _mm256_storeu_si256((__m256i*) output, _mm512_cvtps_ph(vacc2, _MM_FROUND_TO_NEAREST_INT)); output = (void*) ((uintptr_t) output + 16 * sizeof(uint16_t)); + _mm256_storeu_si256((__m256i*) output, _mm512_cvtps_ph(vacc3, _MM_FROUND_TO_NEAREST_INT)); output = (void*) ((uintptr_t) output + 16 * sizeof(uint16_t)); + + input = (const uint16_t*) ((uintptr_t) input + 64 * sizeof(uint16_t)); + } + if (channels != 0) { + input_increment = 7 * input_stride; + const uint16_t* i0 = input; + const uint16_t* i1 = (const uint16_t*) ((uintptr_t) input + 1 * input_stride); + const uint16_t* i2 = (const uint16_t*) ((uintptr_t) input + 2 * input_stride); + const uint16_t* i3 = (const uint16_t*) ((uintptr_t) input + 3 * input_stride); + const uint16_t* i4 = (const uint16_t*) ((uintptr_t) input + 4 * input_stride); + const uint16_t* i5 = (const uint16_t*) ((uintptr_t) input + 5 * input_stride); + const uint16_t* i6 = (const uint16_t*) ((uintptr_t) input + 6 * input_stride); + __m512 vacc[4]; + vacc[0] = _mm512_setzero_ps(); + vacc[1] = _mm512_setzero_ps(); + vacc[2] = _mm512_setzero_ps(); + vacc[3] = _mm512_setzero_ps(); + + const size_t num_full_chunks = channels >> 4; + const size_t num_chunks = round_up_po2(channels, 16) >> 4; + const size_t remainder = channels & 0xF; + const size_t batch = channels & 0xF; + __mmask16 vmask; + if (remainder) { + assert(batch >= 1); + assert(batch <= 15); + vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1))); + } + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + vacc[i] = _mm512_add_ps(_mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) &i0[i*16])), vacc[i]); + vacc[i] = _mm512_add_ps(_mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) &i1[i*16])), vacc[i]); + vacc[i] = _mm512_add_ps(_mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) &i2[i*16])), vacc[i]); + vacc[i] = _mm512_add_ps(_mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) &i3[i*16])), vacc[i]); + vacc[i] = _mm512_add_ps(_mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) &i4[i*16])), vacc[i]); + vacc[i] = _mm512_add_ps(_mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) &i5[i*16])), vacc[i]); + vacc[i] = _mm512_add_ps(_mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) &i6[i*16])), vacc[i]); + } + + if (remainder) { + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, &i0[num_full_chunks*16]))); + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, &i1[num_full_chunks*16]))); + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, &i2[num_full_chunks*16]))); + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, &i3[num_full_chunks*16]))); + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, &i4[num_full_chunks*16]))); + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, &i5[num_full_chunks*16]))); + vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, &i6[num_full_chunks*16]))); + } + i0 = (const uint16_t*) ((uintptr_t) i0 + input_increment); + i1 = (const uint16_t*) ((uintptr_t) i1 + input_increment); + i2 = (const uint16_t*) ((uintptr_t) i2 + input_increment); + i3 = (const uint16_t*) ((uintptr_t) i3 + input_increment); + i4 = (const uint16_t*) ((uintptr_t) i4 + input_increment); + i5 = (const uint16_t*) ((uintptr_t) i5 + input_increment); + i6 = (const uint16_t*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = _mm512_mul_ps(vacc[i], vscale); + } + + __m512 vo[4]; + const uint16_t* o = output; + for (int i = 0; i < num_full_chunks; ++i) { + vo[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) o)); o = (const void*) ((uintptr_t) o + 16 * sizeof(uint16_t)); + } + for (int i = 0; i < num_full_chunks; ++i) { + vacc[i] = _mm512_add_ps(vo[i], vacc[i]); + } + for (int i = 0; i < num_full_chunks; ++i) { + _mm256_storeu_si256((__m256i*) output, _mm512_cvtps_ph(vacc[i], _MM_FROUND_TO_NEAREST_INT)); output = (void*) ((uintptr_t) output + 16 * sizeof(uint16_t)); + } + if (remainder) { + __m512 vout = vacc[num_full_chunks]; + vout = _mm512_maskz_add_ps(vmask, vout, _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, output))); + _mm256_mask_storeu_epi16(output, vmask, _mm512_cvtps_ph(vout, _MM_FROUND_TO_NEAREST_INT)); + } + } +} + void xnn_f16_f32acc_rsum_ukernel__avx512skx_u64_acc4( size_t batch, const void* input, diff --git a/src/amalgam/gen/f16c.c b/src/amalgam/gen/f16c.c index d3529752ead..2fb4c310fcf 100644 --- a/src/amalgam/gen/f16c.c +++ b/src/amalgam/gen/f16c.c @@ -584,6 +584,250 @@ void xnn_f16_f32_vcvt_ukernel__f16c_u16( } } +void xnn_f16_f32acc_rdsum_ukernel_7p7x__f16c_c32( + size_t rows, + size_t channels, + const void* input, + size_t input_stride, + const void* zero, + void* output, + const union xnn_f16_f32acc_scale_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(rows != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const __m256 vscale = _mm256_set1_ps(params->avx.scale); + + size_t input_increment = 7 * input_stride; + for (; channels >= 32; channels -= 32) { + const uint16_t* i0 = input; + const uint16_t* i1 = (const uint16_t*) ((uintptr_t) input + 1 * input_stride); + const uint16_t* i2 = (const uint16_t*) ((uintptr_t) input + 2 * input_stride); + const uint16_t* i3 = (const uint16_t*) ((uintptr_t) input + 3 * input_stride); + const uint16_t* i4 = (const uint16_t*) ((uintptr_t) input + 4 * input_stride); + const uint16_t* i5 = (const uint16_t*) ((uintptr_t) input + 5 * input_stride); + const uint16_t* i6 = (const uint16_t*) ((uintptr_t) input + 6 * input_stride); + + __m256 vacc0 = _mm256_setzero_ps(); + __m256 vacc1 = _mm256_setzero_ps(); + __m256 vacc2 = _mm256_setzero_ps(); + __m256 vacc3 = _mm256_setzero_ps(); + + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + __m256 vin0; + __m256 vin1; + __m256 vin2; + __m256 vin3; + vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i0[0]))); + vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i0[8]))); + vin2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i0[16]))); + vin3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i0[24]))); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i1[0]))); + vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i1[8]))); + vin2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i1[16]))); + vin3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i1[24]))); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i2[0]))); + vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i2[8]))); + vin2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i2[16]))); + vin3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i2[24]))); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i3[0]))); + vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i3[8]))); + vin2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i3[16]))); + vin3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i3[24]))); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i4[0]))); + vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i4[8]))); + vin2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i4[16]))); + vin3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i4[24]))); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i5[0]))); + vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i5[8]))); + vin2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i5[16]))); + vin3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i5[24]))); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + vin0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i6[0]))); + vin1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i6[8]))); + vin2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i6[16]))); + vin3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (&i6[24]))); + vacc0 = _mm256_add_ps(vin0, vacc0); + vacc1 = _mm256_add_ps(vin1, vacc1); + vacc2 = _mm256_add_ps(vin2, vacc2); + vacc3 = _mm256_add_ps(vin3, vacc3); + i0 = (const uint16_t*) ((uintptr_t) i0 + input_increment); + i1 = (const uint16_t*) ((uintptr_t) i1 + input_increment); + i2 = (const uint16_t*) ((uintptr_t) i2 + input_increment); + i3 = (const uint16_t*) ((uintptr_t) i3 + input_increment); + i4 = (const uint16_t*) ((uintptr_t) i4 + input_increment); + i5 = (const uint16_t*) ((uintptr_t) i5 + input_increment); + i6 = (const uint16_t*) ((uintptr_t) i6 + input_increment); + } + vacc0 = _mm256_mul_ps(vacc0, vscale); + vacc1 = _mm256_mul_ps(vacc1, vscale); + vacc2 = _mm256_mul_ps(vacc2, vscale); + vacc3 = _mm256_mul_ps(vacc3, vscale); + + const uint16_t* o = output; + __m256 vo0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) o)); o = (const void*) ((uintptr_t) o + 8 * sizeof(uint16_t)); + __m256 vo1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) o)); o = (const void*) ((uintptr_t) o + 8 * sizeof(uint16_t)); + __m256 vo2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) o)); o = (const void*) ((uintptr_t) o + 8 * sizeof(uint16_t)); + __m256 vo3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) o)); o = (const void*) ((uintptr_t) o + 8 * sizeof(uint16_t)); + vacc0 = _mm256_add_ps(vo0, vacc0); + vacc1 = _mm256_add_ps(vo1, vacc1); + vacc2 = _mm256_add_ps(vo2, vacc2); + vacc3 = _mm256_add_ps(vo3, vacc3); + _mm_storeu_si128((__m128i*) output, _mm256_cvtps_ph(vacc0, _MM_FROUND_TO_NEAREST_INT)); output = (void*) ((uintptr_t) output + 8 * sizeof(uint16_t)); + _mm_storeu_si128((__m128i*) output, _mm256_cvtps_ph(vacc1, _MM_FROUND_TO_NEAREST_INT)); output = (void*) ((uintptr_t) output + 8 * sizeof(uint16_t)); + _mm_storeu_si128((__m128i*) output, _mm256_cvtps_ph(vacc2, _MM_FROUND_TO_NEAREST_INT)); output = (void*) ((uintptr_t) output + 8 * sizeof(uint16_t)); + _mm_storeu_si128((__m128i*) output, _mm256_cvtps_ph(vacc3, _MM_FROUND_TO_NEAREST_INT)); output = (void*) ((uintptr_t) output + 8 * sizeof(uint16_t)); + + input = (const uint16_t*) ((uintptr_t) input + 32 * sizeof(uint16_t)); + } + if (channels != 0) { + input_increment = 7 * input_stride; + const uint16_t* i0 = input; + const uint16_t* i1 = (const uint16_t*) ((uintptr_t) input + 1 * input_stride); + const uint16_t* i2 = (const uint16_t*) ((uintptr_t) input + 2 * input_stride); + const uint16_t* i3 = (const uint16_t*) ((uintptr_t) input + 3 * input_stride); + const uint16_t* i4 = (const uint16_t*) ((uintptr_t) input + 4 * input_stride); + const uint16_t* i5 = (const uint16_t*) ((uintptr_t) input + 5 * input_stride); + const uint16_t* i6 = (const uint16_t*) ((uintptr_t) input + 6 * input_stride); + __m256 vacc[4]; + vacc[0] = _mm256_setzero_ps(); + vacc[1] = _mm256_setzero_ps(); + vacc[2] = _mm256_setzero_ps(); + vacc[3] = _mm256_setzero_ps(); + + const size_t num_full_chunks = channels >> 3; + const size_t num_chunks = round_up_po2(channels, 8) >> 3; + const size_t remainder = channels & 0x7; + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_full_chunks; ++i) { + vacc[i] = _mm256_add_ps(_mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i0[i*8])), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i1[i*8])), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i2[i*8])), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i3[i*8])), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i4[i*8])), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i5[i*8])), vacc[i]); + vacc[i] = _mm256_add_ps(_mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i6[i*8])), vacc[i]); + } + + if (remainder) { + vacc[num_full_chunks] = _mm256_add_ps(vacc[num_full_chunks], _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i0[num_full_chunks*8]))); + vacc[num_full_chunks] = _mm256_add_ps(vacc[num_full_chunks], _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i1[num_full_chunks*8]))); + vacc[num_full_chunks] = _mm256_add_ps(vacc[num_full_chunks], _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i2[num_full_chunks*8]))); + vacc[num_full_chunks] = _mm256_add_ps(vacc[num_full_chunks], _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i3[num_full_chunks*8]))); + vacc[num_full_chunks] = _mm256_add_ps(vacc[num_full_chunks], _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i4[num_full_chunks*8]))); + vacc[num_full_chunks] = _mm256_add_ps(vacc[num_full_chunks], _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i5[num_full_chunks*8]))); + vacc[num_full_chunks] = _mm256_add_ps(vacc[num_full_chunks], _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) &i6[num_full_chunks*8]))); + } + i0 = (const uint16_t*) ((uintptr_t) i0 + input_increment); + i1 = (const uint16_t*) ((uintptr_t) i1 + input_increment); + i2 = (const uint16_t*) ((uintptr_t) i2 + input_increment); + i3 = (const uint16_t*) ((uintptr_t) i3 + input_increment); + i4 = (const uint16_t*) ((uintptr_t) i4 + input_increment); + i5 = (const uint16_t*) ((uintptr_t) i5 + input_increment); + i6 = (const uint16_t*) ((uintptr_t) i6 + input_increment); + } + for (size_t i = 0; i < num_chunks; ++i) { + vacc[i] = _mm256_mul_ps(vacc[i], vscale); + } + + __m256 vo[4]; + const uint16_t* o = output; + for (int i = 0; i < num_full_chunks; ++i) { + vo[i] = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) o)); o = (const void*) ((uintptr_t) o + 8 * sizeof(uint16_t)); + } + for (int i = 0; i < num_full_chunks; ++i) { + vacc[i] = _mm256_add_ps(vo[i], vacc[i]); + } + for (int i = 0; i < num_full_chunks; ++i) { + _mm_storeu_si128((__m128i*) output, _mm256_cvtps_ph(vacc[i], _MM_FROUND_TO_NEAREST_INT)); output = (void*) ((uintptr_t) output + 8 * sizeof(uint16_t)); + } + if (remainder) { + __m256 vout = vacc[num_full_chunks]; + __m128 vout_low = _mm256_castps256_ps128(vout); + if (channels & 4) { + __m128 vo = _mm_cvtph_ps(_mm_loadl_epi64((__m128i*) output)); + vo = _mm_add_ps(vout_low, vo); + _mm_storel_epi64((__m128i*) output, _mm_cvtps_ph(vo, _MM_FROUND_TO_NEAREST_INT)); + vout_low = _mm256_castps256_ps128(_mm256_permute2f128_ps(vout, vout, 1)); + output = (void*) ((uintptr_t) output + 4 * sizeof(uint16_t)); + } + if (channels & 2) { + __m128 vo = _mm_cvtph_ps(_mm_loadu_si32(output)); + vo = _mm_add_ps(vout_low, vo); + _mm_storeu_si32(output, _mm_cvtps_ph(vo, _MM_FROUND_TO_NEAREST_INT)); + vout_low = _mm_movehl_ps(vout_low, vout_low); + output = (void*) ((uintptr_t) output + 2 * sizeof(uint16_t)); + } + if (channels & 1) { + __m128 vo = _mm_cvtph_ps(_mm_loadu_si16(output)); + vo = _mm_add_ps(vout_low, vo); + _mm_storeu_si16(output, _mm_cvtps_ph(vo, _MM_FROUND_TO_NEAREST_INT)); + } + } + } +} + void xnn_f16_f32acc_rsum_ukernel__f16c_u32_acc4( size_t batch, const void* input, diff --git a/src/amalgam/gen/neon.c b/src/amalgam/gen/neon.c index bb6fec6ce8c..508d360b331 100644 --- a/src/amalgam/gen/neon.c +++ b/src/amalgam/gen/neon.c @@ -8033,6 +8033,225 @@ void xnn_f32_raddstoreexpminusmax_ukernel__neon_rr2_lut64_p2_u8( #endif } +void xnn_f32_rdsum_ukernel_7p7x__neon_c16( + size_t rows, + size_t channels, + const float* input, + size_t input_stride, + const float* zero, + float* output, + const union xnn_f32_scale_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(rows != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const float32x4_t vscale = vdupq_n_f32(params->scalar.scale); + + size_t input_increment = 7 * input_stride; + for (; channels >= 16; channels -= 16) { + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride); + const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride); + const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride); + const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride); + const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride); + const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride); + + float32x4_t vacc0 = vdupq_n_f32(0.f); + float32x4_t vacc1 = vdupq_n_f32(0.f); + float32x4_t vacc2 = vdupq_n_f32(0.f); + float32x4_t vacc3 = vdupq_n_f32(0.f); + + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + float32x4_t vin0; + float32x4_t vin1; + float32x4_t vin2; + float32x4_t vin3; + vin0 = vld1q_f32(&i0[0]); + vin1 = vld1q_f32(&i0[4]); + vin2 = vld1q_f32(&i0[8]); + vin3 = vld1q_f32(&i0[12]); + vacc0 = vaddq_f32(vin0, vacc0); + vacc1 = vaddq_f32(vin1, vacc1); + vacc2 = vaddq_f32(vin2, vacc2); + vacc3 = vaddq_f32(vin3, vacc3); + vin0 = vld1q_f32(&i1[0]); + vin1 = vld1q_f32(&i1[4]); + vin2 = vld1q_f32(&i1[8]); + vin3 = vld1q_f32(&i1[12]); + vacc0 = vaddq_f32(vin0, vacc0); + vacc1 = vaddq_f32(vin1, vacc1); + vacc2 = vaddq_f32(vin2, vacc2); + vacc3 = vaddq_f32(vin3, vacc3); + vin0 = vld1q_f32(&i2[0]); + vin1 = vld1q_f32(&i2[4]); + vin2 = vld1q_f32(&i2[8]); + vin3 = vld1q_f32(&i2[12]); + vacc0 = vaddq_f32(vin0, vacc0); + vacc1 = vaddq_f32(vin1, vacc1); + vacc2 = vaddq_f32(vin2, vacc2); + vacc3 = vaddq_f32(vin3, vacc3); + vin0 = vld1q_f32(&i3[0]); + vin1 = vld1q_f32(&i3[4]); + vin2 = vld1q_f32(&i3[8]); + vin3 = vld1q_f32(&i3[12]); + vacc0 = vaddq_f32(vin0, vacc0); + vacc1 = vaddq_f32(vin1, vacc1); + vacc2 = vaddq_f32(vin2, vacc2); + vacc3 = vaddq_f32(vin3, vacc3); + vin0 = vld1q_f32(&i4[0]); + vin1 = vld1q_f32(&i4[4]); + vin2 = vld1q_f32(&i4[8]); + vin3 = vld1q_f32(&i4[12]); + vacc0 = vaddq_f32(vin0, vacc0); + vacc1 = vaddq_f32(vin1, vacc1); + vacc2 = vaddq_f32(vin2, vacc2); + vacc3 = vaddq_f32(vin3, vacc3); + vin0 = vld1q_f32(&i5[0]); + vin1 = vld1q_f32(&i5[4]); + vin2 = vld1q_f32(&i5[8]); + vin3 = vld1q_f32(&i5[12]); + vacc0 = vaddq_f32(vin0, vacc0); + vacc1 = vaddq_f32(vin1, vacc1); + vacc2 = vaddq_f32(vin2, vacc2); + vacc3 = vaddq_f32(vin3, vacc3); + vin0 = vld1q_f32(&i6[0]); + vin1 = vld1q_f32(&i6[4]); + vin2 = vld1q_f32(&i6[8]); + vin3 = vld1q_f32(&i6[12]); + vacc0 = vaddq_f32(vin0, vacc0); + vacc1 = vaddq_f32(vin1, vacc1); + vacc2 = vaddq_f32(vin2, vacc2); + vacc3 = vaddq_f32(vin3, vacc3); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = vmulq_f32(vacc0, vscale); + vacc1 = vmulq_f32(vacc1, vscale); + vacc2 = vmulq_f32(vacc2, vscale); + vacc3 = vmulq_f32(vacc3, vscale); + + const float* o = output; + float32x4_t vo0 = vld1q_f32(o); o += 4; + float32x4_t vo1 = vld1q_f32(o); o += 4; + float32x4_t vo2 = vld1q_f32(o); o += 4; + float32x4_t vo3 = vld1q_f32(o); o += 4; + vacc0 = vaddq_f32(vo0, vacc0); + vacc1 = vaddq_f32(vo1, vacc1); + vacc2 = vaddq_f32(vo2, vacc2); + vacc3 = vaddq_f32(vo3, vacc3); + vst1q_f32(output, vacc0); output += 4; + vst1q_f32(output, vacc1); output += 4; + vst1q_f32(output, vacc2); output += 4; + vst1q_f32(output, vacc3); output += 4; + + input = (const float*) ((uintptr_t) input + 16 * sizeof(float)); + } + if (channels != 0) { + size_t input_increment = 7 * input_stride; + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride); + const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride); + const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride); + const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride); + const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride); + const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride); + float32x4_t vacc[4]; + vacc[0] = vdupq_n_f32(0.f); + vacc[1] = vdupq_n_f32(0.f); + vacc[2] = vdupq_n_f32(0.f); + vacc[3] = vdupq_n_f32(0.f); + + size_t num_chunks = round_up_po2(channels, 4) >> 2; + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_chunks; ++i) { + vacc[i] = vaddq_f32(vld1q_f32(&i0[i*4]), vacc[i]); + vacc[i] = vaddq_f32(vld1q_f32(&i1[i*4]), vacc[i]); + vacc[i] = vaddq_f32(vld1q_f32(&i2[i*4]), vacc[i]); + vacc[i] = vaddq_f32(vld1q_f32(&i3[i*4]), vacc[i]); + vacc[i] = vaddq_f32(vld1q_f32(&i4[i*4]), vacc[i]); + vacc[i] = vaddq_f32(vld1q_f32(&i5[i*4]), vacc[i]); + vacc[i] = vaddq_f32(vld1q_f32(&i6[i*4]), vacc[i]); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (int i = 0; i < (channels + 4) >> 2; ++i) { + vacc[i] = vmulq_f32(vacc[i], vscale); + } + + float32x4_t vo[4]; + const float* o = output; + for (int i = 0; i < channels >> 2; ++i) { + vo[i] = vld1q_f32(o); o += 4; + } + for (int i = 0; i < channels >> 2; ++i) { + vacc[i] = vaddq_f32(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 2; ++i) { + vst1q_f32(output, vacc[i]); output += 4; + } + size_t pos = channels >> 2; + channels &= 0x3; + float32x2_t vacc_low = vget_low_f32(vacc[pos]); + if (channels & 2) { + vst1_f32(output, vadd_f32(vld1_f32(output), vacc_low)); output += 2; + vacc_low = vget_high_f32(vacc[pos]); + } + if (channels & 1) { + vst1_lane_f32(output, vadd_f32(vld1_dup_f32(output), vacc_low), 0); + } + } +} + void xnn_f32_rmax_ukernel__neon_u16_acc4( size_t batch, const float* input, diff --git a/src/amalgam/gen/neonfp16arith.c b/src/amalgam/gen/neonfp16arith.c index 7bf3a5aaa01..d7bb9c2e8b3 100644 --- a/src/amalgam/gen/neonfp16arith.c +++ b/src/amalgam/gen/neonfp16arith.c @@ -4259,6 +4259,228 @@ void xnn_f16_dwconv2d_chw_ukernel_5x5s2p2__neonfp16arith_1x8( } while (output_height != 0); } +void xnn_f16_f32acc_rdsum_ukernel_7p7x__neonfp16arith_c16( + size_t rows, + size_t channels, + const void* input, + size_t input_stride, + const void* zero, + void* output, + const union xnn_f16_f32acc_scale_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(rows != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const float32x4_t vscale = vdupq_n_f32(params->scalar.scale); + + size_t input_increment = 7 * input_stride; + for (; channels >= 16; channels -= 16) { + const uint16_t* i0 = input; + const uint16_t* i1 = (const uint16_t*) ((uintptr_t) input + 1 * input_stride); + const uint16_t* i2 = (const uint16_t*) ((uintptr_t) input + 2 * input_stride); + const uint16_t* i3 = (const uint16_t*) ((uintptr_t) input + 3 * input_stride); + const uint16_t* i4 = (const uint16_t*) ((uintptr_t) input + 4 * input_stride); + const uint16_t* i5 = (const uint16_t*) ((uintptr_t) input + 5 * input_stride); + const uint16_t* i6 = (const uint16_t*) ((uintptr_t) input + 6 * input_stride); + + float32x4_t vacc0 = vdupq_n_f32(0.f); + float32x4_t vacc1 = vdupq_n_f32(0.f); + float32x4_t vacc2 = vdupq_n_f32(0.f); + float32x4_t vacc3 = vdupq_n_f32(0.f); + + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + float32x4_t vin0; + float32x4_t vin1; + float32x4_t vin2; + float32x4_t vin3; + vin0 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i0[0]))); + vin1 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i0[4]))); + vin2 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i0[8]))); + vin3 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i0[12]))); + vacc0 = vaddq_f32(vin0, vacc0); + vacc1 = vaddq_f32(vin1, vacc1); + vacc2 = vaddq_f32(vin2, vacc2); + vacc3 = vaddq_f32(vin3, vacc3); + vin0 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i1[0]))); + vin1 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i1[4]))); + vin2 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i1[8]))); + vin3 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i1[12]))); + vacc0 = vaddq_f32(vin0, vacc0); + vacc1 = vaddq_f32(vin1, vacc1); + vacc2 = vaddq_f32(vin2, vacc2); + vacc3 = vaddq_f32(vin3, vacc3); + vin0 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i2[0]))); + vin1 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i2[4]))); + vin2 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i2[8]))); + vin3 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i2[12]))); + vacc0 = vaddq_f32(vin0, vacc0); + vacc1 = vaddq_f32(vin1, vacc1); + vacc2 = vaddq_f32(vin2, vacc2); + vacc3 = vaddq_f32(vin3, vacc3); + vin0 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i3[0]))); + vin1 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i3[4]))); + vin2 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i3[8]))); + vin3 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i3[12]))); + vacc0 = vaddq_f32(vin0, vacc0); + vacc1 = vaddq_f32(vin1, vacc1); + vacc2 = vaddq_f32(vin2, vacc2); + vacc3 = vaddq_f32(vin3, vacc3); + vin0 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i4[0]))); + vin1 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i4[4]))); + vin2 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i4[8]))); + vin3 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i4[12]))); + vacc0 = vaddq_f32(vin0, vacc0); + vacc1 = vaddq_f32(vin1, vacc1); + vacc2 = vaddq_f32(vin2, vacc2); + vacc3 = vaddq_f32(vin3, vacc3); + vin0 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i5[0]))); + vin1 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i5[4]))); + vin2 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i5[8]))); + vin3 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i5[12]))); + vacc0 = vaddq_f32(vin0, vacc0); + vacc1 = vaddq_f32(vin1, vacc1); + vacc2 = vaddq_f32(vin2, vacc2); + vacc3 = vaddq_f32(vin3, vacc3); + vin0 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i6[0]))); + vin1 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i6[4]))); + vin2 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i6[8]))); + vin3 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i6[12]))); + vacc0 = vaddq_f32(vin0, vacc0); + vacc1 = vaddq_f32(vin1, vacc1); + vacc2 = vaddq_f32(vin2, vacc2); + vacc3 = vaddq_f32(vin3, vacc3); + i0 = (const uint16_t*) ((uintptr_t) i0 + input_increment); + i1 = (const uint16_t*) ((uintptr_t) i1 + input_increment); + i2 = (const uint16_t*) ((uintptr_t) i2 + input_increment); + i3 = (const uint16_t*) ((uintptr_t) i3 + input_increment); + i4 = (const uint16_t*) ((uintptr_t) i4 + input_increment); + i5 = (const uint16_t*) ((uintptr_t) i5 + input_increment); + i6 = (const uint16_t*) ((uintptr_t) i6 + input_increment); + } + vacc0 = vmulq_f32(vacc0, vscale); + vacc1 = vmulq_f32(vacc1, vscale); + vacc2 = vmulq_f32(vacc2, vscale); + vacc3 = vmulq_f32(vacc3, vscale); + + const uint16_t* o = (const uint16_t*) output; + float16x4_t vo0 = vreinterpret_f16_u16(vld1_u16(o)); o += 4; + float16x4_t vo1 = vreinterpret_f16_u16(vld1_u16(o)); o += 4; + float16x4_t vo2 = vreinterpret_f16_u16(vld1_u16(o)); o += 4; + float16x4_t vo3 = vreinterpret_f16_u16(vld1_u16(o)); o += 4; + float16x4_t vfp16_out0 = vadd_f16(vo0, vcvt_f16_f32(vacc0)); + float16x4_t vfp16_out1 = vadd_f16(vo1, vcvt_f16_f32(vacc1)); + float16x4_t vfp16_out2 = vadd_f16(vo2, vcvt_f16_f32(vacc2)); + float16x4_t vfp16_out3 = vadd_f16(vo3, vcvt_f16_f32(vacc3)); + vst1_u16(output, vreinterpret_u16_f16(vfp16_out0)); output = (void*) ((uintptr_t) output + 4 * sizeof(uint16_t)); + vst1_u16(output, vreinterpret_u16_f16(vfp16_out1)); output = (void*) ((uintptr_t) output + 4 * sizeof(uint16_t)); + vst1_u16(output, vreinterpret_u16_f16(vfp16_out2)); output = (void*) ((uintptr_t) output + 4 * sizeof(uint16_t)); + vst1_u16(output, vreinterpret_u16_f16(vfp16_out3)); output = (void*) ((uintptr_t) output + 4 * sizeof(uint16_t)); + + input = (const uint16_t*) ((uintptr_t) input + 16 * sizeof(uint16_t)); + } + if (channels != 0) { + input_increment = 7 * input_stride; + const uint16_t* i0 = input; + const uint16_t* i1 = (const uint16_t*) ((uintptr_t) input + 1 * input_stride); + const uint16_t* i2 = (const uint16_t*) ((uintptr_t) input + 2 * input_stride); + const uint16_t* i3 = (const uint16_t*) ((uintptr_t) input + 3 * input_stride); + const uint16_t* i4 = (const uint16_t*) ((uintptr_t) input + 4 * input_stride); + const uint16_t* i5 = (const uint16_t*) ((uintptr_t) input + 5 * input_stride); + const uint16_t* i6 = (const uint16_t*) ((uintptr_t) input + 6 * input_stride); + float32x4_t vacc[4]; + vacc[0] = vdupq_n_f32(0.f); + vacc[1] = vdupq_n_f32(0.f); + vacc[2] = vdupq_n_f32(0.f); + vacc[3] = vdupq_n_f32(0.f); + + const size_t num_chunks = round_up_po2(channels, 4) >> 2; + const size_t num_full_chunks = channels >> 2; + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_chunks; ++i) { + vacc[i] = vaddq_f32(vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i0[i*4]))), vacc[i]); + vacc[i] = vaddq_f32(vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i1[i*4]))), vacc[i]); + vacc[i] = vaddq_f32(vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i2[i*4]))), vacc[i]); + vacc[i] = vaddq_f32(vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i3[i*4]))), vacc[i]); + vacc[i] = vaddq_f32(vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i4[i*4]))), vacc[i]); + vacc[i] = vaddq_f32(vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i5[i*4]))), vacc[i]); + vacc[i] = vaddq_f32(vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(&i6[i*4]))), vacc[i]); + } + i0 = (const uint16_t*) ((uintptr_t) i0 + input_increment); + i1 = (const uint16_t*) ((uintptr_t) i1 + input_increment); + i2 = (const uint16_t*) ((uintptr_t) i2 + input_increment); + i3 = (const uint16_t*) ((uintptr_t) i3 + input_increment); + i4 = (const uint16_t*) ((uintptr_t) i4 + input_increment); + i5 = (const uint16_t*) ((uintptr_t) i5 + input_increment); + i6 = (const uint16_t*) ((uintptr_t) i6 + input_increment); + } + for (int i = 0; i < (channels + 4) >> 2; ++i) { + vacc[i] = vmulq_f32(vacc[i], vscale); + } + + float16x4_t vo[4]; + const uint16_t* o = (const uint16_t*) output; + for (int i = 0; i < num_full_chunks; ++i) { + vo[i] = vreinterpret_f16_u16(vld1_u16(o)); o += 4; + } + float16x4_t vfp16_out[4]; + for (int i = 0; i < num_full_chunks; ++i) { + vfp16_out[i] = vadd_f16(vo[i], vcvt_f16_f32(vacc[i])); + } + for (int i = 0; i < num_full_chunks; ++i) { + vst1_u16(output, vreinterpret_u16_f16(vfp16_out[i])); output = (void*) ((uintptr_t) output + 4 * sizeof(uint16_t)); + } + + const size_t pos = channels >> 2; + channels &= 0x3; + float16x4_t vacc_low = vcvt_f16_f32(vacc[pos]); + if (channels & 2) { + vst1_lane_u32(output, vreinterpret_u32_f16(vadd_f16(vacc_low, vreinterpret_f16_u32(vld1_dup_u32(output)))), 0); output = (void*) ((uintptr_t) output + 2 * sizeof(uint16_t)); + vacc_low = vext_f16(vacc_low, vacc_low, 2); + } + if (channels & 1) { + vst1_lane_u16(output, vreinterpret_u16_f16(vadd_f16(vacc_low, vreinterpret_f16_u16(vld1_dup_u16(output)))), 0); + } + } +} + void xnn_f16_f32acc_rsum_ukernel__neonfp16arith_u32_acc4( size_t batch, const void* input, diff --git a/src/amalgam/gen/scalar.c b/src/amalgam/gen/scalar.c index 252573b4e7b..4cf264af9ef 100644 --- a/src/amalgam/gen/scalar.c +++ b/src/amalgam/gen/scalar.c @@ -10482,6 +10482,179 @@ void xnn_f32_raddstoreexpminusmax_ukernel__scalar_rr2_p5_u4_acc2( *sum = vacc; } +void xnn_f32_rdsum_ukernel_7p7x__scalar_c4( + size_t rows, + size_t channels, + const float* input, + size_t input_stride, + const float* zero, + float* output, + const union xnn_f32_scale_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(rows != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const float vscale = params->scalar.scale; + + size_t input_increment = 7 * input_stride; + for (; channels >= 4; channels -= 4) { + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) i0 + input_stride); + const float* i2 = (const float*) ((uintptr_t) i1 + input_stride); + const float* i3 = (const float*) ((uintptr_t) i2 + input_stride); + const float* i4 = (const float*) ((uintptr_t) i3 + input_stride); + const float* i5 = (const float*) ((uintptr_t) i4 + input_stride); + const float* i6 = (const float*) ((uintptr_t) i5 + input_stride); + float vacc0 = 0.f; + float vacc1 = 0.f; + float vacc2 = 0.f; + float vacc3 = 0.f; + + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + vacc0 += i0[0]; + vacc1 += i0[1]; + vacc2 += i0[2]; + vacc3 += i0[3]; + vacc0 += i1[0]; + vacc1 += i1[1]; + vacc2 += i1[2]; + vacc3 += i1[3]; + vacc0 += i2[0]; + vacc1 += i2[1]; + vacc2 += i2[2]; + vacc3 += i2[3]; + vacc0 += i3[0]; + vacc1 += i3[1]; + vacc2 += i3[2]; + vacc3 += i3[3]; + vacc0 += i4[0]; + vacc1 += i4[1]; + vacc2 += i4[2]; + vacc3 += i4[3]; + vacc0 += i5[0]; + vacc1 += i5[1]; + vacc2 += i5[2]; + vacc3 += i5[3]; + vacc0 += i6[0]; + vacc1 += i6[1]; + vacc2 += i6[2]; + vacc3 += i6[3]; + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = vacc0 * vscale; + vacc1 = vacc1 * vscale; + vacc2 = vacc2 * vscale; + vacc3 = vacc3 * vscale; + + *output++ += vacc0; + *output++ += vacc1; + *output++ += vacc2; + *output++ += vacc3; + + input = (const float*) ((uintptr_t) input + 4 * sizeof(float)); + } + if (channels != 0) { + size_t input_increment = 7 * input_stride; + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) i0 + input_stride); + const float* i2 = (const float*) ((uintptr_t) i1 + input_stride); + const float* i3 = (const float*) ((uintptr_t) i2 + input_stride); + const float* i4 = (const float*) ((uintptr_t) i3 + input_stride); + const float* i5 = (const float*) ((uintptr_t) i4 + input_stride); + const float* i6 = (const float*) ((uintptr_t) i5 + input_stride); + float vacc0 = 0.f; + float vacc1 = 0.f; + float vacc2 = 0.f; + + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + vacc0 += i0[0]; + vacc1 += i0[1]; + vacc2 += i0[2]; + vacc0 += i1[0]; + vacc1 += i1[1]; + vacc2 += i1[2]; + vacc0 += i2[0]; + vacc1 += i2[1]; + vacc2 += i2[2]; + vacc0 += i3[0]; + vacc1 += i3[1]; + vacc2 += i3[2]; + vacc0 += i4[0]; + vacc1 += i4[1]; + vacc2 += i4[2]; + vacc0 += i5[0]; + vacc1 += i5[1]; + vacc2 += i5[2]; + vacc0 += i6[0]; + vacc1 += i6[1]; + vacc2 += i6[2]; + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = vacc0 * vscale; + vacc1 = vacc1 * vscale; + vacc2 = vacc2 * vscale; + + if (channels & 2) { + *output++ += vacc0; + *output++ += vacc1; + vacc0 = vacc2; + } + if (channels & 1) { + *output++ += vacc0; + } + } +} + void xnn_f32_rmax_ukernel__scalar_u4_acc4( size_t batch, const float* input, diff --git a/src/amalgam/gen/sse.c b/src/amalgam/gen/sse.c index 36244e02241..d6da10a9ae3 100644 --- a/src/amalgam/gen/sse.c +++ b/src/amalgam/gen/sse.c @@ -7372,6 +7372,228 @@ void xnn_f32_pavgpool_minmax_ukernel_9x__sse_c4( } while (--output_pixels != 0); } +void xnn_f32_rdsum_ukernel_7p7x__sse_c16( + size_t rows, + size_t channels, + const float* input, + size_t input_stride, + const float* zero, + float* output, + const union xnn_f32_scale_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(rows != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const __m128 vscale = _mm_load_ps(params->sse.scale); + + size_t input_increment = 7 * input_stride; + for (; channels >= 16; channels -= 16) { + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride); + const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride); + const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride); + const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride); + const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride); + const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride); + + __m128 vacc0 = _mm_setzero_ps(); + __m128 vacc1 = _mm_setzero_ps(); + __m128 vacc2 = _mm_setzero_ps(); + __m128 vacc3 = _mm_setzero_ps(); + + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + __m128 vin0; + __m128 vin1; + __m128 vin2; + __m128 vin3; + vin0 = _mm_loadu_ps(&i0[0]); + vin1 = _mm_loadu_ps(&i0[4]); + vin2 = _mm_loadu_ps(&i0[8]); + vin3 = _mm_loadu_ps(&i0[12]); + vacc0 = _mm_add_ps(vin0, vacc0); + vacc1 = _mm_add_ps(vin1, vacc1); + vacc2 = _mm_add_ps(vin2, vacc2); + vacc3 = _mm_add_ps(vin3, vacc3); + vin0 = _mm_loadu_ps(&i1[0]); + vin1 = _mm_loadu_ps(&i1[4]); + vin2 = _mm_loadu_ps(&i1[8]); + vin3 = _mm_loadu_ps(&i1[12]); + vacc0 = _mm_add_ps(vin0, vacc0); + vacc1 = _mm_add_ps(vin1, vacc1); + vacc2 = _mm_add_ps(vin2, vacc2); + vacc3 = _mm_add_ps(vin3, vacc3); + vin0 = _mm_loadu_ps(&i2[0]); + vin1 = _mm_loadu_ps(&i2[4]); + vin2 = _mm_loadu_ps(&i2[8]); + vin3 = _mm_loadu_ps(&i2[12]); + vacc0 = _mm_add_ps(vin0, vacc0); + vacc1 = _mm_add_ps(vin1, vacc1); + vacc2 = _mm_add_ps(vin2, vacc2); + vacc3 = _mm_add_ps(vin3, vacc3); + vin0 = _mm_loadu_ps(&i3[0]); + vin1 = _mm_loadu_ps(&i3[4]); + vin2 = _mm_loadu_ps(&i3[8]); + vin3 = _mm_loadu_ps(&i3[12]); + vacc0 = _mm_add_ps(vin0, vacc0); + vacc1 = _mm_add_ps(vin1, vacc1); + vacc2 = _mm_add_ps(vin2, vacc2); + vacc3 = _mm_add_ps(vin3, vacc3); + vin0 = _mm_loadu_ps(&i4[0]); + vin1 = _mm_loadu_ps(&i4[4]); + vin2 = _mm_loadu_ps(&i4[8]); + vin3 = _mm_loadu_ps(&i4[12]); + vacc0 = _mm_add_ps(vin0, vacc0); + vacc1 = _mm_add_ps(vin1, vacc1); + vacc2 = _mm_add_ps(vin2, vacc2); + vacc3 = _mm_add_ps(vin3, vacc3); + vin0 = _mm_loadu_ps(&i5[0]); + vin1 = _mm_loadu_ps(&i5[4]); + vin2 = _mm_loadu_ps(&i5[8]); + vin3 = _mm_loadu_ps(&i5[12]); + vacc0 = _mm_add_ps(vin0, vacc0); + vacc1 = _mm_add_ps(vin1, vacc1); + vacc2 = _mm_add_ps(vin2, vacc2); + vacc3 = _mm_add_ps(vin3, vacc3); + vin0 = _mm_loadu_ps(&i6[0]); + vin1 = _mm_loadu_ps(&i6[4]); + vin2 = _mm_loadu_ps(&i6[8]); + vin3 = _mm_loadu_ps(&i6[12]); + vacc0 = _mm_add_ps(vin0, vacc0); + vacc1 = _mm_add_ps(vin1, vacc1); + vacc2 = _mm_add_ps(vin2, vacc2); + vacc3 = _mm_add_ps(vin3, vacc3); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = _mm_mul_ps(vacc0, vscale); + vacc1 = _mm_mul_ps(vacc1, vscale); + vacc2 = _mm_mul_ps(vacc2, vscale); + vacc3 = _mm_mul_ps(vacc3, vscale); + + const float* o = output; + __m128 vo0 = _mm_loadu_ps(o); o += 4; + __m128 vo1 = _mm_loadu_ps(o); o += 4; + __m128 vo2 = _mm_loadu_ps(o); o += 4; + __m128 vo3 = _mm_loadu_ps(o); o += 4; + vacc0 = _mm_add_ps(vo0, vacc0); + vacc1 = _mm_add_ps(vo1, vacc1); + vacc2 = _mm_add_ps(vo2, vacc2); + vacc3 = _mm_add_ps(vo3, vacc3); + _mm_storeu_ps(output, vacc0); output += 4; + _mm_storeu_ps(output, vacc1); output += 4; + _mm_storeu_ps(output, vacc2); output += 4; + _mm_storeu_ps(output, vacc3); output += 4; + + input = (const float*) ((uintptr_t) input + 16 * sizeof(float)); + } + if (channels != 0) { + input_increment = 7 * input_stride; + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride); + const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride); + const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride); + const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride); + const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride); + const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride); + __m128 vacc[4]; + vacc[0] = _mm_setzero_ps(); + vacc[1] = _mm_setzero_ps(); + vacc[2] = _mm_setzero_ps(); + vacc[3] = _mm_setzero_ps(); + + size_t num_chunks = round_up_po2(channels, 4) >> 2; + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_chunks; ++i) { + vacc[i] = _mm_add_ps(_mm_loadu_ps(&i0[i*4]), vacc[i]); + vacc[i] = _mm_add_ps(_mm_loadu_ps(&i1[i*4]), vacc[i]); + vacc[i] = _mm_add_ps(_mm_loadu_ps(&i2[i*4]), vacc[i]); + vacc[i] = _mm_add_ps(_mm_loadu_ps(&i3[i*4]), vacc[i]); + vacc[i] = _mm_add_ps(_mm_loadu_ps(&i4[i*4]), vacc[i]); + vacc[i] = _mm_add_ps(_mm_loadu_ps(&i5[i*4]), vacc[i]); + vacc[i] = _mm_add_ps(_mm_loadu_ps(&i6[i*4]), vacc[i]); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (int i = 0; i < num_chunks; ++i) { + vacc[i] = _mm_mul_ps(vacc[i], vscale); + } + + __m128 vo[4]; + const float* o = output; + for (int i = 0; i < channels >> 2; ++i) { + vo[i] = _mm_loadu_ps(o); o += 4; + } + for (int i = 0; i < channels >> 2; ++i) { + vacc[i] = _mm_add_ps(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 2; ++i) { + _mm_storeu_ps(output, vacc[i]); output += 4; + } + const size_t pos = channels >> 2; + channels &= 0x3; + __m128 vout = vacc[pos]; + if (channels & 2) { + __m128 vo = _mm_loadl_pi(vscale, (__m64*) output); + _mm_storel_pi((__m64*) output, _mm_add_ps(vo, vout)); + vout = _mm_movehl_ps(vout, vout); + output += 2; + } + if (channels & 1) { + __m128 vo = _mm_load_ss(output); + _mm_store_ss(output, _mm_add_ps(vo, vout)); + } + } +} + void xnn_f32_rmax_ukernel__sse_u16_acc4( size_t batch, const float* input, diff --git a/src/amalgam/gen/wasmsimd.c b/src/amalgam/gen/wasmsimd.c index 12f7b67e0dd..88e6e5d853a 100644 --- a/src/amalgam/gen/wasmsimd.c +++ b/src/amalgam/gen/wasmsimd.c @@ -22387,6 +22387,226 @@ void xnn_f32_raddstoreexpminusmax_ukernel__wasmsimd_rr2_p5_u16_acc2( *sum = vsum; } +void xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c16( + size_t rows, + size_t channels, + const float* input, + size_t input_stride, + const float* zero, + float* output, + const union xnn_f32_scale_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(rows != 0); + assert(channels != 0); + assert(input != NULL); + assert(output != NULL); + + const v128_t vscale = wasm_v128_load32_splat(¶ms->scalar.scale); + + size_t input_increment = 7 * input_stride; + for (; channels >= 16; channels -= 16) { + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride); + const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride); + const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride); + const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride); + const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride); + const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride); + + v128_t vacc0 = wasm_i32x4_const_splat(0.f); + v128_t vacc1 = wasm_i32x4_const_splat(0.f); + v128_t vacc2 = wasm_i32x4_const_splat(0.f); + v128_t vacc3 = wasm_i32x4_const_splat(0.f); + + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + v128_t vin0; + v128_t vin1; + v128_t vin2; + v128_t vin3; + vin0 = wasm_v128_load(&i0[0]); + vin1 = wasm_v128_load(&i0[4]); + vin2 = wasm_v128_load(&i0[8]); + vin3 = wasm_v128_load(&i0[12]); + vacc0 = wasm_f32x4_add(vin0, vacc0); + vacc1 = wasm_f32x4_add(vin1, vacc1); + vacc2 = wasm_f32x4_add(vin2, vacc2); + vacc3 = wasm_f32x4_add(vin3, vacc3); + vin0 = wasm_v128_load(&i1[0]); + vin1 = wasm_v128_load(&i1[4]); + vin2 = wasm_v128_load(&i1[8]); + vin3 = wasm_v128_load(&i1[12]); + vacc0 = wasm_f32x4_add(vin0, vacc0); + vacc1 = wasm_f32x4_add(vin1, vacc1); + vacc2 = wasm_f32x4_add(vin2, vacc2); + vacc3 = wasm_f32x4_add(vin3, vacc3); + vin0 = wasm_v128_load(&i2[0]); + vin1 = wasm_v128_load(&i2[4]); + vin2 = wasm_v128_load(&i2[8]); + vin3 = wasm_v128_load(&i2[12]); + vacc0 = wasm_f32x4_add(vin0, vacc0); + vacc1 = wasm_f32x4_add(vin1, vacc1); + vacc2 = wasm_f32x4_add(vin2, vacc2); + vacc3 = wasm_f32x4_add(vin3, vacc3); + vin0 = wasm_v128_load(&i3[0]); + vin1 = wasm_v128_load(&i3[4]); + vin2 = wasm_v128_load(&i3[8]); + vin3 = wasm_v128_load(&i3[12]); + vacc0 = wasm_f32x4_add(vin0, vacc0); + vacc1 = wasm_f32x4_add(vin1, vacc1); + vacc2 = wasm_f32x4_add(vin2, vacc2); + vacc3 = wasm_f32x4_add(vin3, vacc3); + vin0 = wasm_v128_load(&i4[0]); + vin1 = wasm_v128_load(&i4[4]); + vin2 = wasm_v128_load(&i4[8]); + vin3 = wasm_v128_load(&i4[12]); + vacc0 = wasm_f32x4_add(vin0, vacc0); + vacc1 = wasm_f32x4_add(vin1, vacc1); + vacc2 = wasm_f32x4_add(vin2, vacc2); + vacc3 = wasm_f32x4_add(vin3, vacc3); + vin0 = wasm_v128_load(&i5[0]); + vin1 = wasm_v128_load(&i5[4]); + vin2 = wasm_v128_load(&i5[8]); + vin3 = wasm_v128_load(&i5[12]); + vacc0 = wasm_f32x4_add(vin0, vacc0); + vacc1 = wasm_f32x4_add(vin1, vacc1); + vacc2 = wasm_f32x4_add(vin2, vacc2); + vacc3 = wasm_f32x4_add(vin3, vacc3); + vin0 = wasm_v128_load(&i6[0]); + vin1 = wasm_v128_load(&i6[4]); + vin2 = wasm_v128_load(&i6[8]); + vin3 = wasm_v128_load(&i6[12]); + vacc0 = wasm_f32x4_add(vin0, vacc0); + vacc1 = wasm_f32x4_add(vin1, vacc1); + vacc2 = wasm_f32x4_add(vin2, vacc2); + vacc3 = wasm_f32x4_add(vin3, vacc3); + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + vacc0 = wasm_f32x4_mul(vacc0, vscale); + vacc1 = wasm_f32x4_mul(vacc1, vscale); + vacc2 = wasm_f32x4_mul(vacc2, vscale); + vacc3 = wasm_f32x4_mul(vacc3, vscale); + + const float* o = output; + v128_t vo0 = wasm_v128_load(o); o += 4; + v128_t vo1 = wasm_v128_load(o); o += 4; + v128_t vo2 = wasm_v128_load(o); o += 4; + v128_t vo3 = wasm_v128_load(o); o += 4; + vacc0 = wasm_f32x4_add(vo0, vacc0); + vacc1 = wasm_f32x4_add(vo1, vacc1); + vacc2 = wasm_f32x4_add(vo2, vacc2); + vacc3 = wasm_f32x4_add(vo3, vacc3); + wasm_v128_store(output, vacc0); output += 4; + wasm_v128_store(output, vacc1); output += 4; + wasm_v128_store(output, vacc2); output += 4; + wasm_v128_store(output, vacc3); output += 4; + + input = (const float*) ((uintptr_t) input + 16 * sizeof(float)); + } + if (channels != 0) { + input_increment = 7 * input_stride; + const float* i0 = input; + const float* i1 = (const float*) ((uintptr_t) input + 1 * input_stride); + const float* i2 = (const float*) ((uintptr_t) input + 2 * input_stride); + const float* i3 = (const float*) ((uintptr_t) input + 3 * input_stride); + const float* i4 = (const float*) ((uintptr_t) input + 4 * input_stride); + const float* i5 = (const float*) ((uintptr_t) input + 5 * input_stride); + const float* i6 = (const float*) ((uintptr_t) input + 6 * input_stride); + v128_t vacc[4]; + vacc[0] = wasm_i32x4_const_splat(0.f); + vacc[1] = wasm_i32x4_const_splat(0.f); + vacc[2] = wasm_i32x4_const_splat(0.f); + vacc[3] = wasm_i32x4_const_splat(0.f); + + const size_t num_chunks = round_up_po2(channels, 4) >> 2; + for (int r = rows; r > 0; r -= 7) { + if XNN_UNPREDICTABLE(r < 2) { + i1 = zero; + } + if XNN_UNPREDICTABLE(r <= 2) { + i2 = zero; + } + if XNN_UNPREDICTABLE(r < 4) { + i3 = zero; + } + if XNN_UNPREDICTABLE(r <= 4) { + i4 = zero; + } + if XNN_UNPREDICTABLE(r < 6) { + i5 = zero; + } + if XNN_UNPREDICTABLE(r <= 6) { + i6 = zero; + } + for (int i = 0; i < num_chunks; ++i) { + vacc[i] = wasm_f32x4_add(wasm_v128_load(&i0[i*4]), vacc[i]); + vacc[i] = wasm_f32x4_add(wasm_v128_load(&i1[i*4]), vacc[i]); + vacc[i] = wasm_f32x4_add(wasm_v128_load(&i2[i*4]), vacc[i]); + vacc[i] = wasm_f32x4_add(wasm_v128_load(&i3[i*4]), vacc[i]); + vacc[i] = wasm_f32x4_add(wasm_v128_load(&i4[i*4]), vacc[i]); + vacc[i] = wasm_f32x4_add(wasm_v128_load(&i5[i*4]), vacc[i]); + vacc[i] = wasm_f32x4_add(wasm_v128_load(&i6[i*4]), vacc[i]); + } + i0 = (const float*) ((uintptr_t) i0 + input_increment); + i1 = (const float*) ((uintptr_t) i1 + input_increment); + i2 = (const float*) ((uintptr_t) i2 + input_increment); + i3 = (const float*) ((uintptr_t) i3 + input_increment); + i4 = (const float*) ((uintptr_t) i4 + input_increment); + i5 = (const float*) ((uintptr_t) i5 + input_increment); + i6 = (const float*) ((uintptr_t) i6 + input_increment); + } + for (int i = 0; i < num_chunks; ++i) { + vacc[i] = wasm_f32x4_mul(vacc[i], vscale); + } + + v128_t vo[4]; + const float* o = output; + for (int i = 0; i < channels >> 2; ++i) { + vo[i] = wasm_v128_load(o); o += 4; + } + for (int i = 0; i < channels >> 2; ++i) { + vacc[i] = wasm_f32x4_add(vo[i], vacc[i]); + } + for (int i = 0; i < channels >> 2; ++i) { + wasm_v128_store(output, vacc[i]); output += 4; + } + const size_t pos = channels / 4; + v128_t vout = vacc[pos]; + if (channels & 2) { + v128_t vo = wasm_f32x4_make(output[0], output[1], 0.f, 0.f); + wasm_v128_store64_lane(output, wasm_f32x4_add(vo, vout), 0); + vout = wasm_v64x2_shuffle(vout, vout, 1, 1); + output += 2; + } + if (channels & 1) { + *output += wasm_f32x4_extract_lane(vout, 0); + } + } +} + void xnn_f32_rmax_ukernel__wasmsimd_pminmax_u16_acc4( size_t batch, const float* input, diff --git a/src/configs/reduce-config.c b/src/configs/reduce-config.c index c016d755961..635352dc768 100644 --- a/src/configs/reduce-config.c +++ b/src/configs/reduce-config.c @@ -19,20 +19,26 @@ #include static struct xnn_reduce_config f16_f32acc_rsum_config = {0}; +static struct xnn_reduce_config f16_f32acc_rdsum_config = {0}; static struct xnn_reduce_config f16_rminmax_config = {0}; static struct xnn_reduce_config f32_rminmax_config = {0}; static struct xnn_reduce_config f32_rsum_config = {0}; +static struct xnn_reduce_config f32_rdsum_config = {0}; #if XNN_PLATFORM_WINDOWS static INIT_ONCE init_guard_f16_f32acc_rsum = INIT_ONCE_STATIC_INIT; + static INIT_ONCE init_guard_f16_f32acc_rdsum = INIT_ONCE_STATIC_INIT; static INIT_ONCE init_guard_f16_rminmax = INIT_ONCE_STATIC_INIT; static INIT_ONCE init_guard_f32_rminmax = INIT_ONCE_STATIC_INIT; static INIT_ONCE init_guard_f32_rsum = INIT_ONCE_STATIC_INIT; + static INIT_ONCE init_guard_f32_rdsum = INIT_ONCE_STATIC_INIT; #else static pthread_once_t init_guard_f16_f32acc_rsum = PTHREAD_ONCE_INIT; + static pthread_once_t init_guard_f16_f32acc_rdsum = PTHREAD_ONCE_INIT; static pthread_once_t init_guard_f16_rminmax = PTHREAD_ONCE_INIT; static pthread_once_t init_guard_f32_rminmax = PTHREAD_ONCE_INIT; static pthread_once_t init_guard_f32_rsum = PTHREAD_ONCE_INIT; + static pthread_once_t init_guard_f32_rdsum = PTHREAD_ONCE_INIT; #endif static void init_f16_f32acc_rsum_config(void) { @@ -217,12 +223,107 @@ static void init_f32_rsum_config(void) { #endif } +static void init_f16_f32acc_rdsum_config(void) { + #if (XNN_ARCH_ARM || XNN_ARCH_ARM64) && XNN_ENABLE_ARM_FP16_VECTOR + const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); + assert(hardware_config != NULL); + if (hardware_config->use_arm_neon_fp16_arith) { + f16_f32acc_rdsum_config = (struct xnn_reduce_config) { + .rd_ukernel = (xnn_rdsum_ukernel_fn) xnn_f16_f32acc_rdsum_ukernel_7p7x__neonfp16arith_c16, + .init.f16_f32acc_scale = xnn_init_f16_f32acc_scale_scalar_params, + .element_tile = 16, + }; + } + #elif XNN_ARCH_X86 || XNN_ARCH_X86_64 + const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); + assert(hardware_config != NULL); + if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512skx) { + f16_f32acc_rdsum_config = (struct xnn_reduce_config) { + .rd_ukernel = (xnn_rdsum_ukernel_fn) xnn_f16_f32acc_rdsum_ukernel_7p7x__avx512skx_c64, + .init.f16_f32acc_scale = xnn_init_f16_f32acc_scale_scalar_params, + .element_tile = 64, + }; + } else if (hardware_config->use_x86_f16c) { + f16_f32acc_rdsum_config = (struct xnn_reduce_config) { + .rd_ukernel = (xnn_rdsum_ukernel_fn) xnn_f16_f32acc_rdsum_ukernel_7p7x__f16c_c32, + .init.f16_f32acc_scale = xnn_init_f16_f32acc_scale_avx_params, + .element_tile = 32, + }; + } + #endif +} + +static void init_f32_rdsum_config(void) { + #if XNN_ARCH_ARM + const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); + assert(hardware_config != NULL); + if (hardware_config->use_arm_neon) { + f32_rdsum_config = (struct xnn_reduce_config) { + .rd_ukernel = (xnn_rdsum_ukernel_fn) xnn_f32_rdsum_ukernel_7p7x__neon_c16, + .init.f32_scale = xnn_init_f32_scale_scalar_params, + .element_tile = 16, + }; + } else { + f32_rdsum_config = (struct xnn_reduce_config) { + .rd_ukernel = (xnn_rdsum_ukernel_fn) xnn_f32_rdsum_ukernel_7p7x__scalar_c4, + .init.f32_scale = xnn_init_f32_scale_scalar_params, + .element_tile = 4, + }; + } + #elif XNN_ARCH_ARM64 + f32_rdsum_config = (struct xnn_reduce_config) { + .rd_ukernel = (xnn_rdsum_ukernel_fn) xnn_f32_rdsum_ukernel_7p7x__neon_c16, + .init.f32_scale = xnn_init_f32_scale_scalar_params, + .element_tile = 16, + }; + #elif XNN_ARCH_X86 || XNN_ARCH_X86_64 + const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); + assert(hardware_config != NULL); + if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512f) { + f32_rdsum_config = (struct xnn_reduce_config) { + .rd_ukernel = (xnn_rdsum_ukernel_fn) xnn_f32_rdsum_ukernel_7p7x__avx512f_c64, + .init.f32_scale = xnn_init_f32_scale_scalar_params, + .element_tile = 64, + }; + } else if (hardware_config->use_x86_avx) { + f32_rdsum_config = (struct xnn_reduce_config) { + .rd_ukernel = (xnn_rdsum_ukernel_fn) xnn_f32_rdsum_ukernel_7p7x__avx_c32, + .init.f32_scale = xnn_init_f32_scale_avx_params, + .element_tile = 32, + }; + } else { + f32_rdsum_config = (struct xnn_reduce_config) { + .rd_ukernel = (xnn_rdsum_ukernel_fn) xnn_f32_rdsum_ukernel_7p7x__sse_c16, + .init.f32_scale = xnn_init_f32_scale_sse_params, + .element_tile = 16, + }; + } + #elif XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD + f32_rdsum_config = (struct xnn_reduce_config) { + .rd_ukernel = (xnn_rdsum_ukernel_fn) xnn_f32_rdsum_ukernel_7p7x__wasmsimd_c16, + .init.f32_scale = xnn_init_f32_scale_scalar_params, + .element_tile = 16, + }; + #else + f32_rdsum_config = (struct xnn_reduce_config) { + .rd_ukernel = (xnn_rdsum_ukernel_fn) xnn_f32_rdsum_ukernel_7p7x__scalar_c4, + .init.f32_scale = xnn_init_f32_scale_scalar_params, + .element_tile = 4, + }; + #endif +} + #if XNN_PLATFORM_WINDOWS static BOOL CALLBACK init_f16_f32acc_rsum_config_windows(PINIT_ONCE init_once, PVOID parameter, PVOID* context) { init_f16_f32acc_rsum_config(); return TRUE; } + static BOOL CALLBACK init_f16_f32acc_rdsum_config_windows(PINIT_ONCE init_once, PVOID parameter, PVOID* context) { + init_f16_f32acc_rdsum_config(); + return TRUE; + } + static BOOL CALLBACK init_f16_rminmax_config_windows(PINIT_ONCE init_once, PVOID parameter, PVOID* context) { init_f16_rminmax_config(); return TRUE; @@ -236,6 +337,11 @@ static void init_f32_rsum_config(void) { init_f32_rsum_config(); return TRUE; } + + static BOOL CALLBACK init_f32_rdsum_config_windows(PINIT_ONCE init_once, PVOID parameter, PVOID* context) { + init_f32_rdsum_config(); + return TRUE; + } #endif const struct xnn_reduce_config* xnn_init_f16_f32acc_rsum_config() { @@ -289,3 +395,29 @@ const struct xnn_reduce_config* xnn_init_f32_rsum_config() { #endif return &f32_rsum_config; } + +const struct xnn_reduce_config* xnn_init_f16_f32acc_rdsum_config() { + const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); + if (hardware_config == NULL || !xnn_is_f16_compatible_config(hardware_config)) { + return NULL; + } + #if XNN_PLATFORM_WINDOWS + InitOnceExecuteOnce(&init_guard_f16_f32acc_rdsum, &init_f16_f32acc_rdsum_config_windows, NULL, NULL); + #else + pthread_once(&init_guard_f16_f32acc_rdsum, &init_f16_f32acc_rdsum_config); + #endif + return &f16_f32acc_rdsum_config; +} + +const struct xnn_reduce_config* xnn_init_f32_rdsum_config() { + const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); + if (hardware_config == NULL) { + return NULL; + } + #if XNN_PLATFORM_WINDOWS + InitOnceExecuteOnce(&init_guard_f32_rdsum, &init_f32_rdsum_config_windows, NULL, NULL); + #else + pthread_once(&init_guard_f32_rdsum, &init_f32_rdsum_config); + #endif + return &f32_rdsum_config; +} diff --git a/src/f16-f32acc-rdsum/avx512skx.c.in b/src/f16-f32acc-rdsum/avx512skx.c.in index 4c3125b5e76..39691d1e06b 100644 --- a/src/f16-f32acc-rdsum/avx512skx.c.in +++ b/src/f16-f32acc-rdsum/avx512skx.c.in @@ -33,19 +33,19 @@ void xnn_f16_f32acc_rdsum_ukernel_${ACCUMULATORS}p${ACCUMULATORS}x__avx512skx_c$ size_t input_increment = ${ACCUMULATORS} * input_stride; for (; channels >= ${CHANNELS_BATCH}; channels -= ${CHANNELS_BATCH}) { const uint16_t* i0 = input; - $for i in range(1, ACCUMULATORS): - const uint16_t* i${i} = (const uint16_t*) ((uintptr_t) input + ${i} * input_stride); + $for ACC in range(1, ACCUMULATORS): + const uint16_t* i${ACC} = (const uint16_t*) ((uintptr_t) input + ${ACC} * input_stride); $for i in range(UNROLL): __m512 vacc${i} = _mm512_setzero_ps(); for (int r = rows; r > 0; r -= ${ACCUMULATORS}) { - $for N in range(1, ACCUMULATORS, 2): - if XNN_UNPREDICTABLE(r < ${N+1}) { - i${N} = zero; + $for ACC in range(1, ACCUMULATORS, 2): + if XNN_UNPREDICTABLE(r < ${ACC+1}) { + i${ACC} = zero; } - if XNN_UNPREDICTABLE(r <= ${N+1}) { - i${N+1} = zero; + if XNN_UNPREDICTABLE(r <= ${ACC+1}) { + i${ACC+1} = zero; } $for c in range(UNROLL): __m512 vin${c}; @@ -54,8 +54,8 @@ void xnn_f16_f32acc_rdsum_ukernel_${ACCUMULATORS}p${ACCUMULATORS}x__avx512skx_c$ vin${c} = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) (&i${j}[${c*16}]))); $for c in range(UNROLL): vacc${c} = _mm512_add_ps(vin${c}, vacc${c}); - $for N in range(0, ACCUMULATORS): - i${N} = (const uint16_t*) ((uintptr_t) i${N} + input_increment); + $for ACC in range(0, ACCUMULATORS): + i${ACC} = (const uint16_t*) ((uintptr_t) i${ACC} + input_increment); } $for i in range(UNROLL): vacc${i} = _mm512_mul_ps(vacc${i}, vscale); @@ -90,12 +90,12 @@ void xnn_f16_f32acc_rdsum_ukernel_${ACCUMULATORS}p${ACCUMULATORS}x__avx512skx_c$ vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << batch) - UINT32_C(1))); } for (int r = rows; r > 0; r -= ${ACCUMULATORS}) { - $for N in range(1, ACCUMULATORS, 2): - if XNN_UNPREDICTABLE(r < ${N+1}) { - i${N} = zero; + $for ACC in range(1, ACCUMULATORS, 2): + if XNN_UNPREDICTABLE(r < ${ACC+1}) { + i${ACC} = zero; } - if XNN_UNPREDICTABLE(r <= ${N+1}) { - i${N+1} = zero; + if XNN_UNPREDICTABLE(r <= ${ACC+1}) { + i${ACC+1} = zero; } for (int i = 0; i < num_full_chunks; ++i) { $for c in range(ACCUMULATORS): @@ -106,8 +106,8 @@ void xnn_f16_f32acc_rdsum_ukernel_${ACCUMULATORS}p${ACCUMULATORS}x__avx512skx_c$ $for c in range(ACCUMULATORS): vacc[num_full_chunks] = _mm512_maskz_add_ps(vmask, vacc[num_full_chunks], _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, &i${c}[num_full_chunks*16]))); } - $for N in range(ACCUMULATORS): - i${N} = (const uint16_t*) ((uintptr_t) i${N} + input_increment); + $for ACC in range(ACCUMULATORS): + i${ACC} = (const uint16_t*) ((uintptr_t) i${ACC} + input_increment); } for (size_t i = 0; i < num_chunks; ++i) { vacc[i] = _mm512_mul_ps(vacc[i], vscale); diff --git a/src/operator-run.c b/src/operator-run.c index 65d01cb653c..681818b451a 100644 --- a/src/operator-run.c +++ b/src/operator-run.c @@ -2115,7 +2115,7 @@ void xnn_compute_contiguous_reduce( const size_t* input_stride = context->input_stride; const size_t* output_stride = context->output_stride; - // input dimensions 1, 3 & 5 are reduced so the entireity of these dimensions + // input dimensions 1, 3 & 5 are reduced so the entirety of these dimensions // are processed so their indices are always 0. size_t input_offset = input_stride[0] * output_idx0 + input_stride[2] * output_idx1 + input_stride[4] * output_idx2; size_t output_offset = output_stride[0] * output_idx0 + output_stride[1] * output_idx1 + output_stride[2] * output_idx2; @@ -2135,7 +2135,7 @@ void xnn_compute_contiguous_reduce( // output2_block_size output elements are written. for (size_t k = 0; k < output2_block_size; ++k) { // The microkernel reduces input dimension 5. - context->ukernel(context->scaled_elements, input_row, output, &context->params); + context->ukernel.rsum(context->scaled_elements, input_row, output, &context->params); // input_stride[4] is the number of bytes of input which have been // processed by the microkernel call. input_row = (const void*) ((uintptr_t) input_row + input_stride[4]); @@ -2153,6 +2153,50 @@ void xnn_compute_contiguous_reduce( } } +void xnn_compute_discontiguous_reduce( + const struct reduce_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t output_idx0, + size_t output_idx1, + size_t output_idx2, + size_t output1_block_size, + size_t output2_block_size) +{ + assert(output1_block_size == 1); + const size_t* input_stride = context->input_stride; + const size_t* output_stride = context->output_stride; + + // input dimensions 0, 2 & 4 are reduced so the entirety of these dimensions + // are processed so their indices are always 0. + size_t input_offset = input_stride[1] * output_idx0 + input_stride[3] * output_idx1 + input_stride[5] * output_idx2; + size_t output_offset = output_stride[0] * output_idx0 + output_stride[1] * output_idx1 + output_stride[2] * output_idx2; + int input_shape0 = context->input_shape[0]; + int input_shape2 = context->input_shape[2]; + + void* output = (void*) ((uintptr_t) context->output + output_offset); + // RDsum microkernels accumulate into the output buffer. + memset(output, 0, context->element_size * output2_block_size); + + // Input dimension 0 is reduced. + for (size_t i = 0; i < input_shape0; ++i) { + const void* input = (const void*) ((uintptr_t) context->input + input_offset); + // Input dimension 2 is reduced. + for (size_t j = 0; j < input_shape2; ++j) { + const void* input_row = input; + // The microkernel reduces input dimension 4 and iterates over output_block_size elements of dimension 5. + context->ukernel.rdsum(context->scaled_elements, output2_block_size, input_row, input_stride[4], context->zero, output, &context->params); + // input_stride[4] is the number of bytes of input which have been + // processed by the microkernel call. + input_row = (const void*) ((uintptr_t) input_row + input_stride[4]); + // Reset the output pointer. + output = (void*) ((uintptr_t) context->output + output_offset); + // Iterating over input_shape[2]. + input = (const void*) ((uintptr_t) input + input_stride[2]); + } + // Iterating over input_shape[0]. + input_offset += input_stride[0]; + } +} + void xnn_compute_pad_qd8_params( const struct f32_qd8_convert_context context[restrict XNN_MIN_ELEMENTS(1)], size_t batch_index) diff --git a/src/operators/reduce-nd.c b/src/operators/reduce-nd.c index b7d3cd3acfa..d336bfda61c 100644 --- a/src/operators/reduce-nd.c +++ b/src/operators/reduce-nd.c @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -31,8 +32,8 @@ static enum xnn_status create_mean_nd( uint32_t flags, uint32_t log2_element_size, enum xnn_operator_type operator_type, - const struct xnn_gavgpool_config* gavgpool_config, - const struct xnn_reduce_config* reduce_config, + const struct xnn_reduce_config* rdsum_config, + const struct xnn_reduce_config* rsum_config, const void* params, size_t params_size, xnn_operator_t* mean_op_out) @@ -58,8 +59,8 @@ static enum xnn_status create_mean_nd( mean_op->type = operator_type; mean_op->flags = flags; - mean_op->gavgpool_config = gavgpool_config; - mean_op->reduce_config = reduce_config; + mean_op->rdsum_config = rdsum_config; + mean_op->rsum_config = rsum_config; if (params_size != 0) { memcpy(&mean_op->params, params, params_size); } @@ -78,29 +79,20 @@ enum xnn_status xnn_create_mean_nd_f16( uint32_t flags, xnn_operator_t* mean_op_out) { - const struct xnn_gavgpool_config* gavgpool_config = xnn_init_f16_gavgpool_config(); const struct xnn_reduce_config* rsum_config = xnn_init_f16_f32acc_rsum_config(); - if (gavgpool_config == NULL || rsum_config == NULL) { + const struct xnn_reduce_config* rdsum_config = xnn_init_f16_f32acc_rdsum_config(); + if (rdsum_config == NULL || rsum_config == NULL) { xnn_log_error("failed to create %s operator: unsupported hardware configuration", xnn_operator_type_to_string(xnn_operator_type_mean_nd_f16)); return xnn_status_unsupported_hardware; } - - struct { - union xnn_f16_f32acc_scale_params scale; - union xnn_f16_scaleminmax_params scaleminmax; - } params; - gavgpool_config->init.f16(¶ms.scaleminmax, - /*scale=*/UINT16_C(0x3C00) /* 1.0h */, - /*output_min=*/UINT16_C(0xFC00) /* -inf */, - /*output_max=*/UINT16_C(0x7C00) /* +inf */); - rsum_config->init.f16_f32acc_scale(¶ms.scale, - /*scale=*/1.0f); + union xnn_f16_f32acc_scale_params params; + rsum_config->init.f16_f32acc_scale(¶ms, /*scale=*/1.0f); return create_mean_nd( flags, /*log2_element_size=*/XNN_LOG2_SIZEOF_HALF, xnn_operator_type_mean_nd_f16, - gavgpool_config, rsum_config, + rdsum_config, rsum_config, ¶ms, sizeof(params), mean_op_out); } @@ -109,27 +101,21 @@ enum xnn_status xnn_create_mean_nd_f32( uint32_t flags, xnn_operator_t* mean_op_out) { - const struct xnn_gavgpool_config* gavgpool_config = xnn_init_f32_gavgpool_config(); const struct xnn_reduce_config* rsum_config = xnn_init_f32_rsum_config(); - if (gavgpool_config == NULL || rsum_config == NULL) { + const struct xnn_reduce_config* rdsum_config = xnn_init_f32_rdsum_config(); + if (rdsum_config == NULL || rsum_config == NULL) { xnn_log_error("failed to create %s operator: unsupported hardware configuration", xnn_operator_type_to_string(xnn_operator_type_mean_nd_f32)); return xnn_status_unsupported_hardware; } - struct { - union xnn_f32_scale_params scale; - union xnn_f32_scaleminmax_params scaleminmax; - } params; - gavgpool_config->init.f32(¶ms.scaleminmax, - /*scale=*/1.0f, /*output_min=*/-INFINITY, /*output_max=*/INFINITY); - rsum_config->init.f32_scale(¶ms.scale, - /*scale=*/1.0f); + union xnn_f32_scale_params params; + rsum_config->init.f32_scale(¶ms, /*scale=*/1.0f); return create_mean_nd( flags, /*log2_element_size=*/XNN_LOG2_SIZEOF_FLOAT, xnn_operator_type_mean_nd_f32, - gavgpool_config, rsum_config, + rdsum_config, rsum_config, ¶ms, sizeof(params), mean_op_out); } @@ -140,13 +126,9 @@ static enum xnn_status reshape_mean_nd( const size_t* reduction_axes, size_t num_input_dims, const size_t* input_shape, - size_t* workspace_size, - size_t* workspace_alignment, size_t log2_data_element_size, size_t log2_accumulator_element_size, enum xnn_operator_type expected_operator_type, - const void* scaleminmax_params, - size_t scaleminmax_params_size, const void* scale_params, size_t scale_params_size, void (*update_params)(xnn_operator_t, size_t), @@ -230,16 +212,14 @@ static enum xnn_status reshape_mean_nd( return xnn_status_success; } - *workspace_size = 0; - *workspace_alignment = 1; - + memmove(&normalized_input_shape[XNN_MAX_TENSOR_DIMS - num_input_dims], &normalized_input_shape[0], sizeof(size_t) * num_input_dims); + for (int i = 0; i < XNN_MAX_TENSOR_DIMS - num_input_dims; ++i) { + normalized_input_shape[i] = 1; + } + mean_op->compute[0].type = xnn_parallelization_type_3d_tile_2d; + mean_op->ukernel.type = xnn_microkernel_type_mean; // Reduction along the innermost dimension. if (normalized_reduction_axes[num_reduction_axes - 1] == num_input_dims - 1) { - memmove(&normalized_input_shape[XNN_MAX_TENSOR_DIMS - num_input_dims], &normalized_input_shape[0], sizeof(size_t) * num_input_dims); - for (int i = 0; i < XNN_MAX_TENSOR_DIMS - num_input_dims; ++i) { - normalized_input_shape[i] = 1; - } - const size_t scale_dim = normalized_input_shape[1] * normalized_input_shape[3] * normalized_input_shape[5]; const size_t axis_dim = normalized_input_shape[5]; @@ -249,43 +229,30 @@ static enum xnn_status reshape_mean_nd( mean_op->context.reduce = (struct reduce_context) { .scaled_elements = axis_dim << log2_data_element_size, - .ukernel = mean_op->reduce_config->ukernel, + .ukernel.rsum = mean_op->rsum_config->ukernel, .element_size = UINT32_C(1) << log2_data_element_size, }; memcpy(&mean_op->context.reduce.params, scale_params, scale_params_size); - mean_op->context.reduce.input_stride[XNN_MAX_TENSOR_DIMS - 1] = (1 << log2_data_element_size); - for (int i = XNN_MAX_TENSOR_DIMS - 2; i >= 0; --i) { - mean_op->context.reduce.input_stride[i] = (mean_op->context.reduce.input_stride[i + 1] * normalized_input_shape[i + 1]); - } - memcpy(mean_op->context.reduce.input_shape, normalized_input_shape, XNN_MAX_TENSOR_DIMS * sizeof(size_t)); - mean_op->context.reduce.output_stride[XNN_MAX_TENSOR_DIMS / 2 - 1] = (1 << log2_data_element_size); - for (int i = XNN_MAX_TENSOR_DIMS / 2 - 2; i >= 0; --i) { - mean_op->context.reduce.output_stride[i] = (mean_op->context.reduce.output_stride[i + 1] * normalized_input_shape[(i + 1) * 2]); - } mean_op->compute[0].task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_contiguous_reduce; - mean_op->compute[0].type = xnn_parallelization_type_3d_tile_2d; mean_op->compute[0].range[0] = normalized_input_shape[0]; mean_op->compute[0].range[1] = normalized_input_shape[2]; mean_op->compute[0].range[2] = normalized_input_shape[4]; mean_op->compute[0].tile[0] = 1; mean_op->compute[0].tile[1] = 2; - mean_op->ukernel.type = xnn_microkernel_type_mean; + mean_op->context.reduce.output_stride[XNN_MAX_TENSOR_DIMS / 2 - 1] = (1 << log2_data_element_size); + for (int i = XNN_MAX_TENSOR_DIMS / 2 - 2; i >= 0; --i) { + mean_op->context.reduce.output_stride[i] = (mean_op->context.reduce.output_stride[i + 1] * normalized_input_shape[(i + 1) * 2]); + } } else { // Reduction along the non-innermost dimension - if (num_reduction_axes != 1) { - xnn_log_error( - "failed to reshape %s operator with %zu normalized reduction axes: only a single post-normalization reduction axis is supported", - xnn_operator_type_to_string(mean_op->type), num_reduction_axes); - return xnn_status_invalid_parameter; - } - const size_t reduction_axis = normalized_reduction_axes[0]; - const size_t axis_dim = normalized_input_shape[reduction_axis]; - const size_t batch_like_dim = reduction_axis == 0 ? 1 : normalized_input_shape[0]; - const size_t channel_like_dim = normalized_input_shape[num_input_dims - 1]; + const size_t channel_like_dim = normalized_input_shape[XNN_MAX_TENSOR_DIMS - 1]; + + const size_t scale_dim = normalized_input_shape[0] * normalized_input_shape[2] * normalized_input_shape[4]; + const size_t axis_dim = normalized_input_shape[4]; if (update_params != NULL) { - update_params(mean_op, axis_dim); + update_params(mean_op, scale_dim); } if (mean_op->channels != channel_like_dim) { const size_t zero_size = (channel_like_dim << log2_data_element_size) + XNN_EXTRA_BYTES; @@ -301,36 +268,33 @@ static enum xnn_status reshape_mean_nd( mean_op->channels = channel_like_dim; } - mean_op->context.global_average_pooling_nwc = (struct global_average_pooling_nwc_context) { - .zero = mean_op->zero_buffer, - .input_pixel_stride = channel_like_dim << log2_data_element_size, - .input_batch_stride = (channel_like_dim * axis_dim) << log2_data_element_size, - .input_elements = axis_dim, - .channels = channel_like_dim, - .output_batch_stride = channel_like_dim << log2_data_element_size, + mean_op->context.reduce = (struct reduce_context) { + .zero = mean_op->zero_buffer, + .input_pixel_stride = channel_like_dim << log2_data_element_size, + .input_batch_stride = (channel_like_dim * axis_dim) << log2_data_element_size, + .scaled_elements = axis_dim, + .channels = channel_like_dim, + .output_batch_stride = channel_like_dim << log2_data_element_size, + .ukernel.rdsum = mean_op->rdsum_config->rd_ukernel, + .element_size = UINT32_C(1) << log2_data_element_size, }; - memcpy(&mean_op->context.global_average_pooling_nwc.params, scaleminmax_params, scaleminmax_params_size); - - mean_op->compute[0].range[0] = batch_like_dim; - if (axis_dim <= mean_op->gavgpool_config->row_tile) { - mean_op->compute[0].type = xnn_parallelization_type_1d; - mean_op->compute[0].task_1d = (pthreadpool_task_1d_t) xnn_compute_global_average_pooling_nwc_unipass; - mean_op->context.global_average_pooling_nwc.unipass_ukernel = mean_op->gavgpool_config->unipass; - } else { - const size_t multipass_batch_stride = round_up_po2( - (channel_like_dim + (XNN_MAX_SIMD_SIZE >> log2_data_element_size)) << log2_accumulator_element_size, - XNN_ALLOCATION_ALIGNMENT); - const size_t num_threads = pthreadpool_get_threads_count(threadpool); - *workspace_size = num_threads * multipass_batch_stride; - *workspace_alignment = XNN_ALLOCATION_ALIGNMENT; - mean_op->context.global_average_pooling_nwc.multipass_batch_stride = multipass_batch_stride; - mean_op->compute[0].type = xnn_parallelization_type_1d_with_thread; - mean_op->compute[0].task_1d_with_thread = - (pthreadpool_task_1d_with_thread_t) xnn_compute_global_average_pooling_nwc_multipass_with_thread; - mean_op->context.global_average_pooling_nwc.multipass_ukernel = mean_op->gavgpool_config->multipass; + memcpy(&mean_op->context.reduce.params, scale_params, scale_params_size); + mean_op->compute[0].task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_discontiguous_reduce; + mean_op->compute[0].range[0] = normalized_input_shape[1]; + mean_op->compute[0].range[1] = normalized_input_shape[3]; + mean_op->compute[0].range[2] = normalized_input_shape[5]; + mean_op->compute[0].tile[0] = 1; + mean_op->compute[0].tile[1] = normalized_input_shape[5]; + mean_op->context.reduce.output_stride[XNN_MAX_TENSOR_DIMS / 2 - 1] = (1 << log2_data_element_size); + for (int i = XNN_MAX_TENSOR_DIMS / 2 - 2; i >= 0; --i) { + mean_op->context.reduce.output_stride[i] = (mean_op->context.reduce.output_stride[i + 1] * normalized_input_shape[(i * 2+3)]); } - mean_op->ukernel.type = xnn_microkernel_type_global_average_pooling; } + mean_op->context.reduce.input_stride[XNN_MAX_TENSOR_DIMS - 1] = (1 << log2_data_element_size); + for (int i = XNN_MAX_TENSOR_DIMS - 2; i >= 0; --i) { + mean_op->context.reduce.input_stride[i] = (mean_op->context.reduce.input_stride[i + 1] * normalized_input_shape[i + 1]); + } + memcpy(mean_op->context.reduce.input_shape, normalized_input_shape, XNN_MAX_TENSOR_DIMS * sizeof(size_t)); mean_op->state = xnn_run_state_needs_setup; return xnn_status_success; @@ -341,8 +305,8 @@ static void update_params_mean_f16( size_t num_elements) { const float scale = 1.0f / (float) (double) num_elements; - mean_op->reduce_config->init.f16_f32acc_scale(&mean_op->params.f16_f32acc_scale, scale); - mean_op->gavgpool_config->update.f16(&mean_op->params.f16_scale_minmax, fp16_ieee_from_fp32_value(scale)); + mean_op->rsum_config->init.f16_f32acc_scale(&mean_op->params.f16_f32acc_scale, scale); + mean_op->rdsum_config->init.f16_f32acc_scale(&mean_op->params.f16_f32acc_scale, scale); } enum xnn_status xnn_reshape_mean_nd_f16( @@ -351,20 +315,15 @@ enum xnn_status xnn_reshape_mean_nd_f16( const size_t* reduction_axes, size_t num_input_dims, const size_t* input_shape, - size_t* workspace_size, - size_t* workspace_alignment, pthreadpool_t threadpool) { return reshape_mean_nd( mean_op, num_reduction_axes, reduction_axes, num_input_dims, input_shape, - workspace_size, workspace_alignment, /*log2_data_element_size=*/XNN_LOG2_SIZEOF_HALF, /*log2_accumulator_element_size=*/XNN_LOG2_SIZEOF_HALF, xnn_operator_type_mean_nd_f16, - /*scaleminmax_params=*/&mean_op->params.f16_scale_minmax, - /*scaleminmax_params_size=*/sizeof(mean_op->params.f16_scale_minmax), /*scale_params=*/&mean_op->params.f16_f32acc_scale, /*scale_params_size=*/sizeof(mean_op->params.f16_f32acc_scale), update_params_mean_f16, @@ -376,8 +335,8 @@ static void update_params_mean_f32( size_t num_elements) { const float scale = 1.0f / (float) (double) num_elements; - mean_op->reduce_config->init.f32_scale(&mean_op->params.f32_scale, scale); - mean_op->gavgpool_config->update.f32(&mean_op->params.f32_scale_minmax, scale); + mean_op->rsum_config->init.f32_scale(&mean_op->params.f32_scale, scale); + mean_op->rdsum_config->init.f32_scale(&mean_op->params.f32_scale, scale); } enum xnn_status xnn_reshape_mean_nd_f32( @@ -386,20 +345,15 @@ enum xnn_status xnn_reshape_mean_nd_f32( const size_t* reduction_axes, size_t num_input_dims, const size_t* input_shape, - size_t* workspace_size, - size_t* workspace_alignment, pthreadpool_t threadpool) { return reshape_mean_nd( mean_op, num_reduction_axes, reduction_axes, num_input_dims, input_shape, - workspace_size, workspace_alignment, /*log2_data_element_size=*/XNN_LOG2_SIZEOF_FLOAT, /*log2_accumulator_element_size=*/XNN_LOG2_SIZEOF_FLOAT, xnn_operator_type_mean_nd_f32, - /*scaleminmax_params=*/&mean_op->params.f32_scale_minmax, - /*scaleminmax_params_size=*/sizeof(mean_op->params.f32_scale_minmax), /*scale_params=*/&mean_op->params.f32_scale, /*scale_params_size=*/sizeof(mean_op->params.f32_scale), update_params_mean_f32, @@ -408,7 +362,6 @@ enum xnn_status xnn_reshape_mean_nd_f32( static enum xnn_status setup_mean_nd( xnn_operator_t mean_op, - void* workspace, const float* input, float* output, enum xnn_operator_type expected_operator_type) @@ -435,23 +388,8 @@ static enum xnn_status setup_mean_nd( break; } - if (mean_op->ukernel.type == xnn_microkernel_type_mean) { - // Reduction along the innermost dimension - mean_op->context.reduce.input = input; - mean_op->context.reduce.output = output; - } else { - assert(mean_op->ukernel.type == xnn_microkernel_type_global_average_pooling); - mean_op->context.global_average_pooling_nwc.input = input; - mean_op->context.global_average_pooling_nwc.output = output; - - if (mean_op->context.global_average_pooling_nwc.multipass_batch_stride != 0 && workspace == NULL) { - xnn_log_error( - "failed to setup %s operator: workspace of size %zu required but workspace is NULL", - xnn_operator_type_to_string(mean_op->type), - mean_op->context.global_average_pooling_nwc.multipass_batch_stride); - } - mean_op->context.global_average_pooling_nwc.multipass_buffer = workspace; - } + mean_op->context.reduce.input = input; + mean_op->context.reduce.output = output; mean_op->state = xnn_run_state_ready; return xnn_status_success; @@ -459,26 +397,22 @@ static enum xnn_status setup_mean_nd( enum xnn_status xnn_setup_mean_nd_f16( xnn_operator_t mean_op, - void* workspace, const void* input, void* output) { return setup_mean_nd( mean_op, - workspace, input, output, xnn_operator_type_mean_nd_f16); } enum xnn_status xnn_setup_mean_nd_f32( xnn_operator_t mean_op, - void* workspace, const float* input, float* output) { return setup_mean_nd( mean_op, - workspace, input, output, xnn_operator_type_mean_nd_f32); } diff --git a/src/subgraph/static-mean.c b/src/subgraph/static-mean.c index 38d75c6430e..80dbe06d26a 100644 --- a/src/subgraph/static-mean.c +++ b/src/subgraph/static-mean.c @@ -71,7 +71,6 @@ static enum xnn_status reshape_mean_operator( assert(output_id != XNN_INVALID_VALUE_ID); assert(output_id < num_values); - const size_t old_workspace_size = opdata->workspace_size; enum xnn_status status = xnn_status_invalid_state; switch (opdata->operator_objects[0]->type) { case xnn_operator_type_mean_nd_f16: @@ -81,8 +80,6 @@ static enum xnn_status reshape_mean_operator( opdata->reduction_axes, input_value->shape.num_dims, input_value->shape.dim, - &opdata->workspace_size, - &opdata->workspace_alignment, threadpool); break; case xnn_operator_type_mean_nd_f32: @@ -92,8 +89,6 @@ static enum xnn_status reshape_mean_operator( opdata->reduction_axes, input_value->shape.num_dims, input_value->shape.dim, - &opdata->workspace_size, - &opdata->workspace_alignment, threadpool); break; default: @@ -136,7 +131,7 @@ static enum xnn_status reshape_mean_operator( output_value->shape.num_dims = input_value->shape.num_dims - num_skip_axis; } const size_t new_size = xnn_tensor_get_size(output_value); - if (new_size > output_value->size || opdata->workspace_size > old_workspace_size) { + if (new_size > output_value->size) { output_value->size = new_size; return xnn_status_reallocation_required; } @@ -171,12 +166,10 @@ static enum xnn_status setup_mean_operator( case xnn_operator_type_mean_nd_f16: return xnn_setup_mean_nd_f16( opdata->operator_objects[0], - opdata->workspace, input_data, output_data); case xnn_operator_type_mean_nd_f32: return xnn_setup_mean_nd_f32( opdata->operator_objects[0], - opdata->workspace, input_data, output_data); default: XNN_UNREACHABLE; @@ -261,24 +254,23 @@ enum xnn_status xnn_define_static_mean( return xnn_status_invalid_parameter; } - size_t last_axis = 0; for (size_t i = 0; i < num_reduction_axes; i++) { - const size_t axis = reduction_axes[i]; - if (axis > input_value->shape.num_dims) { + if (reduction_axes[i] > input_value->shape.num_dims) { xnn_log_error( "failed to define %s operator with #%zu reduction axis of %zu: the index is out of bounds for a %zuD input shape", - xnn_node_type_to_string(xnn_node_type_static_mean), i, axis, input_value->shape.num_dims); + xnn_node_type_to_string(xnn_node_type_static_mean), i, reduction_axes[i], input_value->shape.num_dims); return xnn_status_invalid_parameter; } - if (i != 0) { - if (axis != last_axis + 1) { - xnn_log_error( - "failed to define %s operator with #%zu reduction axis of %zu: the axis is disjoint with #%zu reduction axis of %zu", - xnn_node_type_to_string(xnn_node_type_static_mean), i, axis, i - 1, last_axis); - return xnn_status_invalid_parameter; - } + } + + for (size_t i = 1; i < num_reduction_axes; i++) { + if (reduction_axes[i] <= reduction_axes[i - 1]) { + xnn_log_error( + "failed to define %s operator with #%zu reduction axis of %zu: the reduction " + "axes must be in ascending order and unique", + xnn_node_type_to_string(xnn_node_type_static_mean), i, reduction_axes[i]); + return xnn_status_invalid_parameter; } - last_axis = axis; } struct xnn_node* node = xnn_subgraph_new_node(subgraph); diff --git a/src/xnnpack/compute.h b/src/xnnpack/compute.h index 09dbdd7688a..a52413efcb5 100644 --- a/src/xnnpack/compute.h +++ b/src/xnnpack/compute.h @@ -1409,12 +1409,21 @@ struct univector_contiguous_context { struct reduce_context { const void* input; void* output; + const void* zero; size_t input_shape[XNN_MAX_TENSOR_DIMS]; size_t input_stride[XNN_MAX_TENSOR_DIMS]; size_t output_stride[XNN_MAX_TENSOR_DIMS]; size_t scaled_elements; + size_t channels; size_t element_size; - xnn_reduce_ukernel_fn ukernel; + size_t input_pixel_stride; + size_t output_pixel_stride; + size_t input_batch_stride; + size_t output_batch_stride; + union { + xnn_reduce_ukernel_fn rsum; + xnn_rdsum_ukernel_fn rdsum; + } ukernel; union { union xnn_f32_default_params f32_default; union xnn_f32_scale_params f32_scale; @@ -1433,6 +1442,18 @@ struct reduce_context { size_t output2_block_size); #endif +#ifndef __cplusplus +// Compute discontigous reduction over the 0st, 2rd and 4th dimensions of the input +// tensor. + XNN_PRIVATE void xnn_compute_discontiguous_reduce( + const struct reduce_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t output_idx0, + size_t output_idx1, + size_t output_idx2, + size_t output1_block_size, + size_t output2_block_size); +#endif + struct prelu_context { size_t n; const void* x; diff --git a/src/xnnpack/config.h b/src/xnnpack/config.h index 4b05cdb8ec5..8d6d9b5f138 100644 --- a/src/xnnpack/config.h +++ b/src/xnnpack/config.h @@ -308,6 +308,7 @@ XNN_INTERNAL const struct xnn_unary_elementwise_config* xnn_init_xx_copy_config( struct xnn_reduce_config { xnn_reduce_ukernel_fn ukernel; + xnn_rdsum_ukernel_fn rd_ukernel; union { xnn_init_f16_f32acc_scale_params_fn f16_f32acc_scale; xnn_init_f16_default_params_fn f16_default; @@ -320,9 +321,11 @@ struct xnn_reduce_config { size_t element_tile; }; XNN_INTERNAL const struct xnn_reduce_config* xnn_init_f16_f32acc_rsum_config(); +XNN_INTERNAL const struct xnn_reduce_config* xnn_init_f16_f32acc_rdsum_config(); XNN_INTERNAL const struct xnn_reduce_config* xnn_init_f16_rminmax_config(); XNN_INTERNAL const struct xnn_reduce_config* xnn_init_f32_rminmax_config(); XNN_INTERNAL const struct xnn_reduce_config* xnn_init_f32_rsum_config(); +XNN_INTERNAL const struct xnn_reduce_config* xnn_init_f32_rdsum_config(); struct xnn_xx_fill_config { xnn_fill_ukernel_fn ukernel; diff --git a/src/xnnpack/microfnptr.h b/src/xnnpack/microfnptr.h index 7cb542b7b1a..d653929083d 100644 --- a/src/xnnpack/microfnptr.h +++ b/src/xnnpack/microfnptr.h @@ -1550,6 +1550,15 @@ typedef void (*xnn_u8_reduce_ukernel_fn)( // RDSUM: Discontiguous Reduce-Sum +typedef void (*xnn_rdsum_ukernel_fn)( + size_t rows, + size_t channels, + const void* input, + size_t input_stride, + const void* zero, + void* output, + const void* params); + typedef void (*xnn_f16_f32acc_rdsum_ukernel_fn)( size_t rows, size_t channels, diff --git a/src/xnnpack/operator.h b/src/xnnpack/operator.h index 4dd5b878a46..25eee0c457d 100644 --- a/src/xnnpack/operator.h +++ b/src/xnnpack/operator.h @@ -340,7 +340,8 @@ struct xnn_operator { const struct xnn_avgpool_config* avgpool_config; const struct xnn_gavgpool_config* gavgpool_config; const struct xnn_pavgpool_config* pavgpool_config; - const struct xnn_reduce_config* reduce_config; + const struct xnn_reduce_config* rdsum_config; + const struct xnn_reduce_config* rsum_config; }; const struct xnn_gavgpool_cw_config* gavgpool_cw_config; const struct xnn_ibilinear_chw_config* ibilinear_chw_config; diff --git a/test/mean-nd.cc b/test/mean-nd.cc index 540ae4bf780..c52d83ac4f1 100644 --- a/test/mean-nd.cc +++ b/test/mean-nd.cc @@ -4,7 +4,6 @@ // LICENSE file in the root directory of this source tree. #include -#include #include #include @@ -83,19 +82,6 @@ TEST(MEAN_ND_F16, reduce_3d) { reduction_axes.push_back(2); } - size_t num_normalized_input_dims = input_shape.size(); - std::array normalized_input_shape; - std::copy(input_shape.cbegin(), input_shape.cend(), normalized_input_shape.begin()); - size_t num_normalized_reduction_axes = reduction_axes.size(); - std::array normalized_reduction_axes; - std::copy(reduction_axes.cbegin(), reduction_axes.cend(), normalized_reduction_axes.begin()); - xnn_normalize_reduction( - &num_normalized_reduction_axes, normalized_reduction_axes.data(), - &num_normalized_input_dims, normalized_input_shape.data()); - if (num_normalized_reduction_axes != 1) { - continue; // unsupported reduction configuration, will fail if we proceed - } - MeanOperatorTester() .input_shape(input_shape) .reduction_axes(reduction_axes) @@ -126,19 +112,6 @@ TEST(MEAN_ND_F16, reduce_4d) { reduction_axes.push_back(3); } - size_t num_normalized_input_dims = input_shape.size(); - std::array normalized_input_shape; - std::copy(input_shape.cbegin(), input_shape.cend(), normalized_input_shape.begin()); - size_t num_normalized_reduction_axes = reduction_axes.size(); - std::array normalized_reduction_axes; - std::copy(reduction_axes.cbegin(), reduction_axes.cend(), normalized_reduction_axes.begin()); - xnn_normalize_reduction( - &num_normalized_reduction_axes, normalized_reduction_axes.data(), - &num_normalized_input_dims, normalized_input_shape.data()); - if (num_normalized_reduction_axes != 1) { - continue; // unsupported reduction configuration, will fail if we proceed - } - MeanOperatorTester() .input_shape(input_shape) .reduction_axes(reduction_axes) @@ -173,19 +146,6 @@ TEST(MEAN_ND_F16, reduce_5d) { reduction_axes.push_back(4); } - size_t num_normalized_input_dims = input_shape.size(); - std::array normalized_input_shape; - std::copy(input_shape.cbegin(), input_shape.cend(), normalized_input_shape.begin()); - size_t num_normalized_reduction_axes = reduction_axes.size(); - std::array normalized_reduction_axes; - std::copy(reduction_axes.cbegin(), reduction_axes.cend(), normalized_reduction_axes.begin()); - xnn_normalize_reduction( - &num_normalized_reduction_axes, normalized_reduction_axes.data(), - &num_normalized_input_dims, normalized_input_shape.data()); - if (num_normalized_reduction_axes != 1) { - continue; // unsupported reduction configuration, will fail if we proceed - } - MeanOperatorTester() .input_shape(input_shape) .reduction_axes(reduction_axes) @@ -224,19 +184,6 @@ TEST(MEAN_ND_F16, reduce_6d) { reduction_axes.push_back(5); } - size_t num_normalized_input_dims = input_shape.size(); - std::array normalized_input_shape; - std::copy(input_shape.cbegin(), input_shape.cend(), normalized_input_shape.begin()); - size_t num_normalized_reduction_axes = reduction_axes.size(); - std::array normalized_reduction_axes; - std::copy(reduction_axes.cbegin(), reduction_axes.cend(), normalized_reduction_axes.begin()); - xnn_normalize_reduction( - &num_normalized_reduction_axes, normalized_reduction_axes.data(), - &num_normalized_input_dims, normalized_input_shape.data()); - if (num_normalized_reduction_axes != 1) { - continue; // unsupported reduction configuration, will fail if we proceed - } - MeanOperatorTester() .input_shape(input_shape) .reduction_axes(reduction_axes) @@ -275,19 +222,6 @@ TEST(MEAN_ND_F16, reduce_6d_multithreaded) { reduction_axes.push_back(5); } - size_t num_normalized_input_dims = input_shape.size(); - std::array normalized_input_shape; - std::copy(input_shape.cbegin(), input_shape.cend(), normalized_input_shape.begin()); - size_t num_normalized_reduction_axes = reduction_axes.size(); - std::array normalized_reduction_axes; - std::copy(reduction_axes.cbegin(), reduction_axes.cend(), normalized_reduction_axes.begin()); - xnn_normalize_reduction( - &num_normalized_reduction_axes, normalized_reduction_axes.data(), - &num_normalized_input_dims, normalized_input_shape.data()); - if (num_normalized_reduction_axes != 1) { - continue; // unsupported reduction configuration, will fail if we proceed - } - MeanOperatorTester() .input_shape(input_shape) .reduction_axes(reduction_axes) @@ -310,90 +244,6 @@ TEST(MEAN_ND_F32, reduce_first_axis) { .TestF32(); } -TEST(MEAN_ND_F16, 1d_contig) { - MeanOperatorTester() - .input_shape({kDim3}) - .reduction_axes({0}) - .TestF16(); -} - -TEST(MEAN_ND_F16, 2d_contig) { - MeanOperatorTester() - .input_shape({kDim1, kDim2}) - .reduction_axes({1}) - .TestF16(); -} - -TEST(MEAN_ND_F16, 3d_contig) { - MeanOperatorTester() - .input_shape({kDim1, kDim2, kDim3}) - .reduction_axes({0, 2}) - .TestF16(); -} - -TEST(MEAN_ND_F16, 4d_contig) { - MeanOperatorTester() - .input_shape({kDim1, kDim2, kDim3, kDim4}) - .reduction_axes({1, 3}) - .TestF16(); -} - -TEST(MEAN_ND_F16, 5d_contig) { - MeanOperatorTester() - .input_shape({kDim1, kDim2, kDim3, kDim4, kDim5}) - .reduction_axes({0, 2, 4}) - .TestF16(); -} - -TEST(MEAN_ND_F16, 6d_contig) { - MeanOperatorTester() - .input_shape({kDim1, kDim2, kDim3, kDim4, kDim5, kDim6}) - .reduction_axes({1, 3, 5}) - .TestF16(); -} - -TEST(MEAN_ND_F32, 1d_contig) { - MeanOperatorTester() - .input_shape({kDim3}) - .reduction_axes({0}) - .TestF32(); -} - -TEST(MEAN_ND_F32, 2d_contig) { - MeanOperatorTester() - .input_shape({kDim1, kDim2}) - .reduction_axes({1}) - .TestF32(); -} - -TEST(MEAN_ND_F32, 3d_contig) { - MeanOperatorTester() - .input_shape({kDim1, kDim2, kDim3}) - .reduction_axes({0, 2}) - .TestF32(); -} - -TEST(MEAN_ND_F32, 4d_contig) { - MeanOperatorTester() - .input_shape({kDim1, kDim2, kDim3, kDim4}) - .reduction_axes({1, 3}) - .TestF32(); -} - -TEST(MEAN_ND_F32, 5d_contig) { - MeanOperatorTester() - .input_shape({kDim1, kDim2, kDim3, kDim4, kDim5}) - .reduction_axes({0, 2, 4}) - .TestF32(); -} - -TEST(MEAN_ND_F32, 6d_contig) { - MeanOperatorTester() - .input_shape({kDim1, kDim2, kDim3, kDim4, kDim5, kDim6}) - .reduction_axes({1, 3, 5}) - .TestF32(); -} - TEST(MEAN_ND_F32, reduce_last_axis) { MeanOperatorTester() .input_shape({kDim1, kDim2, kDim3}) @@ -447,19 +297,6 @@ TEST(MEAN_ND_F32, reduce_3d) { reduction_axes.push_back(2); } - size_t num_normalized_input_dims = input_shape.size(); - std::array normalized_input_shape; - std::copy(input_shape.cbegin(), input_shape.cend(), normalized_input_shape.begin()); - size_t num_normalized_reduction_axes = reduction_axes.size(); - std::array normalized_reduction_axes; - std::copy(reduction_axes.cbegin(), reduction_axes.cend(), normalized_reduction_axes.begin()); - xnn_normalize_reduction( - &num_normalized_reduction_axes, normalized_reduction_axes.data(), - &num_normalized_input_dims, normalized_input_shape.data()); - if (num_normalized_reduction_axes != 1) { - continue; // unsupported reduction configuration, will fail if we proceed - } - MeanOperatorTester() .input_shape(input_shape) .reduction_axes(reduction_axes) @@ -490,19 +327,6 @@ TEST(MEAN_ND_F32, reduce_4d) { reduction_axes.push_back(3); } - size_t num_normalized_input_dims = input_shape.size(); - std::array normalized_input_shape; - std::copy(input_shape.cbegin(), input_shape.cend(), normalized_input_shape.begin()); - size_t num_normalized_reduction_axes = reduction_axes.size(); - std::array normalized_reduction_axes; - std::copy(reduction_axes.cbegin(), reduction_axes.cend(), normalized_reduction_axes.begin()); - xnn_normalize_reduction( - &num_normalized_reduction_axes, normalized_reduction_axes.data(), - &num_normalized_input_dims, normalized_input_shape.data()); - if (num_normalized_reduction_axes != 1) { - continue; // unsupported reduction configuration, will fail if we proceed - } - MeanOperatorTester() .input_shape(input_shape) .reduction_axes(reduction_axes) @@ -537,19 +361,6 @@ TEST(MEAN_ND_F32, reduce_5d) { reduction_axes.push_back(4); } - size_t num_normalized_input_dims = input_shape.size(); - std::array normalized_input_shape; - std::copy(input_shape.cbegin(), input_shape.cend(), normalized_input_shape.begin()); - size_t num_normalized_reduction_axes = reduction_axes.size(); - std::array normalized_reduction_axes; - std::copy(reduction_axes.cbegin(), reduction_axes.cend(), normalized_reduction_axes.begin()); - xnn_normalize_reduction( - &num_normalized_reduction_axes, normalized_reduction_axes.data(), - &num_normalized_input_dims, normalized_input_shape.data()); - if (num_normalized_reduction_axes != 1) { - continue; // unsupported reduction configuration, will fail if we proceed - } - MeanOperatorTester() .input_shape(input_shape) .reduction_axes(reduction_axes) @@ -588,19 +399,6 @@ TEST(MEAN_ND_F32, reduce_6d) { reduction_axes.push_back(5); } - size_t num_normalized_input_dims = input_shape.size(); - std::array normalized_input_shape; - std::copy(input_shape.cbegin(), input_shape.cend(), normalized_input_shape.begin()); - size_t num_normalized_reduction_axes = reduction_axes.size(); - std::array normalized_reduction_axes; - std::copy(reduction_axes.cbegin(), reduction_axes.cend(), normalized_reduction_axes.begin()); - xnn_normalize_reduction( - &num_normalized_reduction_axes, normalized_reduction_axes.data(), - &num_normalized_input_dims, normalized_input_shape.data()); - if (num_normalized_reduction_axes != 1) { - continue; // unsupported reduction configuration, will fail if we proceed - } - MeanOperatorTester() .input_shape(input_shape) .reduction_axes(reduction_axes) @@ -639,23 +437,10 @@ TEST(MEAN_ND_F32, reduce_6d_multithreaded) { reduction_axes.push_back(5); } - size_t num_normalized_input_dims = input_shape.size(); - std::array normalized_input_shape; - std::copy(input_shape.cbegin(), input_shape.cend(), normalized_input_shape.begin()); - size_t num_normalized_reduction_axes = reduction_axes.size(); - std::array normalized_reduction_axes; - std::copy(reduction_axes.cbegin(), reduction_axes.cend(), normalized_reduction_axes.begin()); - xnn_normalize_reduction( - &num_normalized_reduction_axes, normalized_reduction_axes.data(), - &num_normalized_input_dims, normalized_input_shape.data()); - if (num_normalized_reduction_axes != 1) { - continue; // unsupported reduction configuration, will fail if we proceed - } - MeanOperatorTester() .input_shape(input_shape) .reduction_axes(reduction_axes) .multithreaded(true) .TestF32(); } -} \ No newline at end of file +} diff --git a/test/mean-operator-tester.h b/test/mean-operator-tester.h index 7f89125a4c4..afff30ee2c9 100644 --- a/test/mean-operator-tester.h +++ b/test/mean-operator-tester.h @@ -179,8 +179,6 @@ class MeanOperatorTester { // Smart pointer to automatically delete mean_op. std::unique_ptr auto_mean_op(mean_op, xnn_delete_operator); - size_t workspace_size = SIZE_MAX; - size_t workspace_alignment = SIZE_MAX; ASSERT_EQ(xnn_status_success, xnn_reshape_mean_nd_f16( mean_op, @@ -188,17 +186,11 @@ class MeanOperatorTester { reduction_axes().data(), num_input_dims(), input_shape().data(), - &workspace_size, &workspace_alignment, auto_threadpool.get())); - ASSERT_NE(workspace_size, SIZE_MAX); - ASSERT_LE(workspace_alignment, XNN_ALLOCATION_ALIGNMENT); - std::vector> workspace(workspace_size); - ASSERT_EQ(xnn_status_success, xnn_setup_mean_nd_f16( mean_op, - workspace.data(), input.data(), output.data())); ASSERT_EQ(xnn_status_success, @@ -305,8 +297,6 @@ class MeanOperatorTester { // Smart pointer to automatically delete mean_op. std::unique_ptr auto_mean_op(mean_op, xnn_delete_operator); - size_t workspace_size = SIZE_MAX; - size_t workspace_alignment = SIZE_MAX; ASSERT_EQ(xnn_status_success, xnn_reshape_mean_nd_f32( mean_op, @@ -314,17 +304,11 @@ class MeanOperatorTester { reduction_axes().data(), num_input_dims(), input_shape().data(), - &workspace_size, &workspace_alignment, auto_threadpool.get())); - ASSERT_NE(workspace_size, SIZE_MAX); - ASSERT_LE(workspace_alignment, XNN_ALLOCATION_ALIGNMENT); - std::vector> workspace(workspace_size); - ASSERT_EQ(xnn_status_success, xnn_setup_mean_nd_f32( mean_op, - workspace.data(), input.data(), output.data())); ASSERT_EQ(xnn_status_success, diff --git a/test/static-mean.cc b/test/static-mean.cc index 63e94b8e06b..8fb165713d1 100644 --- a/test/static-mean.cc +++ b/test/static-mean.cc @@ -33,16 +33,15 @@ template class MeanTestBase : public ::testing::Test { auto num_input_dim_dist = std::uniform_int_distribution(2, XNN_MAX_TENSOR_DIMS); const size_t num_input_dims = num_input_dim_dist(rng); + auto num_reduction_axes_dist = std::uniform_int_distribution(1, num_input_dims); + const size_t num_reduction_axes = num_reduction_axes_dist(rng); - auto reduction_axes_seq_start_dist = std::uniform_int_distribution(0, num_input_dims - 1); - const size_t reduction_axes_seq_start = reduction_axes_seq_start_dist(rng); - auto reduction_axes_seq_end_dist = std::uniform_int_distribution(reduction_axes_seq_start + 1, num_input_dims); - const size_t reduction_axes_seq_end = reduction_axes_seq_end_dist(rng); - - reduction_axes.clear(); - for (size_t axis = reduction_axes_seq_start; axis < reduction_axes_seq_end; axis++) { - reduction_axes.push_back(axis); - } + auto axes_dist = std::uniform_int_distribution(0, num_input_dims - 1); + reduction_axes.resize(num_reduction_axes); + std::generate(reduction_axes.begin(), reduction_axes.end(), [&]() { return axes_dist(rng); }); + std::sort(reduction_axes.begin(), reduction_axes.end()); + auto end = std::unique(reduction_axes.begin(), reduction_axes.end()); + reduction_axes.erase(end, reduction_axes.end()); auto shape_dist = std::uniform_int_distribution(2, 15); input_shape.resize(num_input_dims); @@ -181,18 +180,13 @@ TEST_F(MeanTestF16, matches_operator_api) std::unique_ptr auto_op(op, xnn_delete_operator); - size_t workspace_size = 0; - size_t workspace_alignment = 0; ASSERT_EQ(xnn_status_success, xnn_reshape_mean_nd_f16(op, reduction_axes.size(), reduction_axes.data(), input_shape.size(), input_shape.data(), - &workspace_size, &workspace_alignment, /*threadpool=*/nullptr)); - ASSERT_LE(workspace_alignment, XNN_ALLOCATION_ALIGNMENT); - std::vector> workspace(workspace_size); - ASSERT_EQ(xnn_status_success, xnn_setup_mean_nd_f16(op, workspace.data(), input.data(), operator_output.data())); + ASSERT_EQ(xnn_status_success, xnn_setup_mean_nd_f16(op, input.data(), operator_output.data())); ASSERT_EQ(xnn_status_success, xnn_run_operator(op, /*threadpool=*/nullptr)); @@ -258,18 +252,13 @@ TEST_F(MeanTestF32, matches_operator_api) std::unique_ptr auto_op(op, xnn_delete_operator); - size_t workspace_size = 0; - size_t workspace_alignment = 0; ASSERT_EQ(xnn_status_success, xnn_reshape_mean_nd_f32(op, reduction_axes.size(), reduction_axes.data(), input_shape.size(), input_shape.data(), - &workspace_size, &workspace_alignment, /*threadpool=*/nullptr)); - ASSERT_LE(workspace_alignment, XNN_ALLOCATION_ALIGNMENT); - std::vector> workspace(workspace_size); - ASSERT_EQ(xnn_status_success, xnn_setup_mean_nd_f32(op, workspace.data(), input.data(), operator_output.data())); + ASSERT_EQ(xnn_status_success, xnn_setup_mean_nd_f32(op, input.data(), operator_output.data())); ASSERT_EQ(xnn_status_success, xnn_run_operator(op, /*threadpool=*/nullptr));