Skip to content

Commit

Permalink
cpu, gpu: fix layer normalization bwd pass
Browse files Browse the repository at this point in the history
  • Loading branch information
irinasok authored and Fomenko, Evarist M committed Jan 2, 2020
1 parent cb2cc7a commit c176ceb
Show file tree
Hide file tree
Showing 11 changed files with 345 additions and 197 deletions.
11 changes: 2 additions & 9 deletions src/cpu/jit_uni_layer_normalization.cpp
Expand Up @@ -139,17 +139,10 @@ void jit_uni_layer_normalization_bwd_t::execute_backward(
dim_t N_s = 0, N_e = 0;
balance211(N, nthr, ithr, N_s, N_e);

float *my_diff_gamma = reduce + C * ithr;
float *my_diff_beta = reduce + C * nthr + C * ithr;
for (dim_t c = 0; c < C; c++) {
my_diff_gamma[c] = diff_scaleshift[c];
my_diff_beta[c] = diff_scaleshift[C + c];
}

for (dim_t n = N_s; n < N_e; n++) {
(*diff_data_kernel_)(&src[n * C_padded], &diff_dst[n * C_padded],
&diff_src[n * C_padded], my_diff_gamma, my_diff_beta,
scaleshift, &mean[n], &variance[n]);
&diff_src[n * C_padded], scaleshift, &mean[n],
&variance[n]);
}
});
}
Expand Down
100 changes: 71 additions & 29 deletions src/cpu/jit_uni_layer_normalization_kernels.hpp
Expand Up @@ -449,31 +449,37 @@ class diff_data_kernel_t : jit_generator {
}
~diff_data_kernel_t() {}
void operator()(const float *src, const float *diff_dst, float *diff_src,
float *diff_gamma, const float *diff_beta, const float *ss,
const float *mean, const float *var) {
const float *ss, const float *mean, const float *var) {
if (ker_) {
ker_args args;
args.src = src;
args.diff_dst = diff_dst;
args.diff_src = diff_src;
args.diff_gamma = diff_gamma;
args.diff_beta = diff_beta;
args.ss = ss;
args.mean = mean;
float inv_sqrtvar = 1. / sqrtf(*var + eps_);
float inv_sqrtvar = 1.f / sqrtf(*var + eps_);
args.inv_sqrtvar = &inv_sqrtvar;
ker_(&args);
} else {
float inv_sqrtvar = 1. / sqrtf(*var + eps_);
float inv_sqrtvar = 1.f / sqrtf(*var + eps_);
float dd_gamma = 0, dd_gamma_x = 0;
if (calculate_diff_stats_) {
PRAGMA_OMP_SIMD(reduction(+ : dd_gamma, dd_gamma_x))
for (dim_t c = 0; c < C_; c++) {
float gamma = use_scaleshift_ ? ss[c] : 1;
dd_gamma += diff_dst[c] * gamma;
dd_gamma_x += diff_dst[c] * gamma * (src[c] - *mean);
}
dd_gamma_x *= inv_sqrtvar;
}
PRAGMA_OMP_SIMD()
for (dim_t c = 0; c < C_; c++) {
float gamma = use_scaleshift_ ? ss[c] : 1;
float v_diff_src = diff_dst[c];
float v_diff_src = diff_dst[c] * gamma;
if (calculate_diff_stats_)
v_diff_src -= diff_beta[c] / C_
+ (src[c] - *mean) * diff_gamma[c] * inv_sqrtvar
/ C_;
v_diff_src *= gamma * inv_sqrtvar;
v_diff_src -= dd_gamma / C_
+ (src[c] - *mean) * dd_gamma_x * inv_sqrtvar / C_;
v_diff_src *= inv_sqrtvar;
diff_src[c] = v_diff_src;
}
}
Expand All @@ -490,8 +496,6 @@ class diff_data_kernel_t : jit_generator {
const float *src;
const float *diff_dst;
float *diff_src;
const float *diff_gamma;
const float *diff_beta;
const float *ss;
const float *mean;
const float *inv_sqrtvar;
Expand Down Expand Up @@ -526,8 +530,6 @@ class diff_data_kernel_t : jit_generator {
mov(reg_src, ptr[reg_param + PARAM_OFF(src)]);
mov(reg_diff_dst, ptr[reg_param + PARAM_OFF(diff_dst)]);
mov(reg_diff_src, ptr[reg_param + PARAM_OFF(diff_src)]);
mov(reg_diff_gamma, ptr[reg_param + PARAM_OFF(diff_gamma)]);
mov(reg_diff_beta, ptr[reg_param + PARAM_OFF(diff_beta)]);
mov(reg_gamma, ptr[reg_param + PARAM_OFF(ss)]);

if (calculate_diff_stats_) {
Expand All @@ -546,31 +548,71 @@ class diff_data_kernel_t : jit_generator {
uni_vbroadcastss(ymm_C, xmm_tmp);

const int C_vecs = C_ / simd_w_;
auto op = [=](int nelems, size_t offt) {

auto compute_dd_gammas = [=](int nelems, size_t offt) {
Ymm ymm_ddst = ymm_dsrc;
load(ymm_ddst, reg_diff_dst, nelems, offt);
if (use_scaleshift_) {
load(ymm_gamma, reg_gamma, nelems, offt);
vmulps(ymm_ddst, ymm_ddst, ymm_gamma);
}
load(ymm_src, reg_src, nelems, offt);
vaddps(ymm_dd_gamma, ymm_dd_gamma, ymm_ddst);
vsubps(ymm_src, ymm_src, ymm_mean);
vfmadd231ps(ymm_dd_gamma_x, ymm_ddst, ymm_src);
};

auto reduce = [=](Ymm ymm_vec) {
vextractf128(xmm_tmp, ymm_vec, 1);
Xmm xmm_vec = Xmm(ymm_vec.getIdx());
vaddps(xmm_vec, xmm_tmp, xmm_vec);
vhaddps(xmm_vec, xmm_vec, xmm_vec);
vhaddps(xmm_vec, xmm_vec, xmm_vec);
};

auto compute_diff_src = [=](int nelems, size_t offt) {
load(ymm_dsrc, reg_diff_dst, nelems, offt);
if (use_scaleshift_) load(ymm_gamma, reg_gamma, nelems, offt);
if (calculate_diff_stats_) {
load(ymm_dbeta, reg_diff_beta, nelems, offt);
load(ymm_dgamma, reg_diff_gamma, nelems, offt);
load(ymm_src, reg_src, nelems, offt);
if (use_scaleshift_) {
load(ymm_gamma, reg_gamma, nelems, offt);
vmulps(ymm_dsrc, ymm_dsrc, ymm_gamma);
}
if (calculate_diff_stats_) {
load(ymm_src, reg_src, nelems, offt);
vsubps(ymm_src, ymm_src, ymm_mean);
vmulps(ymm_src, ymm_src, ymm_inv_sqrtvar);
vfmadd213ps(ymm_src, ymm_dgamma, ymm_dbeta);
vfmadd213ps(ymm_src, ymm_dd_gamma_x, ymm_dd_gamma);
vdivps(ymm_src, ymm_src, ymm_C);
vsubps(ymm_dsrc, ymm_dsrc, ymm_src);
}
if (use_scaleshift_) vmulps(ymm_dsrc, ymm_dsrc, ymm_gamma);
vmulps(ymm_dsrc, ymm_dsrc, ymm_inv_sqrtvar);
store(ymm_dsrc, reg_diff_src, nelems, offt);
};

if (calculate_diff_stats_) {
vpxor(ymm_dd_gamma, ymm_dd_gamma, ymm_dd_gamma);
vpxor(ymm_dd_gamma_x, ymm_dd_gamma_x, ymm_dd_gamma_x);

for (int i = 0; i < C_vecs; i++)
compute_dd_gammas(simd_w_, i * simd_w_ * sizeof(float));

reduce(ymm_dd_gamma);
reduce(ymm_dd_gamma_x);

for (int i = utils::rnd_dn(C_, simd_w_); i < C_; i++)
compute_dd_gammas(1, i * sizeof(float));

vmulps(ymm_dd_gamma_x, ymm_dd_gamma_x, ymm_inv_sqrtvar);
Xmm xmm_dd_gamma = Xmm(ymm_dd_gamma.getIdx());
vbroadcastss(ymm_dd_gamma, xmm_dd_gamma);
Xmm xmm_dd_gamma_x = Xmm(ymm_dd_gamma_x.getIdx());
vbroadcastss(ymm_dd_gamma_x, xmm_dd_gamma_x);
}

for (int i = 0; i < C_vecs; i++)
op(simd_w_, i * simd_w_ * sizeof(float));
compute_diff_src(simd_w_, i * simd_w_ * sizeof(float));

for (int i = utils::rnd_dn(C_, simd_w_); i < C_; i++)
op(1, i * sizeof(float));
compute_diff_src(1, i * sizeof(float));

postamble();

Expand All @@ -583,17 +625,17 @@ class diff_data_kernel_t : jit_generator {
Xbyak::Reg64 reg_diff_dst = rbx;
Xbyak::Reg64 reg_gamma = r11;
Xbyak::Reg64 reg_tmp = r10;
Xbyak::Reg64 reg_diff_gamma = r9;
Xbyak::Reg64 reg_diff_beta = r8;
Xbyak::Reg64 reg_dd_gamma = r9;
Xbyak::Reg64 reg_dd_gamma_x = r8;

Xbyak::Xmm xmm_tmp = Xbyak::Xmm(7);

Xbyak::Ymm ymm_C = Xbyak::Ymm(8);
Xbyak::Ymm ymm_gamma = Xbyak::Ymm(9);
Xbyak::Ymm ymm_inv_sqrtvar = Xbyak::Ymm(10);
Xbyak::Ymm ymm_dsrc = Xbyak::Ymm(11);
Xbyak::Ymm ymm_dgamma = Xbyak::Ymm(12);
Xbyak::Ymm ymm_dbeta = Xbyak::Ymm(13);
Xbyak::Ymm ymm_dd_gamma_x = Xbyak::Ymm(12);
Xbyak::Ymm ymm_dd_gamma = Xbyak::Ymm(13);
Xbyak::Ymm ymm_src = Xbyak::Ymm(14);
Xbyak::Ymm ymm_mean = Xbyak::Ymm(15);
};
Expand Down
71 changes: 45 additions & 26 deletions src/cpu/ref_layer_normalization.cpp
Expand Up @@ -160,41 +160,60 @@ void ref_layer_normalization_bwd_t<d_type>::execute_backward(
const bool use_scaleshift = pd()->use_scaleshift();
const bool calculate_diff_stats = !pd()->use_global_stats();

parallel_nd(C, [&](dim_t c) {
float gamma = use_scaleshift ? scaleshift[scaleshift_d.off(0, c)] : 1;
float diff_gamma = float(0);
float diff_beta = float(0);

for (dim_t n = 0; n < N; ++n) {
const size_t src_off = src_d.off_l(n * C + c),
diff_dst_off = diff_dst_d.off_l(n * C + c),
s_off = stat_d.off_l(n);
float inv_sqrt_variance
= static_cast<float>(1.0f / sqrtf(variance[s_off] + eps));
data_t dd = maybe_up_convert(diff_dst[diff_dst_off]);
diff_gamma += (maybe_up_convert(src[src_off]) - mean[s_off]) * dd
* inv_sqrt_variance;
diff_beta += dd;
}
if (diff_scaleshift) {
parallel_nd(C, [&](dim_t c) {
float diff_gamma = float(0);
float diff_beta = float(0);

for (dim_t n = 0; n < N; ++n) {
const size_t src_off = src_d.off_l(n * C + c),
diff_dst_off = diff_dst_d.off_l(n * C + c),
s_off = stat_d.off_l(n);
float inv_sqrt_variance = static_cast<float>(
1.0f / sqrtf(variance[s_off] + eps));
data_t dd = maybe_up_convert(diff_dst[diff_dst_off]);
diff_gamma += (maybe_up_convert(src[src_off]) - mean[s_off])
* dd * inv_sqrt_variance;
diff_beta += dd;
}

if (diff_scaleshift) {
diff_scaleshift[diff_scaleshift_d.off(0, c)] = diff_gamma;
diff_scaleshift[diff_scaleshift_d.off(1, c)] = diff_beta;
});
}

parallel_nd(N, [&](dim_t n) {
const size_t s_off = stat_d.off_l(n);
float inv_sqrt_variance
= static_cast<float>(1.0f / sqrtf(variance[s_off] + eps));
float dd_gamma = float(0), dd_gamma_x = float(0);
if (calculate_diff_stats) {
for (dim_t c = 0; c < C; ++c) {
float gamma = use_scaleshift
? scaleshift[scaleshift_d.off(0, c)]
: 1;
const size_t src_off = src_d.off_l(n * C + c),
diff_dst_off = diff_dst_d.off_l(n * C + c);
data_t dd = maybe_up_convert(diff_dst[diff_dst_off]);
dd_gamma += dd * gamma;
dd_gamma_x += dd * gamma
* (maybe_up_convert(src[src_off]) - mean[s_off]);
}
dd_gamma_x *= inv_sqrt_variance;
}

for (dim_t n = 0; n < N; ++n) {
for (dim_t c = 0; c < C; ++c) {
float gamma
= use_scaleshift ? scaleshift[scaleshift_d.off(0, c)] : 1;
const size_t src_off = src_d.off_l(n * C + c),
diff_src_off = diff_src_d.off_l(n * C + c),
diff_dst_off = diff_dst_d.off_l(n * C + c),
s_off = stat_d.off_l(n);
float inv_sqrt_variance
= static_cast<float>(1.0f / sqrtf(variance[s_off] + eps));
float v_diff_src = maybe_up_convert(diff_dst[diff_dst_off]);
diff_dst_off = diff_dst_d.off_l(n * C + c);
float v_diff_src = maybe_up_convert(diff_dst[diff_dst_off]) * gamma;
if (calculate_diff_stats)
v_diff_src -= diff_beta / C
v_diff_src -= dd_gamma / C
+ (maybe_up_convert(src[src_off]) - mean[s_off])
* diff_gamma * inv_sqrt_variance / C;
v_diff_src *= gamma * inv_sqrt_variance;
* dd_gamma_x * inv_sqrt_variance / C;
v_diff_src *= inv_sqrt_variance;
diff_src[diff_src_off] = v_diff_src;
}
});
Expand Down
1 change: 1 addition & 0 deletions src/ocl/jit_primitive_conf.hpp
Expand Up @@ -311,6 +311,7 @@ struct jit_bnorm_conf_t {
struct jit_lnorm_conf_t {
data_type_t data_type;

bool is_fwd;
int ndims;
int norm_axis;

Expand Down
20 changes: 9 additions & 11 deletions src/ocl/jit_ref_layer_normalization_kernel.hpp
Expand Up @@ -44,17 +44,13 @@ struct jit_ref_layer_normalization_kernel_t {
jln.dst_md_info = jit_memory_desc_info_t::create(dst_mdw);
jln.stat_md_info = jit_memory_desc_info_t::create(stat_mdw);

if (pd->is_fwd()) {
auto &dims = src_mdw.dims();
jln.gws_d[0] = dims[0];
jln.gws_d[1] = ndims > 2 ? dims[1] : 1;
jln.gws_d[2]
= ndims > 3 ? utils::array_product(&dims[2], ndims - 3) : 1;
} else {
jln.gws_d[0] = pd->norm_axis();
jln.gws_d[1] = 1;
jln.gws_d[2] = 1;
}
jln.is_fwd = pd->is_fwd();

auto &dims = src_mdw.dims();
jln.gws_d[0] = dims[0];
jln.gws_d[1] = ndims > 2 ? dims[1] : 1;
jln.gws_d[2]
= ndims > 3 ? utils::array_product(&dims[2], ndims - 3) : 1;

jln.use_scaleshift = pd->use_scaleshift();
jln.calculate_stats = !pd->stats_are_src();
Expand All @@ -73,6 +69,8 @@ struct jit_ref_layer_normalization_kernel_t {
kernel_ctx.define_int("USE_SCALESHIFT", jln.use_scaleshift);
kernel_ctx.define_int("CALCULATE_STATS", jln.calculate_stats);
kernel_ctx.define_int("SAVE_STATS", jln.save_stats);
kernel_ctx.define_int("IS_FWD", jln.is_fwd);
kernel_ctx.define_int("IS_BWD", !jln.is_fwd);

def_memory_desc_info(kernel_ctx, jln.src_md_info, "SRC");
def_memory_desc_info(kernel_ctx, jln.dst_md_info, "DST");
Expand Down

0 comments on commit c176ceb

Please sign in to comment.