@@ -466,34 +466,13 @@ struct CustomGraph {
466466 if (prev.node_ptr ->OutputDefs ()[0 ]->Type () != dq_node_ref.OutputDefs ()[0 ]->Type ()) {
467467 NodeArg& output = original_graph.GetOrCreateNodeArg (prev.node_name + " _cast_0" , dq_node_ref.OutputDefs ()[0 ]->TypeAsProto ());
468468 std::string cast_node_name = prev.node_ptr ->OutputDefs ()[0 ]->Name () + " _cast" ;
469- InlinedVector<NodeArg*> input_args = {( NodeArg*) (prev.node_ptr ->OutputDefs ()[0 ])};
469+ InlinedVector<NodeArg*> input_args = {const_cast < NodeArg*> (prev.node_ptr ->OutputDefs ()[0 ])};
470470 InlinedVector<NodeArg*> output_args = {&output};
471- std::unordered_map<std::string, int > type_str_to_tensor_data_type_;
472- type_str_to_tensor_data_type_[" tensor(float)" ] = ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
473- type_str_to_tensor_data_type_[" tensor(float16)" ] = ONNX_NAMESPACE::TensorProto_DataType_FLOAT16;
474- type_str_to_tensor_data_type_[" tensor(bfloat16)" ] = ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16;
475- type_str_to_tensor_data_type_[" tensor(double)" ] = ONNX_NAMESPACE::TensorProto_DataType_DOUBLE;
476- type_str_to_tensor_data_type_[" tensor(int8)" ] = ONNX_NAMESPACE::TensorProto_DataType_INT8;
477- type_str_to_tensor_data_type_[" tensor(int16)" ] = ONNX_NAMESPACE::TensorProto_DataType_INT16;
478- type_str_to_tensor_data_type_[" tensor(int32)" ] = ONNX_NAMESPACE::TensorProto_DataType_INT32;
479- type_str_to_tensor_data_type_[" tensor(int64)" ] = ONNX_NAMESPACE::TensorProto_DataType_INT64;
480- type_str_to_tensor_data_type_[" tensor(uint8)" ] = ONNX_NAMESPACE::TensorProto_DataType_UINT8;
481- type_str_to_tensor_data_type_[" tensor(uint16)" ] = ONNX_NAMESPACE::TensorProto_DataType_UINT16;
482- type_str_to_tensor_data_type_[" tensor(uint32)" ] = ONNX_NAMESPACE::TensorProto_DataType_UINT32;
483- type_str_to_tensor_data_type_[" tensor(uint64)" ] = ONNX_NAMESPACE::TensorProto_DataType_UINT64;
484- type_str_to_tensor_data_type_[" tensor(complex64)" ] = ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64;
485- type_str_to_tensor_data_type_[" tensor(complex128)" ] = ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128;
486- type_str_to_tensor_data_type_[" tensor(string)" ] = ONNX_NAMESPACE::TensorProto_DataType_STRING;
487- type_str_to_tensor_data_type_[" tensor(bool)" ] = ONNX_NAMESPACE::TensorProto_DataType_BOOL;
488- type_str_to_tensor_data_type_[" tensor(float8e4m3fn)" ] = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN;
489- type_str_to_tensor_data_type_[" tensor(float8e4m3fnuz)" ] = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ;
490- type_str_to_tensor_data_type_[" tensor(float8e5m2)" ] = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2;
491- type_str_to_tensor_data_type_[" tensor(float8e5m2fnuz)" ] = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ;
492- type_str_to_tensor_data_type_[" tensor(uint4)" ] = ONNX_NAMESPACE::TensorProto_DataType_UINT4;
493- type_str_to_tensor_data_type_[" tensor(int4)" ] = ONNX_NAMESPACE::TensorProto_DataType_INT4;
494- type_str_to_tensor_data_type_[" tensor(float4e2m1)" ] = ONNX_NAMESPACE::TensorProto_DataType_FLOAT4E2M1;
495471 Node& cast_node = original_graph.AddNode (cast_node_name, " Cast" , " " , input_args, output_args, nullptr , " " );
496- auto type_cast = type_str_to_tensor_data_type_[*dq_node_ref.OutputDefs ()[0 ]->Type ()];
472+ auto type_str = dq_node_ref.OutputDefs ()[0 ]->Type ();
473+ auto type_cast = type_str->find (" tensor(float)" ) != std::string::npos ? onnx::TensorProto_DataType_FLOAT : onnx::TensorProto_DataType_FLOAT16;
474+ ORT_ENFORCE ((type_cast == onnx::TensorProto_DataType_FLOAT) || (type_str->find (" tensor(float16)" ) != std::string::npos),
475+ " QDQ type misalignment, expected float32 or float16 output" );
497476 cast_node.AddAttribute (" to" , static_cast <int64_t >(type_cast));
498477 original_graph.AddEdge (prev.node_ptr ->Index (),
499478 cast_node.Index (),
0 commit comments