-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unable to run evaluation on ONNX model due to unregistered MMCVDeformConv2d
#439
Comments
There should be a from mmdeploy.backend.onnxruntime import get_ops_path
get_ops_path() If there is no such library. You might miss something when you install the repo. Such as : wget https://github.com/microsoft/onnxruntime/releases/download/v1.8.1/onnxruntime-linux-x64-1.8.1.tgz
tar -zxvf onnxruntime-linux-x64-1.8.1.tgz
cd onnxruntime-linux-x64-1.8.1
export ONNXRUNTIME_DIR=$(pwd)
export LD_LIBRARY_PATH=$ONNXRUNTIME_DIR/lib:$LD_LIBRARY_PATH Or cmake ...... MMDEPLOY_TARGET_BACKENDS="ort" ..... |
Hey @grimoire Thanks for your quick reply. I do find the If clue for the following warning when testing the ONNX model?
|
In fact, I just tried building with the official Docker file here. I ran conversion and evaluation inside the docker, and I had the same error message:
|
Oh, sorry we do not have an implementation of ONNX Runtime |
Thanks for your clarification, @grimoire . I can try the implementation in MMCV as you suggested. The TensorRT engine would also work for us. That was actually the first engine that I tried. When I tried that, I encountered the following error during onnx2tensorrt. That's why I switched to the onnx engine.
Do you know what would be the potential root cause? Here's the full log:
|
Could you provide more detail about this conversion? |
@grimoire Sure, I trained a RepPoints model with ResNet50 backbone through MMDetection. I was trying to convert it to PyTorch model in a NVIDIA V100 GPU for deployment. I did convertion by following the example here:
And encountered an error message:
My hardware and environment info is reported in the issue description. Would there be any other info that you need? |
I have just add support to reppoints in #457. |
I just tested your PR in my machine and it worked perfectly :) The conversion was successful and the final metric looks good! Thanks, @grimoire One quick question: is there any suggestion/guiding material on how to pick the optimal |
Input shape should be in the range of the given config (min< [your tensor size] < max). It is always a good idea to use a static config since optimization of static shape is better. |
@Ivan-Zhou Hi, |
I met the same issue too. Any solution? Thanks. |
I faced the same issue, Could anyone please share the solution? |
My goal is to convert a Torch model trained via mmdetection to ONNX. I first followed the guide to build toolchains, dependencies, and MMDeploy in my docker environment. Following this, I've successfully convert from a Torch model to an ONNX model through
deploy/tools/test.py
. Then when I ran evaluation by following this page, it throws an error that the operationMMCVDeformConv2d
is not registered.According to here, my understanding is that
MMCVModulatedDeformConv2d
is supported as a custom ONNX operation in MMDeploy. So I want to post the issue here to see if there's any gap in my implementation with the instruction.Here's the script I ran to convert torch2onnx:
Here's the script I ran to evaluate the onnx model:
Here's my environment
Log of converting to ONNX:
Log of evaluating the ONNX model:
The text was updated successfully, but these errors were encountered: