New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pytorch2onnx #3075
Pytorch2onnx #3075
Conversation
This function is in tools/onnx_util/symbolic.py. This is for users who use pytorch1.3. There are some bugs in Pytorch1.3's onnx part, besides, it does not implement TopK op. So you can regard register_extra_symbolics as a mock patch. |
As ONNX can not support dynamic input shape, I believe it's much more convenient for you if you can run the code locally, so that you can set whatever the input shape you want. |
OK, just for this emergency case :) |
I just tried to convert an ONNX model. There is some numerical difference between Pytorch and ONNX, maybe this is due to the dummy input. Anyway, here is the link: |
OK, Good luck with your tomorrow's meeting! |
You need to implement it by yourself. It is not supported yet. |
First, I use a pretrain pytorch model for converting, however, this pth file is not totally compatible with the master code, which means some op does not have its correspond pretrain value, so the original pytorch result is also some kinds low. Besides, although I am not sure, I believe there is some numerical difference between Pytorch and Onnxruntime, I will figure it out next week. Finally, why you said Resize should be converted by Upsample? Resize is a standard ONNX op which can be executed by ONNX runtime directly. |
I have only tested the onnx on cpu using onnxruntime. Maybe it does not work for GPU. In fact, we have been working on another part that supports the converting from onnx to trt, but it has not been published yet. I'm afraid you should implement the correspond part. |
Please see the tools/pytorch2onnx.py, the onnx runtime code is executed while using --verify |
Pytorch1.3 Python3.7.5 . Please make sure you use the correct branch |
Sorry, I do not know how to directly modify onnx model. Onnx model is just a temporary file used to convert pytorch to other backend engine. |
Maybe in the future :). We implement TensorRT plugin for customized op |
After testing your onnx is wrong, you can simplify the input and output |
Would you please describe why it's incorrect? |
1 similar comment
Would you please describe why it's incorrect? |
On topk and usample, there is no operation, and its output is wrong |
I check that onnx supports these operations, but the onnx you transferred does not have this operation. You can see if your 1.3 pytorch supports this operation |
@@ -93,6 +94,10 @@ def simple_test(self, img, img_metas, rescale=False): | |||
outs = self.bbox_head(x) | |||
bbox_list = self.bbox_head.get_bboxes( | |||
*outs, img_metas, rescale=rescale) | |||
# return in advance when export to ONNX |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# return in advance when export to ONNX | |
# skip post-processing when exporting to ONNX |
@@ -97,8 +97,9 @@ def bbox2result(bboxes, labels, num_classes): | |||
if bboxes.shape[0] == 0: | |||
return [np.zeros((0, 5), dtype=np.float32) for i in range(num_classes)] | |||
else: | |||
bboxes = bboxes.cpu().numpy() | |||
labels = labels.cpu().numpy() | |||
if isinstance(bboxes, torch.Tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this necessary? If so, we need to update the docstring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MMdet, after executing bbox2result, will return a np.ndarray value, which is not support by ONNX (ONNX can not trace np op, but only tensor op), so we can only convert the previous part into ONNX. So if we want to compare the result between Pytorch and ONNX, we have to use bbox2result to convert the output of ONNX. So this time, the input of bbox2result is np.ndarray (the ONNXruntime's output type)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may keep this part unchanged and add use [bboxes[labels == i, :] for i in range(num_classes)]
in pytorch2onnx()
We may also need to update doc here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have submitted a new version according to the reviewer which mainly modify the ga_rpn_head's code of calling nms.
@@ -97,8 +97,9 @@ def bbox2result(bboxes, labels, num_classes): | |||
if bboxes.shape[0] == 0: | |||
return [np.zeros((0, 5), dtype=np.float32) for i in range(num_classes)] | |||
else: | |||
bboxes = bboxes.cpu().numpy() | |||
labels = labels.cpu().numpy() | |||
if isinstance(bboxes, torch.Tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MMdet, after executing bbox2result, will return a np.ndarray value, which is not support by ONNX (ONNX can not trace np op, but only tensor op), so we can only convert the previous part into ONNX. So if we want to compare the result between Pytorch and ONNX, we have to use bbox2result to convert the output of ONNX. So this time, the input of bbox2result is np.ndarray (the ONNXruntime's output type)
tools/pytorch2onnx.py
Outdated
parser.add_argument( | ||
'--out', type=str, required=True, help='output ONNX filename') | ||
'--verify', action='store_true', help='verify the onnx model') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
'--verify', action='store_true', help='verify the onnx model') | |
'--verify', action='store_true', help='verify the onnx model output against pytorch output') |
parser.add_argument('config', help='test config file path') | ||
parser.add_argument('checkpoint', help='checkpoint file') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For object detection, we may make checkpoint
a required argument. Without checkpoint, some branches may not be covered.
tools/pytorch2onnx.py
Outdated
one_img = mmcv.imread(input_img, 'color') | ||
one_img = mmcv.imresize(one_img, input_shape[2:]).transpose(2, 0, 1) | ||
# normalize the input images | ||
one_img = torch.from_numpy((one_img - 128) / 256).unsqueeze(0).float() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why normalization is fixed to 128
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest making image norm a user input. The default could be imagenet mean/std.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest making image norm a user input. The default could be imagenet mean/std.
After having some test, I decide to remove the normalized part. As without this step, we can still gain correct RetinaNet with default picture. Besides, the MMDet will raise an Error while we do not execute NMS while ONNX tracing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove img_norm
may get in incorrect results for other images.
I suggest making image norm a user input. The default could be imagenet mean/std.
Sorry, I don't know. |
All right, I will give you a full command list, please give me some time. |
Hi, I have tried to use the following commands to convert RetinaNet, I'm quite sure these commands work: |
Yeah, because the ONNX symbolic between Pytorch1.3 and Pytorch1.5 are different. However, I did not find the bug your report when I use Pytorch1.5. |
I think it's because you used the incorrect branch. You should pull my PR and checkout to it. |
This is the correct branch. I have no idea about your bug as I can not reproduce it. |
mmdet/models/necks/fpn.py
Outdated
@@ -182,6 +182,9 @@ def forward(self, inputs): | |||
**self.upsample_cfg) | |||
else: | |||
prev_shape = laterals[i - 1].shape[2:] | |||
# convert prev_shape from torch.Size to tuple | |||
# so that we can convert F.interpolate into ONNX | |||
prev_shape = tuple(int(e) for e in prev_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prev_shape = tuple(laterals[i - 1].shape[2:])
tools/pytorch2onnx.py
Outdated
import torch | ||
from mmcv.ops import RoIAlign, RoIPool | ||
from mmcv.onnx.symbolic import register_extra_symbolics |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can raise an error message if mmcv version is low.
I meet a problem when I convert retinanet to onnx, could you help me?
My environment is pytorch1.5 mmdet2.3.0 mmcv1.0.5. |
Hi. |
Thanks for your rapid reply. |
/localdev/anaconda3/envs/mmdet01/lib/python3.7/site-packages/torch/onnx/symbolic_registry.py", line 91, in get_registered_op |
I have met this problem, it caused by the version of pytorch. Update the pytorch version could solve this problem. My version is pytorch1.6 |
@tianwen0110 好的,谢谢! |
Support convert RetinaNet from Pytorch to ONNX.
We can verify the computation results between Pytorch and ONNX.
We do several things in this PR:
[1] Replace some Pytorch op that are not supported by ONNX
[2] Replace some dynamic shape by static shape, as ONNX only support constant shape in some case
[3] Fix some bugs in Pytorch1.3 while converting to ONNX, which may cause numerical error while running by onnxruntime
[4] Update tool/pytorch2onnx.py file with our new API