diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc index ef65424df87b9..f07ee80811001 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc @@ -337,6 +337,18 @@ static bool CheckDQRuleSet(const NodeUnit& node_unit, } } +static bool CheckQFeedsIntoQuantizedOutput(const NodeUnit& node_unit, + const std::unordered_map graph_op_data_type) { + auto op_of_quantized_layer = node_unit.Outputs(); + for (auto itr : op_of_quantized_layer) { + auto it = graph_op_data_type.find(itr.node_arg.Name()); + if (it != graph_op_data_type.end() && it->second == "tensor(uint8)") { + return true; + } + } + return false; +} + static bool CheckQRuleSet(const NodeUnit& node_unit, const Node* q_node, const onnxruntime::GraphViewer& src_graph, @@ -347,6 +359,12 @@ static bool CheckQRuleSet(const NodeUnit& node_unit, const auto& target_node = node_unit.GetNode(); auto op_type = node_unit.OpType(); + auto op = src_graph.GetOutputs(); + std::unordered_map graph_op_data_type; + for (auto& ops : op) { + graph_op_data_type[src_graph.GetNodeArg(ops->Name())->Name()] = ops->Type()->data(); + } + // If UInt16 Q, don't keep it if (GetQDQDataType(q_node) == DT_UINT16 || GetQDQDataType(q_node) == DT_INT16) { reason = SkipReason::Int16QDQ; @@ -359,6 +377,8 @@ static bool CheckQRuleSet(const NodeUnit& node_unit, } else if (op_type == "Add") { // Add keeps all Qs return true; + } else if (CheckQFeedsIntoQuantizedOutput(node_unit, graph_op_data_type)) { + return true; } else { // Keep Q of an unsupported Op only if the target that succeeds it is a supported Op in this list return IsNextTargetNodeOfQValid(q_node, &target_node, src_graph, {"Conv", "Add", "MatMul"}, false);