Skip to content

Commit

Permalink
x64: brgemm unrolled kernel: update output prefetching
Browse files Browse the repository at this point in the history
  • Loading branch information
ankalinin committed Dec 27, 2023
1 parent b5f916e commit fa43640
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 8 deletions.
3 changes: 3 additions & 0 deletions src/cpu/x64/brgemm/brgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,9 @@ status_t brgemm_desc_set_attr(brgemm_t *brg, const brgemm_attr_t &brgattr) {
if (brgattr.hint_innermost_loop != brgemm_innermost_undef)
brg->innermost_loop = brgattr.hint_innermost_loop;

if (brgattr.hint_prefetching == brgemm_kernel_prefetching_t::brgemm_prf0
&& brg->prfC.dist0 < 0)
brg->prfC.dist0 = 0;
if (brgattr.hint_prefetching == brgemm_kernel_prefetching_t::brgemm_prf1
&& brg->prfC.dist1 < 0)
brg->prfC.dist1 = 0;
Expand Down
1 change: 0 additions & 1 deletion src/cpu/x64/brgemm/jit_brgemm_amx_uker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1014,7 +1014,6 @@ void jit_brgemm_amx_uker_base_t::uni_prefetch(
if (for_write) {
switch (pft) {
case brgemm_prf0: prefetchw(addr); break;
case brgemm_prf1: prefetchwt1(addr); break;
default: break;
}
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/x64/jit_brgemm_conv_bwd_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1892,7 +1892,7 @@ status_t init_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa,
if (is_amx(isa) && (/* heuristic */ jcp.kw_sets == 1 && jcp.iw < 256)) {
jcp.use_M_mask = 0;

jcp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf1;
jcp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf0;

// assuming 2x2 decomposition in amx brgemm kernel
// and overlap of input by kw
Expand Down
8 changes: 4 additions & 4 deletions src/cpu/x64/jit_brgemm_conv_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2032,7 +2032,7 @@ status_t init_conf(jit_brgemm_conv_conf_t &jcp, bool use_inversion,
jcp.use_M_mask = jcp.is_os_blocking ? 2 : 0;
jcp.use_uker = true;
jcp.use_interleave_stores = true;
jcp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf1;
jcp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf0;
// assuming 2x2 decomposition in amx brgemm kernel
// and overlap of input by kw
const auto bd_blocking = 2 * jcp.amx_h;
Expand Down Expand Up @@ -2067,7 +2067,7 @@ status_t init_conf(jit_brgemm_conv_conf_t &jcp, bool use_inversion,
if (is_amx(isa) && jcp.ow < (8 * 1024)) {
jcp.use_uker = true;
jcp.use_interleave_stores = true;
jcp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf1;
jcp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf0;
}

try_exec_type_res = try_exec_type();
Expand Down Expand Up @@ -2339,7 +2339,7 @@ status_t init_1x1_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa,
return status::unimplemented;

if (jcp.use_uker)
jcp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf1;
jcp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf0;
if (!jcp.wei_plain)
CHECK(pick_tags(jcp, src_md, weights_md, dst_md, bias_md));
CHECK(attr.set_default_formats(&dst_md));
Expand Down Expand Up @@ -3073,7 +3073,7 @@ status_t init_conf_bwd_w(jit_brgemm_conv_conf_t &jcp,
jcp.od_block = utils::saturate(1, jcp.od, od_block_limit);

jcp.use_interleave_stores = false;
jcp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf1;
jcp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf0;
jcp.amx_tile_load_xx = false;

if (one_of(jcp.harness, harness_2d_reduction, harness_3d_reduction)) {
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/x64/jit_brgemm_inner_product_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1417,7 +1417,7 @@ status_t jit_brgemm_ip_conf_t::init_conf_base(cpu_isa_t isa,
jbgp.use_uker = true;
jbgp.use_interleave_stores = jbgp.use_uker;
if (jbgp.use_uker)
jbgp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf1;
jbgp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf0;
CHECK(set_or_check_tags());
CHECK(attr.set_default_formats(&dst_md));

Expand Down
2 changes: 1 addition & 1 deletion src/cpu/x64/matmul/brgemm_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
brgattr.hint_expected_B_size = vN * vK * bs;
brgattr.hint_expected_C_size = vM * vN * bs;
brgattr.hint_innermost_loop = brgemm_innermost_undef;
brgattr.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf1;
brgattr.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf0;
}

CHECK(brgemm_desc_set_attr(&brg, brgattr));
Expand Down

0 comments on commit fa43640

Please sign in to comment.