Skip to content

Commit

Permalink
src: cpu: fwd bnorm: use more accurate division w/ scaleshift
Browse files Browse the repository at this point in the history
  • Loading branch information
shelleygoel committed Mar 15, 2019
1 parent a8565bf commit d203987
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 31 deletions.
22 changes: 13 additions & 9 deletions src/cpu/jit_uni_batch_normalization.cpp
Expand Up @@ -652,19 +652,22 @@ struct jit_bnorm_t: public jit_generator {
uni_vaddps(vsqrtvar, vsqrtvar, veps);
uni_vsqrtps(vsqrtvar, vsqrtvar);

if (isa == sse42) {
movups(vbuf, vone);
divps(vbuf, vsqrtvar);
movups(vsqrtvar, vbuf);
} else {
vdivps(vsqrtvar, vone, vsqrtvar);
}

if (bdesc_->use_scaleshift()) {
uni_vmovups_maybe_tail(vgamma, gamma_ptr());
uni_vmovups_maybe_tail(vbeta, beta_ptr());
}

Vmm vscale = bdesc_->use_scaleshift() ? vgamma : vone;
Vmm vdiv = bdesc_->use_scaleshift() ? vgamma : vsqrtvar;

if (isa == sse42) {
movups(vbuf, vscale);
divps(vbuf, vsqrtvar);
movups(vdiv, vbuf);
} else {
vdivps(vdiv, vscale, vsqrtvar);
}

auto compute = [=](bool output_is_aligned) {
spat_loop(spat_size, unroll_blocks, unroll_regs,
[](size_t base_reg) {UNUSED(base_reg);},
Expand All @@ -678,9 +681,10 @@ struct jit_bnorm_t: public jit_generator {
mic_prefetcht1(ptr[reg_src + reg_soff + offt
+ t1_pf_offt]);
uni_vsubps(v, v, vmean);
uni_vmulps(v, v, vsqrtvar);
if (bdesc_->use_scaleshift()) {
uni_vfmadd213ps(v, vgamma, vbeta);
} else {
uni_vmulps(v, v, vsqrtvar);
}
if (with_relu_inf_only) {
uni_vmaxps(v, v, vzero);
Expand Down
9 changes: 4 additions & 5 deletions src/cpu/ncsp_batch_normalization.cpp
Expand Up @@ -191,19 +191,18 @@ void ncsp_batch_normalization_fwd_t::execute_forward(

for (dim_t c = C_blk_s; c < C_blk_e; c++) {
size_t off = c + C_off;
data_t sm = use_scaleshift ? scaleshift[off] : 1;
data_t sv = use_scaleshift ? scaleshift[C + off] : 0;
data_t sqrt_variance
= static_cast<data_t>(1.0f / sqrtf(variance[off] + eps));
= static_cast<data_t>(sqrtf(variance[off] + eps));
data_t sm = (use_scaleshift ? scaleshift[off] : 1.0f) / sqrt_variance;
data_t sv = use_scaleshift ? scaleshift[C + off] : 0;
for (dim_t n = N_s; n < N_e; ++n)
#if SAFE_TO_USE_OMP_SIMD
PRAGMA_OMP_SIMD()
#endif
for (dim_t sp = S_s; sp < S_e; ++sp) {
size_t d_off = off * SP + n * C * SP + sp;
data_t bn_res
= sm * (src[d_off] - mean[off]) * sqrt_variance
+ sv;
= sm * (src[d_off] - mean[off]) + sv;
if (fuse_bn_relu) {
if (bn_res <= 0) {
bn_res = 0;
Expand Down
15 changes: 7 additions & 8 deletions src/cpu/nspc_batch_normalization.cpp
Expand Up @@ -151,23 +151,22 @@ void nspc_batch_normalization_fwd_t::execute_forward(
#endif
for (dim_t c = 0; c < C; c++) {
data_t sqrt_variance = static_cast<data_t>(
1.0f / sqrtf(variance_loc[c] + eps));
data_t sm = use_scaleshift ? scaleshift[c] : 1;
sqrtf(variance_loc[c] + eps));
data_t sm = (use_scaleshift ? scaleshift[c] : 1.0f) / sqrt_variance;
data_t sv = use_scaleshift ? scaleshift[C + c] : 0;
data_t bn_res
= sm * (src[(size_t)n * SP * C + sp * C + c]
- mean_loc[c]) * sqrt_variance + sv;
size_t d_off = (size_t)n * SP * C + sp * C + c;
data_t bn_res = sm * (src[d_off] - mean_loc[c]) + sv;
if (fuse_bn_relu) {
if (bn_res <= 0) {
bn_res = 0;
if (is_training)
ws[(size_t)n * SP * C + sp * C + c] = 0;
ws[d_off] = 0;
} else {
if (is_training)
ws[(size_t)n * SP * C + sp * C + c] = 1;
ws[d_off] = 1;
}
}
dst[(size_t)n * SP * C + sp * C + c] = maybe_post_op(bn_res);
dst[d_off] = maybe_post_op(bn_res);
}
}
}
Expand Down
13 changes: 7 additions & 6 deletions src/cpu/ref_batch_normalization.cpp
Expand Up @@ -88,8 +88,6 @@ void ref_batch_normalization_fwd_t<data_type>::execute_forward(
float v_mean = calculate_stats ? 0 : mean[c];
float v_variance = calculate_stats ? 0 : variance[c];

float sm = use_scaleshift ? scaleshift[scaleshift_d.off(0, c)] : 1;
float sv = use_scaleshift ? scaleshift[scaleshift_d.off(1, c)] : 0;
if (calculate_stats) {
for (dim_t n = 0; n < N; ++n)
for (dim_t d = 0; d < D; ++d)
Expand All @@ -108,15 +106,18 @@ void ref_batch_normalization_fwd_t<data_type>::execute_forward(
v_variance /= W*H*N*D;
}

float sqrt_variance = 1.0f / sqrtf(v_variance + eps);
float sqrt_variance = sqrtf(v_variance + eps);
float sm = (use_scaleshift
? scaleshift[scaleshift_d.off(0, c)]
: 1.0f) / sqrt_variance;
float sv = use_scaleshift ? scaleshift[scaleshift_d.off(1, c)] : 0;

for (dim_t n = 0; n < N; ++n)
for (dim_t d = 0; d < D; ++d)
for (dim_t h = 0; h < H; ++h)
for (dim_t w = 0; w < W; ++w) {
auto d_off = data_offset(data_d, n, c, d, h, w);
float bn_res = sm * ((float)src[d_off] - v_mean) *
sqrt_variance + sv;
auto d_off = data_offset(data_d,n,c,d,h,w);
float bn_res = sm * ((float)src[d_off] - v_mean) + sv;
if (fuse_bn_relu) {
if (bn_res <= 0) {
bn_res = 0;
Expand Down
6 changes: 3 additions & 3 deletions tests/benchdnn/bnorm/ref_bnorm.cpp
Expand Up @@ -43,17 +43,17 @@ void compute_ref_fwd(const prb_t *p, const dnn_mem_t &src, dnn_mem_t &mean,
mkldnn::impl::parallel_nd(p->ic, [&](int64_t c) {
float smean = ((float *)mean)[c];
float svar = ((float *)var)[c];
float rcp_denom = (float)(1.0f / (sqrtf(svar + p->eps)));
float sqrt_var = sqrtf(svar + p->eps);

float gamma = p->flags & USE_SCALESHIFT ? ((float *)ss)[c] : 1;
float gamma = (p->flags & USE_SCALESHIFT ? ((float *)ss)[c] : 1.0f) / sqrt_var;
float beta = p->flags & USE_SCALESHIFT ? ((float *)ss)[p->ic + c] : 0;

for (int64_t mb = 0; mb < p->mb; ++mb)
for (int64_t d = 0; d < p->id; ++d)
for (int64_t h = 0; h < p->ih; ++h)
for (int64_t w = 0; w < p->iw; ++w) {
auto off = data_off(p, mb, c, d, h, w);
float res = gamma * (((float *)src)[off] - smean) * rcp_denom + beta;
float res = gamma * (((float *)src)[off] - smean) + beta;
float &D = ((float *)dst)[off];
if ((p->flags & FUSE_BN_RELU) && res < 0) res = 0;
maybe_post_ops(res, D);
Expand Down

0 comments on commit d203987

Please sign in to comment.