Skip to content

Commit

Permalink
cpu: x64: brgemm ip fwd: adjust parallel ic reduction heuristic for f32
Browse files Browse the repository at this point in the history
  • Loading branch information
bhavani-subramanian authored and tprimak committed Aug 6, 2021
1 parent f5f25f4 commit 3e379b8
Showing 1 changed file with 25 additions and 21 deletions.
46 changes: 25 additions & 21 deletions src/cpu/x64/jit_brgemm_inner_product_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,15 +259,6 @@ status_t init_ip_conf_fwd(jit_brgemm_primitive_conf_t &jbgp,
const bool is_int8 = one_of(jbgp.src_dt, u8, s8) && jbgp.wei_dt == s8;
const bool is_f32 = everyone_is(f32, jbgp.src_dt, jbgp.wei_dt, jbgp.dst_dt);

// NOTE: comment about is_gigantic_shape is in get_os_block()
// The idea here is to use just mb and oc parallelism (i.e. w/o ic reduction)
// for "gigantic shapes" as there is a lot of parallelism without enabling
// ic parallelism.
const bool is_gigantic_shape
= jbgp.ic >= 9216 && jbgp.oc >= 4096 && jbgp.os >= 512;
const bool use_ic_reduction
= is_f32 && jbgp.ic > 1024 && !is_gigantic_shape;

const auto &p = attr.post_ops_;
jbgp.with_sum = p.find(primitive_kind::sum) != -1;
const int eltwise_ind = p.find(primitive_kind::eltwise);
Expand Down Expand Up @@ -314,18 +305,35 @@ status_t init_ip_conf_fwd(jit_brgemm_primitive_conf_t &jbgp,
jbgp.nb_os_blocking = saturate(1, nstl::min(8, jbgp.nb_os),
nstl::min(nstl::max(jbgp.oc / jbgp.os / 2, 1),
div_up(jbgp.nb_os * jbgp.nb_oc, 2 * jbgp.nthr)));

// For os > 256, compute all os blocks as a single chunk when performing
// IC reduction. Note that this condition is empirical
if (use_ic_reduction && jbgp.nb_os_blocking > 1)
jbgp.nb_os_blocking = jbgp.nb_os;
}

// NOTE: comment about is_gigantic_shape is in get_os_block()
const bool is_gigantic_shape = jbgp.oc >= 4096 && jbgp.os >= 512;
const int num_work_to_parallel = div_up(jbgp.nb_oc, jbgp.nb_oc_blocking)
* div_up(jbgp.nb_os, jbgp.nb_os_blocking);

// TODO: although the below heuristic produces good performance for fp32,
// num_work_to_parallel needs to compared with nthr (instead of nb_ic)
// and os_block needs some further tuning.

// Use parallel IC reduction for f32 if we have:
// * very large input channels
// * work amount in mb and oc dimensions is small compared to nb_ic
// * number of threads > 1
// * not a "gigantic shape" since it already has a lot of parallelism
// in mb and oc dimensions w/o enabling IC parallelism
const bool use_parallel_ic_reduction = is_f32 && jbgp.ic > 1024
&& num_work_to_parallel < jbgp.nb_ic && jbgp.nthr > 1
&& !is_gigantic_shape;

// For os > 256, compute all os blocks as a single chunk when performing
// IC reduction. Note that this condition is empirical
if (use_parallel_ic_reduction && jbgp.os > 256 && jbgp.nb_os_blocking > 1)
jbgp.nb_os_blocking = jbgp.nb_os;

jbgp.nb_ic_blocking = 1;
jbgp.nthr_ic_b = 1;
const int max_nb_ic_blocking = nstl::min(64, jbgp.nb_ic);
const int total_work = div_up(jbgp.nb_oc, jbgp.nb_oc_blocking)
* div_up(jbgp.nb_os, jbgp.nb_os_blocking);
if (IMPLICATION(!is_int8, jbgp.ic <= max_nb_ic_blocking * jbgp.ic_block)
&& everyone_is(1, jbgp.kw, jbgp.kh, jbgp.kd)
&& !jbgp.use_buffer_a) {
Expand All @@ -337,11 +345,7 @@ status_t init_ip_conf_fwd(jit_brgemm_primitive_conf_t &jbgp,
: rnd_dn(jbgp.ic, jbgp.ic_block);
jbgp.nb_ic_blocking = jbgp.nb_ic;
jbgp.gemm_batch_size = 1;
} else if (!jbgp.use_buffer_a
&& (use_ic_reduction || (is_f32 && total_work < 2 * jbgp.nthr))) {
// Use IC reduction if we have
// * very large input channels (OR)
// * work amount per thread is small
} else if (!jbgp.use_buffer_a && use_parallel_ic_reduction) {
const int min_chunk_sz = 16;
const int num_min_chunk_sz = div_up(jbgp.nb_ic, min_chunk_sz);
float reduce_work = 0.5f * num_min_chunk_sz * jbgp.nb_os
Expand Down

0 comments on commit 3e379b8

Please sign in to comment.