diff --git a/src/cpu/x64/jit_brgemm_conv_bwd_utils.cpp b/src/cpu/x64/jit_brgemm_conv_bwd_utils.cpp index a924980a2e7..0e71e9a347c 100644 --- a/src/cpu/x64/jit_brgemm_conv_bwd_utils.cpp +++ b/src/cpu/x64/jit_brgemm_conv_bwd_utils.cpp @@ -455,7 +455,6 @@ struct brg_blocking_t : public jit_brgemm_conv_conf_t { static constexpr int bench_iterations = 1; int sp, sp_block, nb_sp; - static int last_oc_block_size; void get_from_jcp(const jit_brgemm_conv_conf_t &jcp) { *this = jcp; } void save_to_jcp(jit_brgemm_conv_conf_t &jcp) const { jcp = *this; } @@ -531,7 +530,6 @@ struct brg_blocking_t : public jit_brgemm_conv_conf_t { unsigned brg_blocking_t::L1; unsigned brg_blocking_t::L2; -int brg_blocking_t::last_oc_block_size; float brg_blocking_t::io_k(dim_t src, dim_t wei, dim_t dst, float n, float pk, bool is_broadcast, bool is_shared) const { @@ -556,7 +554,7 @@ float brg_blocking_t::io_k(const loop_t loop, const array_in_loop_t arr, } void brg_blocking_t::select_oc_block() { - const auto padded_oc = last_oc_block_size * (is_oc_padded ? acc_simd_w : 1); + const auto padded_oc = vnni_block * (is_oc_padded ? acc_simd_w : 1); oc_block = (exec_type == exec_trans ? rnd_up(oc, padded_oc) : oc); nb_oc = utils::div_up(oc, oc_block); } @@ -570,7 +568,7 @@ status_t brg_blocking_t::estimate_brgemm_ur() { // Configure matrix sizes // for amx if oc_block != oc then we use exec_trans so K is oc_block - const auto padded_oc = last_oc_block_size * (is_oc_padded ? acc_simd_w : 1); + const auto padded_oc = vnni_block * (is_oc_padded ? acc_simd_w : 1); ocp = rnd_up(oc, padded_oc); @@ -641,7 +639,7 @@ status_t brg_blocking_t::get_brgemm_ur( brg_strides.stride_a = ngroups * oc_without_padding * (dilate_w + 1) * src_dsz; //weights are padded by ic_block and last_oc_block - brg_strides.stride_b = rnd_up(oc, last_oc_block_size) + brg_strides.stride_b = rnd_up(oc, vnni_block) * rnd_up(ic, ic_block) * wei_dsz; const auto strides_ptr = (brg_type == brgemm_strd) ? &brg_strides @@ -1501,8 +1499,7 @@ status_t init_jcp(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa, VDISPATCH_CONV_IC(!jcp.is_bf32, VERBOSE_UNSUPPORTED_DT); - brg_blocking_t::last_oc_block_size - = (jcp.wei_dt == f16 && isa == avx512_core_fp16) + jcp.vnni_block = (jcp.wei_dt == f16 && isa == avx512_core_fp16) ? 1 : data_type_vnni_granularity(jcp.wei_dt); @@ -1921,8 +1918,7 @@ status_t init_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa, jcp.copy_block_only = true; - const auto oc_padded_block - = jcp.acc_simd_w * brg_blocking_t::last_oc_block_size; + const auto oc_padded_block = jcp.acc_simd_w * jcp.vnni_block; jcp.is_oc_padded = one_of(jcp.wei_dt, bf16, f16, s8) && jcp.oc > oc_padded_block && is_amx(isa);