Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNX Inferencing Error #7

Closed
mohamadmansourX opened this issue Sep 28, 2022 · 3 comments
Closed

ONNX Inferencing Error #7

mohamadmansourX opened this issue Sep 28, 2022 · 3 comments

Comments

@mohamadmansourX
Copy link
Contributor

Hello,
Thank you for this great project to properly convert SparseInst to Onnx/TensorRT.

I started training through your updated implementation using a customized version of configs/sparse_inst_r50_giam_aug.yaml and 10 classes

I converted to onnx successfully

python3 convert_onnx.py --config-file configs/sparse_inst_r50_giam_aug.yaml --output on2/myonnx2.onnx --image assets/figures/t1.jpg --opts MODEL.WEIGHTS output/checkpoints/model_0001999.pth

Then when trying to inference using

python3 eval_tensorrt_onnx.py  -c 0.2 --width_resized 640 --height_resized 640 --input assets/figures/* --use_onnx --onnx_engine on2/myonnx2.onnx --output_onnx on2/ --save_image

I'm receiving the below error!

onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from on2/myonnx2.onnx failed:Type 
Error: Type parameter (T) of Optype (Add) bound to different types (tensor(float) and tensor(double) in node (Add_398).
@leandro-svg
Copy link
Owner

leandro-svg commented Sep 28, 2022

Dear @mohamadmansourX, thank you for saying that 😉
Would you mind sharing your ONNX file with me? Such that I can check it out on netron with the verbose output.

On my ONNX, the last Add node I have a the Add_319 which is link to the "rescoring_mask" definition in the sparseinst.py, called line 229. If it is the case, you should be able to change line 21 which define mask_pred_ as a float that could possibily be converted to double.
Let's see from where your Add_398 comes from ...

@mohamadmansourX
Copy link
Contributor Author

Hello, thank you for your reply.
Yeah as you mentioned, after debugging with netron, the issue is in this line:

return scores * ((masks * mask_pred_).sum([1, 2]) / (mask_pred_.sum([1, 2]) + 1e-6))

Adding the terms mask_pred_.sum([1, 2]) and 1e-6 is causing that issue.
Tried changing 1e-6 to torch.tensor(1e-6).float() but for some reason it's still being interpreted as a tensor(double)!

@mohamadmansourX
Copy link
Contributor Author

mohamadmansourX commented Oct 4, 2022

Done!!
I solved it by converting the array to double then float rather than 1e-6. Not sure why 1e-6 was mapped to double no matter what I do.
My solution:

@torch.jit.script
def rescoring_mask(scores, mask_pred, masks):
    mask_pred_ = mask_pred.float()
    factrr = mask_pred_.sum([1, 2]).double()
    factrr = factrr + 1e-6
    return scores * ((masks * mask_pred_).sum([1, 2]) / factrr.float())

mohamadmansourX added a commit to mohamadmansourX/SparseInst_TensorRT that referenced this issue Oct 4, 2022
In accordance with my (issue)[leandro-svg#7] this might be a one-liner solution suggestion!
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants