Skip to content

Commit 69f09bb

Browse files
Limit output types to f32/f16, add const_cast
1 parent de6e250 commit 69f09bb

File tree

1 file changed

+5
-26
lines changed

1 file changed

+5
-26
lines changed

onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -466,34 +466,13 @@ struct CustomGraph {
466466
if (prev.node_ptr->OutputDefs()[0]->Type() != dq_node_ref.OutputDefs()[0]->Type()) {
467467
NodeArg& output = original_graph.GetOrCreateNodeArg(prev.node_name + "_cast_0", dq_node_ref.OutputDefs()[0]->TypeAsProto());
468468
std::string cast_node_name = prev.node_ptr->OutputDefs()[0]->Name() + "_cast";
469-
InlinedVector<NodeArg*> input_args = {(NodeArg*)(prev.node_ptr->OutputDefs()[0])};
469+
InlinedVector<NodeArg*> input_args = {const_cast<NodeArg*>(prev.node_ptr->OutputDefs()[0])};
470470
InlinedVector<NodeArg*> output_args = {&output};
471-
std::unordered_map<std::string, int> type_str_to_tensor_data_type_;
472-
type_str_to_tensor_data_type_["tensor(float)"] = ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
473-
type_str_to_tensor_data_type_["tensor(float16)"] = ONNX_NAMESPACE::TensorProto_DataType_FLOAT16;
474-
type_str_to_tensor_data_type_["tensor(bfloat16)"] = ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16;
475-
type_str_to_tensor_data_type_["tensor(double)"] = ONNX_NAMESPACE::TensorProto_DataType_DOUBLE;
476-
type_str_to_tensor_data_type_["tensor(int8)"] = ONNX_NAMESPACE::TensorProto_DataType_INT8;
477-
type_str_to_tensor_data_type_["tensor(int16)"] = ONNX_NAMESPACE::TensorProto_DataType_INT16;
478-
type_str_to_tensor_data_type_["tensor(int32)"] = ONNX_NAMESPACE::TensorProto_DataType_INT32;
479-
type_str_to_tensor_data_type_["tensor(int64)"] = ONNX_NAMESPACE::TensorProto_DataType_INT64;
480-
type_str_to_tensor_data_type_["tensor(uint8)"] = ONNX_NAMESPACE::TensorProto_DataType_UINT8;
481-
type_str_to_tensor_data_type_["tensor(uint16)"] = ONNX_NAMESPACE::TensorProto_DataType_UINT16;
482-
type_str_to_tensor_data_type_["tensor(uint32)"] = ONNX_NAMESPACE::TensorProto_DataType_UINT32;
483-
type_str_to_tensor_data_type_["tensor(uint64)"] = ONNX_NAMESPACE::TensorProto_DataType_UINT64;
484-
type_str_to_tensor_data_type_["tensor(complex64)"] = ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64;
485-
type_str_to_tensor_data_type_["tensor(complex128)"] = ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128;
486-
type_str_to_tensor_data_type_["tensor(string)"] = ONNX_NAMESPACE::TensorProto_DataType_STRING;
487-
type_str_to_tensor_data_type_["tensor(bool)"] = ONNX_NAMESPACE::TensorProto_DataType_BOOL;
488-
type_str_to_tensor_data_type_["tensor(float8e4m3fn)"] = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN;
489-
type_str_to_tensor_data_type_["tensor(float8e4m3fnuz)"] = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ;
490-
type_str_to_tensor_data_type_["tensor(float8e5m2)"] = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2;
491-
type_str_to_tensor_data_type_["tensor(float8e5m2fnuz)"] = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ;
492-
type_str_to_tensor_data_type_["tensor(uint4)"] = ONNX_NAMESPACE::TensorProto_DataType_UINT4;
493-
type_str_to_tensor_data_type_["tensor(int4)"] = ONNX_NAMESPACE::TensorProto_DataType_INT4;
494-
type_str_to_tensor_data_type_["tensor(float4e2m1)"] = ONNX_NAMESPACE::TensorProto_DataType_FLOAT4E2M1;
495471
Node& cast_node = original_graph.AddNode(cast_node_name, "Cast", "", input_args, output_args, nullptr, "");
496-
auto type_cast = type_str_to_tensor_data_type_[*dq_node_ref.OutputDefs()[0]->Type()];
472+
auto type_str = dq_node_ref.OutputDefs()[0]->Type();
473+
auto type_cast = type_str->find("tensor(float)") != std::string::npos ? onnx::TensorProto_DataType_FLOAT : onnx::TensorProto_DataType_FLOAT16;
474+
ORT_ENFORCE((type_cast == onnx::TensorProto_DataType_FLOAT) || (type_str->find("tensor(float16)") != std::string::npos),
475+
"QDQ type misalignment, expected float32 or float16 output");
497476
cast_node.AddAttribute("to", static_cast<int64_t>(type_cast));
498477
original_graph.AddEdge(prev.node_ptr->Index(),
499478
cast_node.Index(),

0 commit comments

Comments
 (0)