Skip to content

Commit

Permalink
cpu: gemm conv: disable unsupported binary po bcasts
Browse files Browse the repository at this point in the history
  • Loading branch information
tczeszun authored and tprimak committed Oct 26, 2022
1 parent 8b08a07 commit 9cf9c18
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions src/cpu/gemm_convolution.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2016-2021 Intel Corporation
* Copyright 2016-2022 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,7 @@
#ifndef CPU_GEMM_CONVOLUTION_HPP
#define CPU_GEMM_CONVOLUTION_HPP

#include "common/broadcast_strategy.hpp"
#include "common/c_types_map.hpp"
#include "common/memory_tracking.hpp"
#include "common/primitive.hpp"
Expand Down Expand Up @@ -71,7 +72,16 @@ struct gemm_convolution_fwd_t : public primitive_t {
for (int idx = 0; idx < po.len(); idx++) {
bool ok = utils::one_of(true, is_sum(idx), is_binary(idx),
is_eltwise(idx))
&& IMPLICATION(is_sum(idx), idx == 0);
&& IMPLICATION(is_sum(idx), idx == 0)
&& IMPLICATION(is_binary(idx),
dnnl::impl::get_rhs_arg_broadcasting_strategy(
po.entry_[idx].binary.src1_desc,
dst_md_,
{broadcasting_strategy_t::scalar,
broadcasting_strategy_t::
per_oc})
!= broadcasting_strategy_t::
unsupported);
if (!ok) return false;
}

Expand Down

0 comments on commit 9cf9c18

Please sign in to comment.