Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problems caused by GetInputNameAllocated after upgrading from 1.12 to 1.13 #14157

Closed
UNeedCryDear opened this issue Jan 6, 2023 · 16 comments
Closed
Labels
api issues related to all other APIs: C, C++, Python, etc. ep:CUDA issues related to the CUDA execution provider platform:windows issues related to the Windows platform

Comments

@UNeedCryDear
Copy link

UNeedCryDear commented Jan 6, 2023

Describe the issue

When upgrading from 1.12.x to 1.13.x,GetInputName and GetOutputName need to be replaced with GetInputNameAllocated and GetOutputNameAllocated, I encountered a very strange bug here.

onnx mode export from yolov5-seg.pt:
https://drive.google.com/file/d/1tV2moQxNfLzNf6yXm5Zev5CVj2o9zuaz/view?usp=share_link

Run the following code, and everything is OK for TestONNX(), but when running TestONNX2(), the input and output nodes names become strange after session->Run()
image

image

They just have different ways of obtaining node names. One is for loop, and the other is useless.
image

An error will be reported even if nothing is modified beyond the bracket of the node name:
image

So, if I have multiple inputs and outputs and cannot use the for loop, how can I solve this problem?

To reproduce

int TestONNX(std::string modelPath, bool useCuda=true) {

	Ort::Env ortEnv = Ort::Env(OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR, "Yolov5-Seg");
	Ort::SessionOptions sessionOptions = Ort::SessionOptions();
	std::vector<std::string> available_providers = GetAvailableProviders();
	auto cuda_available = std::find(available_providers.begin(), available_providers.end(), "CUDAExecutionProvider");
	OrtCUDAProviderOptions cudaOption;
	if (useCuda && (cuda_available == available_providers.end()))
	{
		std::cout << "Your ORT build without GPU. Change to CPU." << std::endl;
	  std::cout << " Infer model on CPU "<< std::endl;
	}
	else if (useCuda && (cuda_available != available_providers.end()))
	{
		std::cout << "* Infer model on GPU! " << std::endl;
		OrtStatus* status = OrtSessionOptionsAppendExecutionProvider_CUDA(sessionOptions, 0);
	}
	else
	{
		std::cout << " Infer model on CPU! " << std::endl;
	}
	//
	sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);

#ifdef _WIN32
	std::wstring model_path(modelPath.begin(), modelPath.end());
	Ort::Session* session = new Ort::Session(ortEnv, model_path.c_str(), sessionOptions);
#else
	Ort::Session* session = new Ort::Session(_OrtEnv, modelPath.c_str(), sessionOptions);
#endif

	Ort::AllocatorWithDefaultOptions allocator;
	//init input

	std::vector<const char*> inputNodeNames; //
	std::vector<const char*> outputNodeNames;//
	std::vector<int64_t> inputTensorShape; //
	std::vector<int64_t> outputTensorShape;
	std::vector<int64_t> outputMaskTensorShape;

	auto inputNodesNum = session->GetInputCount();

	auto temp_input_name0 = session->GetInputNameAllocated(0, allocator);
	inputNodeNames.push_back(temp_input_name0.get());


	Ort::TypeInfo inputTypeInfo = session->GetInputTypeInfo(0);
	auto input_tensor_info = inputTypeInfo.GetTensorTypeAndShapeInfo();
	//inputNodeDataType = input_tensor_info.GetElementType();
	inputTensorShape = input_tensor_info.GetShape();
	//init output
	auto outputNodesNum = session->GetOutputCount();

	auto temp_output_name0 = session->GetOutputNameAllocated(0, allocator);
	auto temp_output_name1 = session->GetOutputNameAllocated(1, allocator);
	Ort::TypeInfo type_info_output0(nullptr);
	Ort::TypeInfo type_info_output1(nullptr);
	type_info_output0 = session->GetOutputTypeInfo(0);  //output0
	type_info_output1 = session->GetOutputTypeInfo(1);  //output1
	outputNodeNames.push_back(temp_output_name0.get());
	outputNodeNames.push_back(temp_output_name1.get());


	auto tensor_info_output0 = type_info_output0.GetTensorTypeAndShapeInfo();
	//outputNodeDataType = tensor_info_output0.GetElementType();
	outputTensorShape = tensor_info_output0.GetShape();
	auto tensor_info_output1 = type_info_output1.GetTensorTypeAndShapeInfo();
	//_outputMaskNodeDataType = tensor_info_output1.GetElementType(); //the same as output0
	outputMaskTensorShape = tensor_info_output1.GetShape();

	auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtDeviceAllocator, OrtMemType::OrtMemTypeCPUOutput);
	for (int i = 0; i < 3; i++) {
		std::cout << "Start warming up" << endl;
		size_t input_tensor_length = 640 * 640 * 3;
		float* temp = new float[input_tensor_length];
		std::vector<Ort::Value> input_tensors;
		std::vector<Ort::Value> output_tensors;
		std::cout << "################### befor run:##############" << endl;
		std::cout << "input node name:" << inputNodeNames[0] << endl;
		std::cout << "output0 node name:" << outputNodeNames[0] << endl;
		std::cout << "output1 node name:" << outputNodeNames[1] << endl;

		input_tensors.push_back(Ort::Value::CreateTensor<float>(
			memoryInfo, temp, input_tensor_length, inputTensorShape.data(),
			inputTensorShape.size()));
		output_tensors = session->Run(Ort::RunOptions{ nullptr },
			inputNodeNames.data(),
			input_tensors.data(),
			inputNodeNames.size(),
			outputNodeNames.data(),
			outputNodeNames.size());
		std::cout << "################### after run:##############" << endl;
		std::cout << "input node name:" << inputNodeNames[0] << endl;
		std::cout << "output0 node name:" << outputNodeNames[0] << endl;
		std::cout << "output1 node name:" << outputNodeNames[1] << endl;
	}

	std::cout << "*********************************** test onnx ok  ***************************************" << endl;

	return 0;
}

int TestONNX2(std::string modelPath, bool useCuda = true) {

	Ort::Env ortEnv = Ort::Env(OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR, "Yolov5-Seg2");
	Ort::SessionOptions sessionOptions = Ort::SessionOptions();
	std::vector<std::string> available_providers = GetAvailableProviders();
	auto cuda_available = std::find(available_providers.begin(), available_providers.end(), "CUDAExecutionProvider");
	OrtCUDAProviderOptions cudaOption;
	if (useCuda && (cuda_available == available_providers.end()))
	{
		std::cout << "Your ORT build without GPU. Change to CPU." << std::endl;
		std::cout << " Infer model on CPU " << std::endl;
	}
	else if (useCuda && (cuda_available != available_providers.end()))
	{
		std::cout << "* Infer model on GPU! " << std::endl;
		OrtStatus* status = OrtSessionOptionsAppendExecutionProvider_CUDA(sessionOptions, 0);
	}
	else
	{
		std::cout << " Infer model on CPU! " << std::endl;
	}
	//
	sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);

#ifdef _WIN32
	std::wstring model_path(modelPath.begin(), modelPath.end());
	Ort::Session* session = new Ort::Session(ortEnv, model_path.c_str(), sessionOptions);
#else
	Ort::Session* session = new Ort::Session(_OrtEnv, modelPath.c_str(), sessionOptions);
#endif

	Ort::AllocatorWithDefaultOptions allocator;
	//init input

	std::vector<const char*> inputNodeNames; //
	std::vector<const char*> outputNodeNames;//
	std::vector<int64_t> inputTensorShape; //
	std::vector<int64_t> outputTensorShape;
	std::vector<int64_t> outputMaskTensorShape;

	auto inputNodesNum = session->GetInputCount();
	for (int i = 0; i < inputNodesNum; i++) {
		auto temp_input_name = session->GetInputNameAllocated(i, allocator);
		inputNodeNames.push_back(temp_input_name.get());
	}

	Ort::TypeInfo inputTypeInfo = session->GetInputTypeInfo(0);
	auto input_tensor_info = inputTypeInfo.GetTensorTypeAndShapeInfo();
	//inputNodeDataType = input_tensor_info.GetElementType();
	inputTensorShape = input_tensor_info.GetShape();
	//init output
	auto outputNodesNum = session->GetOutputCount();

	for (int i = 0; i < outputNodesNum; i++) {
		auto temp_output_name = session->GetOutputNameAllocated(i, allocator);
		Ort::TypeInfo type_info_output(nullptr);
		type_info_output = session->GetOutputTypeInfo(i);  //output0
		outputNodeNames.push_back(temp_output_name.get());

		auto tensor_info_output = type_info_output.GetTensorTypeAndShapeInfo();
		if (i == 0) {
			outputTensorShape = tensor_info_output.GetShape();
		}
		else {
			outputMaskTensorShape = tensor_info_output.GetShape();
		}

	}
	auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtDeviceAllocator, OrtMemType::OrtMemTypeCPUOutput);
	for (int i = 0; i < 3; i++) {
		std::cout << "Start warming up" << endl;
		size_t input_tensor_length = 640 * 640 * 3;
		float* temp = new float[input_tensor_length];
		std::vector<Ort::Value> input_tensors;
		std::vector<Ort::Value> output_tensors;
		std::cout << "################### befor run:##############" << endl;
		std::cout << "input node name:" << inputNodeNames[0] << endl;
		std::cout << "output0 node name:" << outputNodeNames[0] << endl;
		std::cout << "output1 node name:" << outputNodeNames[1] << endl;

		input_tensors.push_back(Ort::Value::CreateTensor<float>(
			memoryInfo, temp, input_tensor_length, inputTensorShape.data(),
			inputTensorShape.size()));
		output_tensors = session->Run(Ort::RunOptions{ nullptr },
			inputNodeNames.data(),
			input_tensors.data(),
			inputNodeNames.size(),
			outputNodeNames.data(),
			outputNodeNames.size());
		std::cout << "################### after run:##############" << endl;
		std::cout << "input node name:" << inputNodeNames[0] << endl;
		std::cout << "output0 node name:" << outputNodeNames[0] << endl;
		std::cout << "output1 node name:" << outputNodeNames[1] << endl;
	}

	return 0;
}
int main()
{
	string model_path = "./yolov5s-seg.onnx";
	TestONNX(model_path, true);
	TestONNX2(model_path, true);
	return 0;

}

Urgency

No response

Platform

Windows

OS Version

WIN10 22H2

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.13.1 from https://github.com/microsoft/onnxruntime/releases/download/v1.13.1/onnxruntime-win-x64-gpu-1.13.1.zip

ONNX Runtime API

C++

Architecture

X64

Execution Provider

CUDA

Execution Provider Library Version

CUDA 11.4

@github-actions github-actions bot added ep:CUDA issues related to the CUDA execution provider platform:windows issues related to the Windows platform labels Jan 6, 2023
@edgchen1
Copy link
Contributor

edgchen1 commented Jan 6, 2023

Session::GetInputNameAllocated() returns a std::unique_ptr.

using AllocatedStringPtr = std::unique_ptr<char, detail::AllocatedFree>;

AllocatedStringPtr GetInputNameAllocated(size_t index, OrtAllocator* allocator) const;

In this code, a temporary std::unique_ptr is created in the loop but a pointer to its memory is stored in inputNodeNames. After that loop iteration the memory will no longer be valid.

	for (int i = 0; i < inputNodesNum; i++) {
		auto temp_input_name = session->GetInputNameAllocated(i, allocator);
		inputNodeNames.push_back(temp_input_name.get());
	}

You could store the result of GetInputNameAllocated() (e.g., in a std::vector<AllocatedStringPtr>) and ensure that the AllocatedStringPtr is in scope while you need to access its data to avoid this issue.

@yuslepukhin
Copy link
Member

yuslepukhin commented Jan 6, 2023

Yes temp_input_name is destroyed on every iteration and it deallocates the name. The code is storing a pointer to a freed memory, that is being reused. The reason why the API was changed is because GetInput/OutputName() was leaking the raw pointer, it was never deallocated.

The code is also leaking floating point input buffers since CreateTensor does not take ownership of the input buffers.

Ort::Session is also allocated on the heap for some reason unlike other objects.

@pranavsharma pranavsharma added the api issues related to all other APIs: C, C++, Python, etc. label Jan 6, 2023
@UNeedCryDear
Copy link
Author

@edgchen1

You could store the result of GetInputNameAllocated() (e.g., in a std::vector<AllocatedStringPtr>) and ensure that the AllocatedStringPtr is in scope while you need to access its data to avoid this issue.<.br>

I've tried std::vector<AllocatedStringPtr>, but it's the same problem. It doesn't work after only session.run() once.
Because there is only one model, and the input and output names are fixed, I don't want to get the input and output names every time forward model, but I want to get them once when I read the model, which will be used later.

string in_name="images";
string out_name0="output0";
string out_name1="output1";
inputNodeNames.push_back(in_name.c_str());
outputNodeNames.push_back(out_name0.c_str());
outputNodeNames.push_back(out_name1.c_str());

I can get it correctly in this way now, but once a model is changed, the name may change, which will cause some other problems.
Is there any way to achieve this effect?

@UNeedCryDear
Copy link
Author

@yuslepukhin

Yes temp_input_name is destroyed on every iteration and it deallocates the name. The code is storing a pointer to a freed memory, that is being reused. The reason why the API was changed is because GetInput/OutputName() was leaking the raw pointer, it was never deallocated.

Since the amount of data I need to infer is large, I don't want to get the input and output names again every time before running. Is there any way to achieve it?

The code is also leaking floating point input buffers since CreateTensor does not take ownership of the input buffers.

Because it is analog data, I forget to delete the temp pointer of each loop. Is it OK to modify it like this?
cv::Mat::zeros():Simulate reading a new picture from the disk;

...
auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtDeviceAllocator, OrtMemType::OrtMemTypeCPUOutput);
size_t input_tensor_length = 640 * 640 * 3;
	cv::Size input_size(640, 640);
	for (int i = 0; i < 10; i++) {
		cv::Mat img = cv::Mat::zeros(input_size, CV_8UC3);//Simulate reading a  new picture from the disk; img=imread("new_img_path");
		Mat blob;
		cv::dnn::blobFromImage(img, blob, 1 / 255.0, input_size, Scalar(0, 0, 0), true, false);
		std::vector<Ort::Value> input_tensors;
		std::vector<Ort::Value> output_tensors;
		std::cout << "################### befor run:##############" << endl;
		std::cout << "input node name:" << inputNodeNames[0] << endl;
		std::cout << "output0 node name:" << outputNodeNames[0] << endl;
		std::cout << "output1 node name:" << outputNodeNames[1] << endl;
		input_tensors.push_back(Ort::Value::CreateTensor<float>(
			memoryInfo, (float*)blob.data, input_tensor_length, inputTensorShape.data(),
			inputTensorShape.size()));
		output_tensors = session->Run(Ort::RunOptions{ nullptr },
			inputNodeNames.data(),
			input_tensors.data(),
			inputNodeNames.size(),
			outputNodeNames.data(),
			outputNodeNames.size());
		std::cout << "################### after run:##############" << endl;
		std::cout << "input node name:" << inputNodeNames[0] << endl;
		std::cout << "output0 node name:" << outputNodeNames[0] << endl;
		std::cout << "output1 node name:" << outputNodeNames[1] << endl;
	}

@edgchen1
Copy link
Contributor

edgchen1 commented Jan 6, 2023

I've tried std::vector, but it's the same problem. It doesn't work after only session.run() once.

To clarify, in this approach there is an additional vector.

	std::vector<const char*> inputNodeNames; //
	std::vector<AllocatedStringPtr> inputNodeNameAllocatedStrings; // <-- newly added
	...

	auto inputNodesNum = session->GetInputCount();
	for (int i = 0; i < inputNodesNum; i++) {
		auto input_name = session->GetInputNameAllocated(i, allocator);
		inputNodeNameAllocatedStrings.push_back(std::move(input_name));
		inputNodeNames.push_back(inputNodeNameAllocatedStrings.back().get());
	}

So the memory pointed to by inputNodeNames[i] is owned by inputNodeNameAllocatedStrings[i].

Because there is only one model, and the input and output names are fixed, I don't want to get the input and output names every time forward model, but I want to get them once when I read the model, which will be used later.

You can do this after loading the model before the call(s) to Session::Run().

@UNeedCryDear
Copy link
Author

@edgchen1 Thanks for your help, it works.

@chengdashia
Copy link

I've tried std::vector, but it's the same problem. It doesn't work after only session.run() once.

To clarify, in this approach there is an additional vector.

	std::vector<const char*> inputNodeNames; //
	std::vector<AllocatedStringPtr> inputNodeNameAllocatedStrings; // <-- newly added
	...

	auto inputNodesNum = session->GetInputCount();
	for (int i = 0; i < inputNodesNum; i++) {
		auto input_name = session->GetInputNameAllocated(i, allocator);
		inputNodeNameAllocatedStrings.push_back(std::move(input_name));
		inputNodeNames.push_back(inputNodeNameAllocatedStrings.back().get());
	}

So the memory pointed to by inputNodeNames[i] is owned by inputNodeNameAllocatedStrings[i].

Because there is only one model, and the input and output names are fixed, I don't want to get the input and output names every time forward model, but I want to get them once when I read the model, which will be used later.

You can do this after loading the model before the call(s) to Session::Run().

I tried your method and still reported the same error

image
image

@UNeedCryDear
Copy link
Author

UNeedCryDear commented Nov 7, 2023

@chengdashia
Changing unique_ptr to shared_ptr by std::move().

std::shared_ptr<char> inputName;
std::vector<char*> inputNodeNames;
inputName = std::move(ort_session->GetInputNameAllocated(0, allocator));
inputNodeNames.push_back(inputName.get());

@chengdashia
Copy link

@chengdashia Changing unique_ptr to shared_ptr by std::move().

std::shared_ptr<char> inputName;
std::vector<char*> inputNodeNames;
inputName = std::move(ort_session->GetInputNameAllocated(0, allocator));
inputNodeNames.push_back(inputName.get());

I'm sorry for bothering you again, but I tried using the code you provided and I'm still getting the same error. I was wondering which version of onnxruntime you are using?
image
image

@UNeedCryDear
Copy link
Author

UNeedCryDear commented Nov 9, 2023

@chengdashia
ORT_VERSION>=1.13.0

input_names and output_names should not be used as a temporary variable in "for",which is you are doing.
You should treat it as a class member variable or outside of "for" to ensure that these two variables can obtain the correct values when you run “ort_session->Run()”
https://github.com/UNeedCryDear/yolov8-opencv-onnxruntime-cpp/blob/main/yolov8_seg_onnx.h
https://github.com/UNeedCryDear/yolov8-opencv-onnxruntime-cpp/blob/main/yolov8_seg_onnx.cpp

@chengdashia
Copy link

chengdashia commented Nov 9, 2023

@chengdashia ORT_VERSION>=1.13.0

input_names and output_names should not be used as a temporary variable in "for",which writing like a temporary variable as you do. You should treat it as a class member variable or outside of "for" to ensure that these two variables can obtain the correct values when you run “ort_session->Run()” https://github.com/UNeedCryDear/yolov8-opencv-onnxruntime-cpp/blob/main/yolov8_seg_onnx.h https://github.com/UNeedCryDear/yolov8-opencv-onnxruntime-cpp/blob/main/yolov8_seg_onnx.cpp

Yes,As you said, I treat them as member variables.
The person I am referring to wrote about the cpp version of onnxruntime deploying yolov5. Python can do https://github.com/hpc203/yolov5-v6.1-opencv-onnxrun
image

@UNeedCryDear
Copy link
Author

Yeah, I know about this and there are other Yolov5 repositories available, so I did not release the ORT of yolov5, but instead replaced it with yolov5-seg.

@chengdashia
Copy link

Yeah, I know about this and there are other Yolov5 repositories available, so I did not release the ORT of yolov5, but instead replaced it with yolov5-seg.

Sorry, I didn't quite understand what you meant. You mean, let me look at your warehouse to find the answer? I trained a defect detection model for PCB using YOLOv5.
https://github.com/UNeedCryDear/yolov5-seg-opencv-onnxruntime-cpp

@UNeedCryDear
Copy link
Author

UNeedCryDear commented Nov 9, 2023

You misunderstood. I mean, Yolov5 for onnxruntime was not released because many people have already released it.
Do you mean that you still report errors when using the hpc203‘s repository?
Of course, you can also infer YOLOv5 by modifying the yolov5-seg code, as the difference between them is not significant and can be easily modified. If you think that the changes cannot be made, I can release the ORT codes of YOLOv5 later.

@chengdashia
Copy link

You misunderstood. I mean, Yolov5 for onnxruntime was not released because many people have already released it. Do you mean that you still report errors when using the hpc203‘s repository? Of course, you can also infer YOLOv5 by modifying the yolov5-seg code, as the difference between them is not significant and can be easily modified. If you think that the changes cannot be made, I can release the ORT codes of YOLOv5 later.

I originally trained with yolov5-7.0, using the pcb defect data set published by Peking University. Then I started using OpenCV to deploy in c++, but the frame rate was very slow, and that's when I came to onnxruntime. I have been using the hpc203 warehouse code, his c++ OpenCV deployment is no problem, python onnxruntime deployment is no problem. However, when using his c++ onnxruntime for deployment, the present problem arose. I suspected it was a version problem, then I tried 1.7,1.13.1,1.14,1.15, 1.16,16.1. It never worked out. Finally I found this place. I hope you can help me. Thank you very much.

@UNeedCryDear
Copy link
Author

@chengdashia
The code has been updated,check https://github.com/UNeedCryDear/yolov5-seg-opencv-onnxruntime-cpp for yolov5-onnx.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
api issues related to all other APIs: C, C++, Python, etc. ep:CUDA issues related to the CUDA execution provider platform:windows issues related to the Windows platform
Projects
None yet
Development

No branches or pull requests

5 participants