Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions include/ideep/operators/matmul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,11 +335,11 @@ struct matmul_forward : public dnnl::matmul,
if (bias.is_empty()) {
do_prepare_dynamic_quant</*with_bias=*/false>(
param, src, weights, bias, dst, weights_scales,
sum_coeff, attr, data_type::f32, aengine);
sum_coeff, attr, data_type::f32, alowp_kind, aengine);
} else {
do_prepare_dynamic_quant</*with_bias=*/true>(
param, src, weights, bias, dst, weights_scales,
sum_coeff, attr, data_type::f32, aengine);
sum_coeff, attr, data_type::f32, alowp_kind, aengine);
}
} else {
if (bias.is_empty()) {
Expand Down Expand Up @@ -379,7 +379,7 @@ struct matmul_forward : public dnnl::matmul,
if (is_dynamic) {
do_prepare_dynamic_quant</*with_bias=*/false>(
param, src, weights, dummy_bias, dst, weights_scales,
sum_coeff, attr, data_type::f32, aengine);
sum_coeff, attr, data_type::f32, alowp_kind, aengine);
} else {
do_prepare_static_quant</*with_bias=*/false>(
param, src, weights, dummy_bias, dst,
Expand Down Expand Up @@ -606,11 +606,11 @@ struct matmul_forward : public dnnl::matmul,
if (bias.is_empty()) {
do_prepare_dynamic_quant</*with_bias=*/false>(
param, src, weights, bias, dst, weights_scales,
sum_coeff, attr, data_type::f32, aengine);
sum_coeff, attr, data_type::f32, alowp_kind, aengine);
} else {
do_prepare_dynamic_quant</*with_bias=*/true>(
param, src, weights, bias, dst, weights_scales,
sum_coeff, attr, data_type::f32, aengine);
sum_coeff, attr, data_type::f32, alowp_kind, aengine);
}
} else {
if (bias.is_empty()) {
Expand Down Expand Up @@ -710,7 +710,7 @@ struct matmul_forward : public dnnl::matmul,
// prepare
do_prepare_dynamic_quant<with_bias>(param, src, weights,
bias, dst, weights_scales, sum_coeff,
attr, data_type::f32, aengine);
attr, data_type::f32, alowp_kind, aengine);
// compute
if (bias.is_empty()) {
do_compute_dynamic_quant</*with_bias=*/false, reorder_weight>(
Expand Down Expand Up @@ -1126,6 +1126,7 @@ struct matmul_forward : public dnnl::matmul,
const float sum_coeff = 1.0f, // for post-op sum
const attr_t& attr = attr_t(),
const data_type dst_type = data_type::f32,
const lowp_kind alowp_kind = u8s8,
const engine& aengine = engine::cpu_engine()) {
/* This function does the following things:
* - Determine expected descs of src/weight/dst
Expand Down Expand Up @@ -1153,7 +1154,7 @@ struct matmul_forward : public dnnl::matmul,
auto& weights_scales_in =
weights.has_scale() ? weights.get_scale() : weights_scales;

auto src_data_type = data_type::u8;
auto src_data_type = alowp_kind == s8s8 ? data_type::s8 : data_type::u8;
std::vector<int64_t> src_strides = (ndims == 3) ?
std::vector<int64_t>({src_dims[1] * src_dims[2], src_dims[1], 1}) :
std::vector<int64_t>({src_dims[1], 1});
Expand Down Expand Up @@ -1373,11 +1374,11 @@ struct matmul_forward : public dnnl::matmul,
auto& expected_src = reorder_src ?
src.reorder_if_differ_in(expected_src_desc) :
src;

auto& expected_other = reorder_src ?
other.reorder_if_differ_in(expected_dst_desc) :
other;

auto& expected_weights = reorder_weight ?
weights.reorder_if_differ_in(expected_wei_desc) :
weights;
Expand Down