-
Notifications
You must be signed in to change notification settings - Fork 3.2k
How to save inference onnx model? #14185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Did you call https://pytorch.org/tutorials/beginner/saving_loading_models.html?highlight=eval
|
Thanks for your reply,I set torch model to inference mode before eval_step: |
I'm not familiar with ORTTrainer. Apparently it's deprecated though. @baijumeswani is there documentation on how to a) convert to using ORTModule and b) exporting the inferencing model from that? |
@skottmckay thanks for tagging me. @ArtyZe I would recommend trying out Here is a simple example on how you can use it: # User defined original pytorch module
class MyModule(torch.nn.Module):
...
# The original pytorch Module instantiated
pt_model = MyModule(...)
# Module that uses onnxruntime execution engine.
# It shares the model parameters with the original pytorch model.
# So, training the ort_model is like training the pt_model, just leveraging the onnxruntime acceleration.
ort_model = ORTModule(pt_model)
# run training and eval as you would with the pt_model (using the same api)
...
# after training the model, the original pt_model will have the updated weights
# you can simply use torch.onnx.export to export the trained model for inference with onnxruntime.
# use the pt_model here for export since we want to export the original pytorch model
torch.onnx.export(pt_model, training=torch.onnx.TrainingMode.EVAL, ...)
# use exported model for inferencing. Hope my above snippet is helpful. Please also checkout this example on how to use |
Great! Thanks for your reply @skottmckay @baijumeswani error log:
|
@baijumeswani hello,Could you please give me an example for how to train a onnx model? |
@ArtyZe do you mean an example where the starting point is an onnx model as opposed to a pytorch model? |
@baijumeswani Yes, I have an untrained .onnx model converted from maybe tf or caffe,or an untrained .onnx model converted from PyTorch (but apply some specific optimizations, like node split and node fusion), and now I need to retrain it in ort, which ORTTrainer supported before(But ORTTrainer does not support save inference onnx model :( ). |
oh. Unfortunately, I am not sure how advanced your use case is. But we have python and C++ training apis that can work with small models (https://github.com/microsoft/onnxruntime/tree/main/orttraining/orttraining/python/training/api). This api supports an onnx model as a starting point. However, it may not be robust enough to support very large model training. You can give it a try if you like. Apologies if this was not helpful enough. |
That's really a lot of incomplete modules ...., But what I suggest is if ort wants to make orttraining more popular, combine ORTTrainer's graph(support onnx and Pytorch model, ) and ORTModule's training api(use Pytorch's optimizer, loss func, train(), eval()) is the best solution. Accessibility is more important than training speed for orttraining right now 💯 , I can't even find a working orttraining module for a not complex task |
Describe the issue
Now I can build my own training session from torch net, but when I save onnx model after training, BatchNormalization is in training mode and can not fuse to conv. What should I do to save inference model ?

current format:
expect format:

To reproduce
Urgency
No response
ONNX Runtime Installation
Built from Source
ONNX Runtime Version or Commit ID
1.8.1
PyTorch Version
3.7
Execution Provider
CUDA
Execution Provider Library Version
No response
The text was updated successfully, but these errors were encountered: