Skip to content

Commit

Permalink
x64: brgemm bwd_d strided: remove static last_oc_block_size
Browse files Browse the repository at this point in the history
  • Loading branch information
ankalinin committed Apr 9, 2024
1 parent e4737d9 commit 0184044
Showing 1 changed file with 5 additions and 9 deletions.
14 changes: 5 additions & 9 deletions src/cpu/x64/jit_brgemm_conv_bwd_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,6 @@ struct brg_blocking_t : public jit_brgemm_conv_conf_t {
static constexpr int bench_iterations = 1;

int sp, sp_block, nb_sp;
static int last_oc_block_size;

void get_from_jcp(const jit_brgemm_conv_conf_t &jcp) { *this = jcp; }
void save_to_jcp(jit_brgemm_conv_conf_t &jcp) const { jcp = *this; }
Expand Down Expand Up @@ -531,7 +530,6 @@ struct brg_blocking_t : public jit_brgemm_conv_conf_t {

unsigned brg_blocking_t::L1;
unsigned brg_blocking_t::L2;
int brg_blocking_t::last_oc_block_size;

float brg_blocking_t::io_k(dim_t src, dim_t wei, dim_t dst, float n, float pk,
bool is_broadcast, bool is_shared) const {
Expand All @@ -556,7 +554,7 @@ float brg_blocking_t::io_k(const loop_t loop, const array_in_loop_t arr,
}

void brg_blocking_t::select_oc_block() {
const auto padded_oc = last_oc_block_size * (is_oc_padded ? acc_simd_w : 1);
const auto padded_oc = vnni_block * (is_oc_padded ? acc_simd_w : 1);
oc_block = (exec_type == exec_trans ? rnd_up(oc, padded_oc) : oc);
nb_oc = utils::div_up(oc, oc_block);
}
Expand All @@ -570,7 +568,7 @@ status_t brg_blocking_t::estimate_brgemm_ur() {

// Configure matrix sizes
// for amx if oc_block != oc then we use exec_trans so K is oc_block
const auto padded_oc = last_oc_block_size * (is_oc_padded ? acc_simd_w : 1);
const auto padded_oc = vnni_block * (is_oc_padded ? acc_simd_w : 1);

ocp = rnd_up(oc, padded_oc);

Expand Down Expand Up @@ -641,7 +639,7 @@ status_t brg_blocking_t::get_brgemm_ur(
brg_strides.stride_a = ngroups * oc_without_padding
* (dilate_w + 1) * src_dsz;
//weights are padded by ic_block and last_oc_block
brg_strides.stride_b = rnd_up(oc, last_oc_block_size)
brg_strides.stride_b = rnd_up(oc, vnni_block)
* rnd_up(ic, ic_block) * wei_dsz;
const auto strides_ptr = (brg_type == brgemm_strd)
? &brg_strides
Expand Down Expand Up @@ -1501,8 +1499,7 @@ status_t init_jcp(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa,

VDISPATCH_CONV_IC(!jcp.is_bf32, VERBOSE_UNSUPPORTED_DT);

brg_blocking_t::last_oc_block_size
= (jcp.wei_dt == f16 && isa == avx512_core_fp16)
jcp.vnni_block = (jcp.wei_dt == f16 && isa == avx512_core_fp16)
? 1
: data_type_vnni_granularity(jcp.wei_dt);

Expand Down Expand Up @@ -1921,8 +1918,7 @@ status_t init_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa,

jcp.copy_block_only = true;

const auto oc_padded_block
= jcp.acc_simd_w * brg_blocking_t::last_oc_block_size;
const auto oc_padded_block = jcp.acc_simd_w * jcp.vnni_block;
jcp.is_oc_padded = one_of(jcp.wei_dt, bf16, f16, s8)
&& jcp.oc > oc_padded_block && is_amx(isa);

Expand Down

0 comments on commit 0184044

Please sign in to comment.