Skip to content

Commit

Permalink
gpu: ocl: enable deconv binary dims, types
Browse files Browse the repository at this point in the history
  • Loading branch information
kealan-barbieri authored and karturov committed Nov 18, 2022
1 parent 26f97dc commit dd54d39
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
3 changes: 2 additions & 1 deletion src/gpu/ocl/ref_convolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ struct ref_convolution_bwd_data_t : public gpu_primitive_t {
&& !memory_desc_ndims_ok(diff_src_md(), diff_dst_md())
&& this->set_default_formats()
&& attr()->has_default_values(attr_skip_mask)
&& post_ops_with_binary_ok(attr(), dst_md()->data_type)
&& post_ops_with_binary_ok(
attr(), dst_md()->data_type, ndims())
&& zero_points_ok(attr())
&& IMPLICATION(!attr()->output_scales_.has_default_values(),
utils::one_of(diff_dst_md()->data_type, u8, s8)
Expand Down
7 changes: 3 additions & 4 deletions src/gpu/ocl/ref_deconvolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ struct ref_deconvolution_fwd_t : public gpu_primitive_t {
&& desc()->alg_kind == alg_kind::deconvolution_direct
&& attr()->has_default_values(attr_skip_mask)
&& post_ops_with_binary_ok(
attr(), desc()->dst_desc.data_type)
attr(), desc()->dst_desc.data_type, ndims())
&& (utils::everyone_is(data_type::f32,
desc()->src_desc.data_type,
desc()->weights_desc.data_type,
Expand All @@ -146,9 +146,8 @@ struct ref_deconvolution_fwd_t : public gpu_primitive_t {
|| (desc()->weights_desc.data_type == data_type::s8
&& utils::one_of(desc()->src_desc.data_type,
data_type::u8, data_type::s8)
&& utils::one_of(desc()->dst_desc.data_type,
data_type::u8, data_type::s8,
data_type::s32, data_type::f32)));
&& desc()->dst_desc.data_type
!= data_type::f64));
if (ok) {
CHECK(init_convolution(engine));
if (weights_md_.format_kind == format_kind::any)
Expand Down

0 comments on commit dd54d39

Please sign in to comment.