Skip to content

Commit

Permalink
cpu: update pool implementations with nthr_ member
Browse files Browse the repository at this point in the history
  • Loading branch information
dzarukin committed Dec 20, 2021
1 parent 57b1e7a commit 72b54de
Show file tree
Hide file tree
Showing 9 changed files with 46 additions and 27 deletions.
6 changes: 4 additions & 2 deletions src/cpu/nchw_pooling.cpp
Expand Up @@ -608,8 +608,10 @@ status_t nchw_pooling_bwd_t<data_type::bf16>::execute_backward(

dim_t c_blk = pd()->channel_block_size_;
dim_t c_blk_tail = C % c_blk;
const int nthr = pd()->nthr_;

if (alg == alg_kind::pooling_max) {
parallel_nd_ext(0, MB, utils::div_up(C, c_blk),
parallel_nd_ext(nthr, MB, utils::div_up(C, c_blk),
[&](int ithr, int, dim_t mb, dim_t cb) {
bool is_last_c_block
= c_blk_tail > 0 && (cb + 1) * c_blk > C;
Expand Down Expand Up @@ -647,7 +649,7 @@ status_t nchw_pooling_bwd_t<data_type::bf16>::execute_backward(
diff_src_fp32, src_sp_size * curr_c_block);
});
} else {
parallel_nd_ext(0, MB, utils::div_up(C, c_blk),
parallel_nd_ext(nthr, MB, utils::div_up(C, c_blk),
[&](int ithr, int, dim_t mb, dim_t cb) {
bool is_last_c_block
= c_blk_tail > 0 && (cb + 1) * c_blk > C;
Expand Down
10 changes: 5 additions & 5 deletions src/cpu/nchw_pooling.hpp
Expand Up @@ -139,27 +139,28 @@ struct nchw_pooling_bwd_t : public primitive_t {
ws_md_ = *hint_fwd_pd_->workspace_md();
}

nthr_ = dnnl_get_max_threads();
calculate_channel_block_size();
init_scratchpad();

return status::success;
}

dim_t channel_block_size_;
int nthr_; // To not exceed the limit in execute used for set up.

private:
void init_scratchpad() {
using namespace memory_tracking::names;
if (diff_dst_md()->data_type == data_type::bf16) {
size_t dst_sz_ = OD() * OH() * OW();
size_t src_sz_ = ID() * IH() * IW();
size_t nthrs = dnnl_get_max_threads();
auto scratchpad = scratchpad_registry().registrar();

scratchpad.template book<float>(key_pool_src_bf16cvt,
src_sz_ * nthrs * channel_block_size_);
src_sz_ * nthr_ * channel_block_size_);
scratchpad.template book<float>(key_pool_dst_bf16cvt,
dst_sz_ * nthrs * channel_block_size_);
dst_sz_ * nthr_ * channel_block_size_);
}
}

Expand All @@ -169,8 +170,7 @@ struct nchw_pooling_bwd_t : public primitive_t {
// spatial
dim_t dst_sz_ = OD() * OH() * OW();
dim_t src_sz_ = ID() * IH() * IW();
dim_t nthrs = dnnl_get_max_threads();
dim_t C_per_thr = nstl::min(MB() * IC() / nthrs, IC());
dim_t C_per_thr = nstl::min(MB() * IC() / nthr_, IC());
const dim_t max_block_size
= platform::get_per_core_cache_size(1) / 2;
dim_t data_size_per_ch = (dst_sz_ + src_sz_) * 6; // f32 + bf16
Expand Down
6 changes: 4 additions & 2 deletions src/cpu/nhwc_pooling.cpp
Expand Up @@ -372,8 +372,9 @@ status_t nhwc_pooling_fwd_t<data_type::bf16>::execute_forward(
return OSP * OC * mb + OSP * oc + SP * od + OW * oh + ow;
};
const bool are_postops_set = !(pd()->attr()->post_ops_.entry_.empty());
const int nthr = pd()->nthr_;

parallel_nd_ext(0, MB, OD, OH, OW,
parallel_nd_ext(nthr, MB, OD, OH, OW,
[&](int ithr, int, dim_t mb, dim_t od, dim_t oh, dim_t ow) {
const size_t dst_offset_init = strided_offset(mb, dst_n_stride,
od, dst_d_stride, oh, dst_h_stride, ow, dst_w_stride);
Expand Down Expand Up @@ -672,8 +673,9 @@ status_t nhwc_pooling_bwd_t<data_type::bf16>::execute_backward(
auto apply_offset = [=](dim_t index, dim_t offset) {
return (index > offset) ? index - offset : 0;
};
const int nthr = pd()->nthr_;

parallel_nd_ext(0, MB, ID, IH, IW,
parallel_nd_ext(nthr, MB, ID, IH, IW,
[&](int ithr, int, dim_t mb, dim_t id, dim_t ih, dim_t iw) {
size_t src_offset_init = strided_offset(mb, diff_src_n_stride,
id, diff_src_d_stride, ih, diff_src_h_stride, iw,
Expand Down
10 changes: 8 additions & 2 deletions src/cpu/nhwc_pooling.hpp
Expand Up @@ -73,16 +73,19 @@ struct nhwc_pooling_fwd_t : public primitive_t {
init_default_ws();
}

nthr_ = dnnl_get_max_threads();
init_scratchpad();

return status::success;
}

int nthr_; // To not exceed the limit in execute used for set up.

private:
void init_scratchpad() {
using namespace memory_tracking::names;
if (src_md()->data_type == data_type::bf16) {
const size_t bf16cvt_sz_ = IC() * dnnl_get_max_threads();
const size_t bf16cvt_sz_ = IC() * nthr_;
auto scratchpad = scratchpad_registry().registrar();
scratchpad.template book<float>(
key_pool_src_bf16cvt, bf16cvt_sz_);
Expand Down Expand Up @@ -148,16 +151,19 @@ struct nhwc_pooling_bwd_t : public primitive_t {
if (!compare_ws(hint_fwd_pd_)) return status::unimplemented;
}

nthr_ = dnnl_get_max_threads();
init_scratchpad();

return status::success;
}

int nthr_; // To not exceed the limit in execute used for set up.

private:
void init_scratchpad() {
using namespace memory_tracking::names;
if (diff_src_md()->data_type == data_type::bf16) {
size_t bf16cvt_sz_ = IC() * dnnl_get_max_threads();
size_t bf16cvt_sz_ = IC() * nthr_;
auto scratchpad = scratchpad_registry().registrar();
scratchpad.template book<float>(
key_pool_src_bf16cvt, bf16cvt_sz_);
Expand Down
1 change: 1 addition & 0 deletions src/cpu/x64/jit_primitive_conf.hpp
Expand Up @@ -687,6 +687,7 @@ struct jit_pool_conf_t {
bool with_postops;
bool with_eltwise;
bool with_binary;
int nthr;
};

struct jit_pool_call_s {
Expand Down
6 changes: 3 additions & 3 deletions src/cpu/x64/jit_uni_pool_kernel.cpp
Expand Up @@ -78,8 +78,7 @@ jit_uni_pool_kernel<isa>::jit_uni_pool_kernel(

template <cpu_isa_t isa>
status_t jit_uni_pool_kernel<isa>::init_conf(jit_pool_conf_t &jpp,
memory_tracking::registrar_t &scratchpad, const pooling_pd_t *ppd,
int nthreads) {
memory_tracking::registrar_t &scratchpad, const pooling_pd_t *ppd) {

const auto &pd = *ppd->desc();
const memory_desc_wrapper src_d(
Expand All @@ -89,6 +88,7 @@ status_t jit_uni_pool_kernel<isa>::init_conf(jit_pool_conf_t &jpp,

const int ndims = src_d.ndims();

jpp.nthr = dnnl_get_max_threads();
jpp.is_training = pd.prop_kind == prop_kind::forward_training;
jpp.is_backward = pd.prop_kind == prop_kind::backward_data;

Expand Down Expand Up @@ -250,7 +250,7 @@ status_t jit_uni_pool_kernel<isa>::init_conf(jit_pool_conf_t &jpp,
? (ndims == 5 && jpp.simple_alg ? jpp.od : 1)
: (ndims == 5 ? jpp.od : jpp.oh);
work *= jpp.mb * nb2_c;
auto eff = (float)work / utils::rnd_up(work, nthreads);
auto eff = (float)work / utils::rnd_up(work, jpp.nthr);
if (eff > best_eff) {

best_eff = eff;
Expand Down
3 changes: 1 addition & 2 deletions src/cpu/x64/jit_uni_pool_kernel.hpp
Expand Up @@ -46,8 +46,7 @@ struct jit_uni_pool_kernel : public jit_generator {
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_pool_kernel)

static status_t init_conf(jit_pool_conf_t &jbp,
memory_tracking::registrar_t &scratchpad, const pooling_pd_t *ppd,
int nthreads);
memory_tracking::registrar_t &scratchpad, const pooling_pd_t *ppd);

private:
using Xmm = Xbyak::Xmm;
Expand Down
22 changes: 15 additions & 7 deletions src/cpu/x64/jit_uni_pooling.cpp
Expand Up @@ -616,6 +616,8 @@ void jit_uni_pooling_fwd_t<isa, d_type>::execute_forward(const data_t *src,
(*kernel_)(&arg);
};

const int nthr = jpp.nthr;

if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) {
const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc);
parallel_nd(jpp.mb, jpp.oh, nb2_c, [&](dim_t n, dim_t oh, dim_t b2_c) {
Expand All @@ -626,7 +628,7 @@ void jit_uni_pooling_fwd_t<isa, d_type>::execute_forward(const data_t *src,
} else {
if (trans_src || trans_dst) {
// ncsp format
parallel_nd_ext(0, jpp.mb, jpp.nb_c,
parallel_nd_ext(nthr, jpp.mb, jpp.nb_c,
[&](dim_t ithr, dim_t nthr, dim_t n, dim_t b_c) {
if (trans_src)
transpose_facade.execute_transpose_input(
Expand All @@ -639,7 +641,7 @@ void jit_uni_pooling_fwd_t<isa, d_type>::execute_forward(const data_t *src,
});
} else {
// nChw16c, nChw8c format
parallel(0, [&](dim_t ithr, dim_t nthr) {
parallel(nthr, [&](dim_t ithr, dim_t nthr) {
dim_t work_amount = jpp.mb * jpp.nb_c * jpp.oh;
if (ithr >= work_amount) return;

Expand Down Expand Up @@ -742,6 +744,8 @@ void jit_uni_pooling_fwd_t<isa, d_type>::execute_forward_3d(const data_t *src,
(*kernel_)(&arg);
};

const int nthr = jpp.nthr;

if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) {
const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc);
parallel_nd(jpp.mb, jpp.od, nb2_c, [&](dim_t n, dim_t od, dim_t b2_c) {
Expand All @@ -761,7 +765,7 @@ void jit_uni_pooling_fwd_t<isa, d_type>::execute_forward_3d(const data_t *src,
});
} else {
if (trans_src || trans_dst) {
parallel_nd_ext(0, jpp.mb, jpp.nb_c,
parallel_nd_ext(nthr, jpp.mb, jpp.nb_c,
[&](dim_t ithr, dim_t nthr, dim_t n, dim_t b_c) {
if (trans_src)
transpose_facade.execute_transpose_input(
Expand Down Expand Up @@ -954,7 +958,9 @@ void jit_uni_pooling_bwd_t<isa, d_type>::execute_backward(
transpose_facade.execute_transpose_output(ithr, n, b_c);
};

parallel(0, [&](int ithr, int nthr) {
const int nthr = jpp.nthr;

parallel(nthr, [&](int ithr, int nthr) {
const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc);
const std::size_t work_amount
= static_cast<std::size_t>(jpp.mb) * nb2_c;
Expand Down Expand Up @@ -1104,6 +1110,8 @@ void jit_uni_pooling_bwd_t<isa, d_type>::execute_backward_3d(
}
};

const int nthr = jpp.nthr;

if (jpp.simple_alg) {
if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) {
const dim_t nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc);
Expand All @@ -1117,7 +1125,7 @@ void jit_uni_pooling_bwd_t<isa, d_type>::execute_backward_3d(
} else {
assert(jpp.ur_bc == 1);
if (trans_src || trans_dst) {
parallel_nd_ext(0, jpp.mb, jpp.nb_c,
parallel_nd_ext(nthr, jpp.mb, jpp.nb_c,
[&](dim_t ithr, dim_t nthr, dim_t n, dim_t b_c) {
if (trans_src)
transpose_facade.execute_transpose_input(
Expand Down Expand Up @@ -1150,7 +1158,7 @@ void jit_uni_pooling_bwd_t<isa, d_type>::execute_backward_3d(
if (!trans_src) {
const size_t chunk_size
= (size_t)jpp.id * jpp.ih * jpp.iw * jpp.c_block;
parallel_nd_ext(0, jpp.mb, jpp.nb_c,
parallel_nd_ext(nthr, jpp.mb, jpp.nb_c,
[&](dim_t ithr, dim_t nthr, dim_t n, dim_t b_c) {
const size_t offset
= ((size_t)n * jpp.nb_c + b_c) * chunk_size;
Expand All @@ -1163,7 +1171,7 @@ void jit_uni_pooling_bwd_t<isa, d_type>::execute_backward_3d(

const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc);
if (trans_src || trans_dst) {
parallel_nd_ext(0, jpp.mb, nb2_c,
parallel_nd_ext(nthr, jpp.mb, nb2_c,
[&](dim_t ithr, dim_t nthr, dim_t n, dim_t b2_c) {
const dim_t b_c = b2_c * jpp.ur_bc;

Expand Down
9 changes: 5 additions & 4 deletions src/cpu/x64/jit_uni_pooling.hpp
Expand Up @@ -66,8 +66,7 @@ struct jit_uni_pooling_fwd_t : public primitive_t {
init_default_ws();

auto scratchpad = scratchpad_registry().registrar();
CHECK(jit_uni_pool_kernel<isa>::init_conf(
jpp_, scratchpad, this, dnnl_get_max_threads()));
CHECK(jit_uni_pool_kernel<isa>::init_conf(jpp_, scratchpad, this));

return status::success;
}
Expand Down Expand Up @@ -132,9 +131,11 @@ struct jit_uni_pooling_bwd_t : public primitive_t {
init_default_ws();
if (!compare_ws(hint_fwd_pd_)) return status::unimplemented;
}

auto scratchpad = scratchpad_registry().registrar();
return jit_uni_pool_kernel<isa>::init_conf(
jpp_, scratchpad, this, dnnl_get_max_threads());
CHECK(jit_uni_pool_kernel<isa>::init_conf(jpp_, scratchpad, this));

return status::success;
}

jit_pool_conf_t jpp_;
Expand Down

0 comments on commit 72b54de

Please sign in to comment.