Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ bool IsQDQPairSupported(
Initializer dq_scale(*dq_scale_tensor_proto, model_path);

if (q_zp.data_type() != dq_zp.data_type() ||
q_scale.data_type() != q_scale.data_type() ||
q_scale.data_type() != dq_scale.data_type() ||
!SpanEq(q_zp.DataAsByteSpan(), dq_zp.DataAsByteSpan())) {
return false;
}
Expand Down
24 changes: 12 additions & 12 deletions onnxruntime/test/optimizer/graph_transform_test_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -304,17 +304,17 @@ class ModelTestBuilder {
return AddNode("Conv", input_args, {output_arg});
}

template <typename T>
typename std::enable_if<IsTypeQuantLinearCompatible<T>::value, Node&>::type
template <typename ZpType, typename ScaleType = float>
typename std::enable_if<IsTypeQuantLinearCompatible<ZpType>::value, Node&>::type
AddQuantizeLinearNode(NodeArg* input_arg,
float input_scale,
T input_zero_point,
ScaleType input_scale,
ZpType input_zero_point,
NodeArg* output_arg,
bool use_ms_domain = false) {
std::vector<NodeArg*> input_args;
input_args.push_back(input_arg);
input_args.push_back(MakeScalarInitializer<float>(input_scale));
input_args.push_back(MakeScalarInitializer<T>(input_zero_point));
input_args.push_back(MakeScalarInitializer<ScaleType>(input_scale));
input_args.push_back(MakeScalarInitializer<ZpType>(input_zero_point));

std::string domain = use_ms_domain ? kMSDomain : "";
return AddNode("QuantizeLinear", input_args, {output_arg}, domain);
Expand Down Expand Up @@ -382,17 +382,17 @@ class ModelTestBuilder {
NodeArg* output_arg,
bool use_ms_domain = false);

template <typename T>
typename std::enable_if<IsTypeDequantLinearCompatible<T>::value, Node&>::type
template <typename ZpType, typename ScaleType = float>
typename std::enable_if<IsTypeDequantLinearCompatible<ZpType>::value, Node&>::type
AddDequantizeLinearNode(NodeArg* input_arg,
float input_scale,
T input_zero_point,
ScaleType input_scale,
ZpType input_zero_point,
NodeArg* output_arg,
bool use_ms_domain = false) {
std::vector<NodeArg*> input_args;
input_args.push_back(input_arg);
input_args.push_back(MakeScalarInitializer<float>(input_scale));
input_args.push_back(MakeScalarInitializer<T>(input_zero_point));
input_args.push_back(MakeScalarInitializer<ScaleType>(input_scale));
input_args.push_back(MakeScalarInitializer<ZpType>(input_zero_point));

std::string domain = use_ms_domain ? kMSDomain : "";
return AddNode("DequantizeLinear", input_args, {output_arg}, domain);
Expand Down
46 changes: 46 additions & 0 deletions onnxruntime/test/optimizer/qdq_transformer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1521,6 +1521,52 @@ TEST(QDQTransformerTests, DoubleQDQ) {
bad_float_point, good_float_point_2, true); // Use com.microsoft QDQ ops
}

// Verifies fix for bug in the IsQDQPairSupported utility function, which is used by
// various optimizers such as DoubleQDQPairsRemover. The bug causes an exception when
// IsQDQPairIsSupported() is called with a Q -> DQ sequence that uses different scale types.
TEST(QDQTransformerTests, DoubleQDQPairsRemover_Bug_RejectDifferentScaleTypes) {
// Function that builds a model with a QDQ nodes that use different scale data types:
// input_fp32 -> Q(scale_fp32) -> DQ(scale_fp16) -> Mul(fp16) -> Q(scale_fp16) -> DQ(scale_fp32) -> output_fp32
auto build_model_func = [](ModelTestBuilder& builder) {
auto* input0_arg = builder.MakeInput<float>({1, 1, 1, 3}, {1.0f, 2.0f, 3.0f});
NodeArg* q1_output = builder.MakeIntermediate();
NodeArg* dq1_output = builder.MakeIntermediate();
NodeArg* mul_output = builder.MakeIntermediate();
NodeArg* q2_output = builder.MakeIntermediate();
NodeArg* dq2_output = builder.MakeOutput();
NodeArg* const_arg = builder.MakeScalarInitializer(MLFloat16(10.0f));

const float scale_fp32 = 1.0f;
const MLFloat16 scale_fp16 = MLFloat16(scale_fp32);
const uint8_t zp = 127;

builder.AddQuantizeLinearNode<uint8_t, float>(input0_arg, scale_fp32, zp, q1_output);
builder.AddDequantizeLinearNode<uint8_t, MLFloat16>(q1_output, scale_fp16, zp, dq1_output);
builder.AddNode("Mul", {dq1_output, const_arg}, {mul_output});
builder.AddQuantizeLinearNode<uint8_t, MLFloat16>(mul_output, scale_fp16, zp, q2_output);
builder.AddDequantizeLinearNode<uint8_t, float>(q2_output, scale_fp32, zp, dq2_output);
};

// Previously, using different scale data types caused an exception in IsQDQPairSupported.
// Now, we just reject the sequence. The DoubleQDQPairsRemover optimizer should not change the
// graph.
auto graph_checker = [](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
EXPECT_EQ(op_to_count["QuantizeLinear"], 2);
EXPECT_EQ(op_to_count["Mul"], 1);
EXPECT_EQ(op_to_count["DequantizeLinear"], 2);
};
TransformerTester(
build_model_func,
graph_checker,
TransformerLevel::Default,
TransformerLevel::Level1,
21,
/*per_sample_tolerance*/ 0.0,
/*relative_per_sample_tolerance*/ 0.0,
std::make_unique<DoubleQDQPairsRemover>());
}

template <typename QuantType>
static void RunDoubleQDQWithoutLastNodeBeingOutput(int output_index, int expected_Q_count, int expected_DQ_count,
bool use_contrib_qdq = false, int opset = 12) {
Expand Down