diff --git a/include/ideep/operators/matmul.hpp b/include/ideep/operators/matmul.hpp index 36b15ede..a1bb4e7e 100644 --- a/include/ideep/operators/matmul.hpp +++ b/include/ideep/operators/matmul.hpp @@ -335,11 +335,11 @@ struct matmul_forward : public dnnl::matmul, if (bias.is_empty()) { do_prepare_dynamic_quant( param, src, weights, bias, dst, weights_scales, - sum_coeff, attr, data_type::f32, aengine); + sum_coeff, attr, data_type::f32, alowp_kind, aengine); } else { do_prepare_dynamic_quant( param, src, weights, bias, dst, weights_scales, - sum_coeff, attr, data_type::f32, aengine); + sum_coeff, attr, data_type::f32, alowp_kind, aengine); } } else { if (bias.is_empty()) { @@ -379,7 +379,7 @@ struct matmul_forward : public dnnl::matmul, if (is_dynamic) { do_prepare_dynamic_quant( param, src, weights, dummy_bias, dst, weights_scales, - sum_coeff, attr, data_type::f32, aengine); + sum_coeff, attr, data_type::f32, alowp_kind, aengine); } else { do_prepare_static_quant( param, src, weights, dummy_bias, dst, @@ -606,11 +606,11 @@ struct matmul_forward : public dnnl::matmul, if (bias.is_empty()) { do_prepare_dynamic_quant( param, src, weights, bias, dst, weights_scales, - sum_coeff, attr, data_type::f32, aengine); + sum_coeff, attr, data_type::f32, alowp_kind, aengine); } else { do_prepare_dynamic_quant( param, src, weights, bias, dst, weights_scales, - sum_coeff, attr, data_type::f32, aengine); + sum_coeff, attr, data_type::f32, alowp_kind, aengine); } } else { if (bias.is_empty()) { @@ -710,7 +710,7 @@ struct matmul_forward : public dnnl::matmul, // prepare do_prepare_dynamic_quant(param, src, weights, bias, dst, weights_scales, sum_coeff, - attr, data_type::f32, aengine); + attr, data_type::f32, alowp_kind, aengine); // compute if (bias.is_empty()) { do_compute_dynamic_quant( @@ -1126,6 +1126,7 @@ struct matmul_forward : public dnnl::matmul, const float sum_coeff = 1.0f, // for post-op sum const attr_t& attr = attr_t(), const data_type dst_type = data_type::f32, + const lowp_kind alowp_kind = u8s8, const engine& aengine = engine::cpu_engine()) { /* This function does the following things: * - Determine expected descs of src/weight/dst @@ -1153,7 +1154,7 @@ struct matmul_forward : public dnnl::matmul, auto& weights_scales_in = weights.has_scale() ? weights.get_scale() : weights_scales; - auto src_data_type = data_type::u8; + auto src_data_type = alowp_kind == s8s8 ? data_type::s8 : data_type::u8; std::vector src_strides = (ndims == 3) ? std::vector({src_dims[1] * src_dims[2], src_dims[1], 1}) : std::vector({src_dims[1], 1}); @@ -1373,11 +1374,11 @@ struct matmul_forward : public dnnl::matmul, auto& expected_src = reorder_src ? src.reorder_if_differ_in(expected_src_desc) : src; - + auto& expected_other = reorder_src ? other.reorder_if_differ_in(expected_dst_desc) : other; - + auto& expected_weights = reorder_weight ? weights.reorder_if_differ_in(expected_wei_desc) : weights;