From d2039879cb9f9f7283746002f4d6c500ca5d7d10 Mon Sep 17 00:00:00 2001 From: Shelley Goel Date: Mon, 11 Mar 2019 10:54:37 -0700 Subject: [PATCH] src: cpu: fwd bnorm: use more accurate division w/ scaleshift --- src/cpu/jit_uni_batch_normalization.cpp | 22 +++++++++++++--------- src/cpu/ncsp_batch_normalization.cpp | 9 ++++----- src/cpu/nspc_batch_normalization.cpp | 15 +++++++-------- src/cpu/ref_batch_normalization.cpp | 13 +++++++------ tests/benchdnn/bnorm/ref_bnorm.cpp | 6 +++--- 5 files changed, 34 insertions(+), 31 deletions(-) diff --git a/src/cpu/jit_uni_batch_normalization.cpp b/src/cpu/jit_uni_batch_normalization.cpp index f895d931b57..72fe3a81097 100644 --- a/src/cpu/jit_uni_batch_normalization.cpp +++ b/src/cpu/jit_uni_batch_normalization.cpp @@ -652,19 +652,22 @@ struct jit_bnorm_t: public jit_generator { uni_vaddps(vsqrtvar, vsqrtvar, veps); uni_vsqrtps(vsqrtvar, vsqrtvar); - if (isa == sse42) { - movups(vbuf, vone); - divps(vbuf, vsqrtvar); - movups(vsqrtvar, vbuf); - } else { - vdivps(vsqrtvar, vone, vsqrtvar); - } - if (bdesc_->use_scaleshift()) { uni_vmovups_maybe_tail(vgamma, gamma_ptr()); uni_vmovups_maybe_tail(vbeta, beta_ptr()); } + Vmm vscale = bdesc_->use_scaleshift() ? vgamma : vone; + Vmm vdiv = bdesc_->use_scaleshift() ? vgamma : vsqrtvar; + + if (isa == sse42) { + movups(vbuf, vscale); + divps(vbuf, vsqrtvar); + movups(vdiv, vbuf); + } else { + vdivps(vdiv, vscale, vsqrtvar); + } + auto compute = [=](bool output_is_aligned) { spat_loop(spat_size, unroll_blocks, unroll_regs, [](size_t base_reg) {UNUSED(base_reg);}, @@ -678,9 +681,10 @@ struct jit_bnorm_t: public jit_generator { mic_prefetcht1(ptr[reg_src + reg_soff + offt + t1_pf_offt]); uni_vsubps(v, v, vmean); - uni_vmulps(v, v, vsqrtvar); if (bdesc_->use_scaleshift()) { uni_vfmadd213ps(v, vgamma, vbeta); + } else { + uni_vmulps(v, v, vsqrtvar); } if (with_relu_inf_only) { uni_vmaxps(v, v, vzero); diff --git a/src/cpu/ncsp_batch_normalization.cpp b/src/cpu/ncsp_batch_normalization.cpp index e4c2ab87e29..c0e93fefe4d 100644 --- a/src/cpu/ncsp_batch_normalization.cpp +++ b/src/cpu/ncsp_batch_normalization.cpp @@ -191,10 +191,10 @@ void ncsp_batch_normalization_fwd_t::execute_forward( for (dim_t c = C_blk_s; c < C_blk_e; c++) { size_t off = c + C_off; - data_t sm = use_scaleshift ? scaleshift[off] : 1; - data_t sv = use_scaleshift ? scaleshift[C + off] : 0; data_t sqrt_variance - = static_cast(1.0f / sqrtf(variance[off] + eps)); + = static_cast(sqrtf(variance[off] + eps)); + data_t sm = (use_scaleshift ? scaleshift[off] : 1.0f) / sqrt_variance; + data_t sv = use_scaleshift ? scaleshift[C + off] : 0; for (dim_t n = N_s; n < N_e; ++n) #if SAFE_TO_USE_OMP_SIMD PRAGMA_OMP_SIMD() @@ -202,8 +202,7 @@ void ncsp_batch_normalization_fwd_t::execute_forward( for (dim_t sp = S_s; sp < S_e; ++sp) { size_t d_off = off * SP + n * C * SP + sp; data_t bn_res - = sm * (src[d_off] - mean[off]) * sqrt_variance - + sv; + = sm * (src[d_off] - mean[off]) + sv; if (fuse_bn_relu) { if (bn_res <= 0) { bn_res = 0; diff --git a/src/cpu/nspc_batch_normalization.cpp b/src/cpu/nspc_batch_normalization.cpp index 5428adc8606..e20333e66fe 100644 --- a/src/cpu/nspc_batch_normalization.cpp +++ b/src/cpu/nspc_batch_normalization.cpp @@ -151,23 +151,22 @@ void nspc_batch_normalization_fwd_t::execute_forward( #endif for (dim_t c = 0; c < C; c++) { data_t sqrt_variance = static_cast( - 1.0f / sqrtf(variance_loc[c] + eps)); - data_t sm = use_scaleshift ? scaleshift[c] : 1; + sqrtf(variance_loc[c] + eps)); + data_t sm = (use_scaleshift ? scaleshift[c] : 1.0f) / sqrt_variance; data_t sv = use_scaleshift ? scaleshift[C + c] : 0; - data_t bn_res - = sm * (src[(size_t)n * SP * C + sp * C + c] - - mean_loc[c]) * sqrt_variance + sv; + size_t d_off = (size_t)n * SP * C + sp * C + c; + data_t bn_res = sm * (src[d_off] - mean_loc[c]) + sv; if (fuse_bn_relu) { if (bn_res <= 0) { bn_res = 0; if (is_training) - ws[(size_t)n * SP * C + sp * C + c] = 0; + ws[d_off] = 0; } else { if (is_training) - ws[(size_t)n * SP * C + sp * C + c] = 1; + ws[d_off] = 1; } } - dst[(size_t)n * SP * C + sp * C + c] = maybe_post_op(bn_res); + dst[d_off] = maybe_post_op(bn_res); } } } diff --git a/src/cpu/ref_batch_normalization.cpp b/src/cpu/ref_batch_normalization.cpp index a7db61d5f6c..d79b1a034ba 100644 --- a/src/cpu/ref_batch_normalization.cpp +++ b/src/cpu/ref_batch_normalization.cpp @@ -88,8 +88,6 @@ void ref_batch_normalization_fwd_t::execute_forward( float v_mean = calculate_stats ? 0 : mean[c]; float v_variance = calculate_stats ? 0 : variance[c]; - float sm = use_scaleshift ? scaleshift[scaleshift_d.off(0, c)] : 1; - float sv = use_scaleshift ? scaleshift[scaleshift_d.off(1, c)] : 0; if (calculate_stats) { for (dim_t n = 0; n < N; ++n) for (dim_t d = 0; d < D; ++d) @@ -108,15 +106,18 @@ void ref_batch_normalization_fwd_t::execute_forward( v_variance /= W*H*N*D; } - float sqrt_variance = 1.0f / sqrtf(v_variance + eps); + float sqrt_variance = sqrtf(v_variance + eps); + float sm = (use_scaleshift + ? scaleshift[scaleshift_d.off(0, c)] + : 1.0f) / sqrt_variance; + float sv = use_scaleshift ? scaleshift[scaleshift_d.off(1, c)] : 0; for (dim_t n = 0; n < N; ++n) for (dim_t d = 0; d < D; ++d) for (dim_t h = 0; h < H; ++h) for (dim_t w = 0; w < W; ++w) { - auto d_off = data_offset(data_d, n, c, d, h, w); - float bn_res = sm * ((float)src[d_off] - v_mean) * - sqrt_variance + sv; + auto d_off = data_offset(data_d,n,c,d,h,w); + float bn_res = sm * ((float)src[d_off] - v_mean) + sv; if (fuse_bn_relu) { if (bn_res <= 0) { bn_res = 0; diff --git a/tests/benchdnn/bnorm/ref_bnorm.cpp b/tests/benchdnn/bnorm/ref_bnorm.cpp index 387fb0a19eb..20287634bff 100644 --- a/tests/benchdnn/bnorm/ref_bnorm.cpp +++ b/tests/benchdnn/bnorm/ref_bnorm.cpp @@ -43,9 +43,9 @@ void compute_ref_fwd(const prb_t *p, const dnn_mem_t &src, dnn_mem_t &mean, mkldnn::impl::parallel_nd(p->ic, [&](int64_t c) { float smean = ((float *)mean)[c]; float svar = ((float *)var)[c]; - float rcp_denom = (float)(1.0f / (sqrtf(svar + p->eps))); + float sqrt_var = sqrtf(svar + p->eps); - float gamma = p->flags & USE_SCALESHIFT ? ((float *)ss)[c] : 1; + float gamma = (p->flags & USE_SCALESHIFT ? ((float *)ss)[c] : 1.0f) / sqrt_var; float beta = p->flags & USE_SCALESHIFT ? ((float *)ss)[p->ic + c] : 0; for (int64_t mb = 0; mb < p->mb; ++mb) @@ -53,7 +53,7 @@ void compute_ref_fwd(const prb_t *p, const dnn_mem_t &src, dnn_mem_t &mean, for (int64_t h = 0; h < p->ih; ++h) for (int64_t w = 0; w < p->iw; ++w) { auto off = data_off(p, mb, c, d, h, w); - float res = gamma * (((float *)src)[off] - smean) * rcp_denom + beta; + float res = gamma * (((float *)src)[off] - smean) + beta; float &D = ((float *)dst)[off]; if ((p->flags & FUSE_BN_RELU) && res < 0) res = 0; maybe_post_ops(res, D);