Skip to content

Commit

Permalink
x64: matmul, softmax: move format setting before post_ops_ok check
Browse files Browse the repository at this point in the history
  • Loading branch information
tczeszun authored and tprimak committed Apr 15, 2024
1 parent 756d3cf commit d39e1b7
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 7 deletions.
4 changes: 2 additions & 2 deletions src/cpu/x64/jit_uni_softmax.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2019-2023 Intel Corporation
* Copyright 2019-2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -135,13 +135,13 @@ struct jit_uni_softmax_fwd_t : public primitive_t {
| skip_mask_t::post_ops),
VERBOSE_UNSUPPORTED_ATTR);
VDISPATCH_SOFTMAX(attr_scales_ok(), VERBOSE_UNSUPPORTED_SCALES_CFG);
VDISPATCH_SOFTMAX(post_ops_ok(), VERBOSE_UNSUPPORTED_POSTOP);

VDISPATCH_SOFTMAX(set_default_formats() == status::success,
VERBOSE_UNSUPPORTED_TAG);
VDISPATCH_SOFTMAX(
attr_.set_default_formats(dst_md(0)) == status::success,
VERBOSE_UNSUPPORTED_TAG);
VDISPATCH_SOFTMAX(post_ops_ok(), VERBOSE_UNSUPPORTED_POSTOP);
VDISPATCH_SOFTMAX(
memory_desc_wrapper(src_md()).similar_to(
memory_desc_wrapper(dst_md()), true, false, 0),
Expand Down
8 changes: 3 additions & 5 deletions src/cpu/x64/matmul/brgemm_matmul_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1141,6 +1141,9 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
bgmmc.is_runtime_N = is_runtime_value(bgmmc.N);
bgmmc.is_runtime_K = is_runtime_value(bgmmc.K);

VCHECK_BG(bm_conf_utils.set_or_check_tags(src_md, dst_md, bias_md),
VERBOSE_UNSUPPORTED_TAG);
VCHECK_BG(attr.set_default_formats(&dst_md), VERBOSE_UNSUPPORTED_TAG);
VCONDCHECK_BG(post_ops_ok(bgmmc, attr, dst_d), VERBOSE_UNSUPPORTED_POSTOP);

// runtime values for M/N dimensions are only supported
Expand Down Expand Up @@ -1176,15 +1179,10 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
VCONDCHECK_BG(bgmmc.required_k_granularity > 0, VERBOSE_BLOCKING_FAIL);
bgmmc.wei_k_blk = data_type_vnni_simd_elems<avx512_core>(bgmmc.wei_dt);

VCHECK_BG(bm_conf_utils.set_or_check_tags(src_md, dst_md, bias_md),
VERBOSE_UNSUPPORTED_TAG);
VCHECK_BG(bm_conf_utils.set_or_check_B_tag(weights_md),
VERBOSE_UNSUPPORTED_TAG);

bgmmc.req_wei_vnni_downconvert = bm_conf_utils.wei_down_convert_to_vnni();

VCHECK_BG(attr.set_default_formats(&dst_md), VERBOSE_UNSUPPORTED_TAG);

bgmmc.wei_n_blk = get_default_n_block(bgmmc.wei_tag);

bgmmc.blocked_B = bm_conf_utils.get_blocked_B();
Expand Down

0 comments on commit d39e1b7

Please sign in to comment.