Skip to content

How to save inference onnx model? #14185

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
ArtyZe opened this issue Jan 9, 2023 · 10 comments
Closed

How to save inference onnx model? #14185

ArtyZe opened this issue Jan 9, 2023 · 10 comments
Labels
ep:CUDA issues related to the CUDA execution provider training issues related to ONNX Runtime training; typically submitted using template

Comments

@ArtyZe
Copy link

ArtyZe commented Jan 9, 2023

Describe the issue

Now I can build my own training session from torch net, but when I save onnx model after training, BatchNormalization is in training mode and can not fuse to conv. What should I do to save inference model ?
current format:
1

expect format:
0

To reproduce

2

Urgency

No response

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

1.8.1

PyTorch Version

3.7

Execution Provider

CUDA

Execution Provider Library Version

No response

@ArtyZe ArtyZe added the training issues related to ONNX Runtime training; typically submitted using template label Jan 9, 2023
@github-actions github-actions bot added the ep:CUDA issues related to the CUDA execution provider label Jan 9, 2023
@skottmckay
Copy link
Contributor

Did you call eval() on the pytorch model to set it to inferencing mode before exporting?

https://pytorch.org/tutorials/beginner/saving_loading_models.html?highlight=eval

Remember that you must call model.eval() to set dropout and batch normalization layers to evaluation mode before running inference.

@ArtyZe
Copy link
Author

ArtyZe commented Jan 16, 2023

Did you call eval() on the pytorch model to set it to inferencing mode before exporting?

https://pytorch.org/tutorials/beginner/saving_loading_models.html?highlight=eval

Remember that you must call model.eval() to set dropout and batch normalization layers to evaluation mode before running inference.

Thanks for your reply,I set torch model to inference mode before eval_step:
3
but got same training mode onnx model:
4
Is there something I missed?

@skottmckay
Copy link
Contributor

I'm not familiar with ORTTrainer. Apparently it's deprecated though.

@baijumeswani is there documentation on how to a) convert to using ORTModule and b) exporting the inferencing model from that?

@baijumeswani
Copy link
Contributor

@skottmckay thanks for tagging me.

@ArtyZe I would recommend trying out ORTModule. It should be very simple and straightforward to use.

Here is a simple example on how you can use it:

# User defined original pytorch module
class MyModule(torch.nn.Module):
    ...

# The original pytorch Module instantiated
pt_model = MyModule(...)

# Module that uses onnxruntime execution engine.
# It shares the model parameters with the original pytorch model.
# So, training the ort_model is like training the pt_model, just leveraging the onnxruntime acceleration.
ort_model = ORTModule(pt_model)

# run training and eval as you would with the pt_model (using the same api)
...

# after training the model, the original pt_model will have the updated weights
# you can simply use torch.onnx.export to export the trained model for inference with onnxruntime.
# use the pt_model here for export since we want to export the original pytorch model

torch.onnx.export(pt_model, training=torch.onnx.TrainingMode.EVAL, ...)

# use exported model for inferencing.

Hope my above snippet is helpful. Please also checkout this example on how to use ORTModule: https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py

@ArtyZe
Copy link
Author

ArtyZe commented Jan 17, 2023

@skottmckay thanks for tagging me.

@ArtyZe I would recommend trying out ORTModule. It should be very simple and straightforward to use.

Here is a simple example on how you can use it:

# User defined original pytorch module
class MyModule(torch.nn.Module):
    ...

# The original pytorch Module instantiated
pt_model = MyModule(...)

# Module that uses onnxruntime execution engine.
# It shares the model parameters with the original pytorch model.
# So, training the ort_model is like training the pt_model, just leveraging the onnxruntime acceleration.
ort_model = ORTModule(pt_model)

# run training and eval as you would with the pt_model (using the same api)
...

# after training the model, the original pt_model will have the updated weights
# you can simply use torch.onnx.export to export the trained model for inference with onnxruntime.
# use the pt_model here for export since we want to export the original pytorch model

torch.onnx.export(pt_model, training=torch.onnx.TrainingMode.EVAL, ...)

# use exported model for inferencing.

Hope my above snippet is helpful. Please also checkout this example on how to use ORTModule: https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py

Great! Thanks for your reply @skottmckay @baijumeswani
Now I can train mobilenet with ORTModule successfully. And unfortunately I got a gradient infer error for torchvision.efficientnet.efficientnet_b2

error log:

    loss.backward()
  File "/workspace/ygao/anaconda3/envs/art37/lib/python3.7/site-packages/torch/_tensor.py", line 363, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/workspace/ygao/anaconda3/envs/art37/lib/python3.7/site-packages/torch/autograd/__init__.py", line 175, in backward
    allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass
  File "/workspace/ygao/anaconda3/envs/art37/lib/python3.7/site-packages/torch/autograd/function.py", line 253, in apply
    return user_fn(self, *args)
  File "/workspace/ygao/anaconda3/envs/art37/lib/python3.7/site-packages/onnxruntime/training/ortmodule/_training_manager.py", line 158, in backward
    self._execution_agent.run_backward(backward_inputs, backward_outputs, ctx.run_info.state)
  File "/workspace/ygao/anaconda3/envs/art37/lib/python3.7/site-packages/onnxruntime/training/ortmodule/_execution_agent.py", line 163, in run_backward
    self._training_agent.run_backward(feeds, fetches, state)
RuntimeError: Error in backward pass execution: Non-zero status code returned while running SigmoidGrad node. Name:'Sigmoid_751_Grad/SigmoidGrad_0' Status Message: Sigmoid_751_Grad/SigmoidGrad_0: mismatching input shapes: {1,2112,7,7} != {1,2112,1,1}

And error node:
gradient inference from mul error
image

@ArtyZe ArtyZe closed this as completed Jan 30, 2023
@ArtyZe ArtyZe reopened this Jan 31, 2023
@ArtyZe
Copy link
Author

ArtyZe commented Jan 31, 2023

@baijumeswani hello,Could you please give me an example for how to train a onnx model?

@baijumeswani
Copy link
Contributor

@ArtyZe do you mean an example where the starting point is an onnx model as opposed to a pytorch model?
Currently, ORTModule only supports scenarios where the user begins with a PyTorch model. Could you share what you're trying to accomplish?

@ArtyZe
Copy link
Author

ArtyZe commented Feb 1, 2023

@ArtyZe do you mean an example where the starting point is an onnx model as opposed to a pytorch model? Currently, ORTModule only supports scenarios where the user begins with a PyTorch model. Could you share what you're trying to accomplish?

@baijumeswani Yes, I have an untrained .onnx model converted from maybe tf or caffe,or an untrained .onnx model converted from PyTorch (but apply some specific optimizations, like node split and node fusion), and now I need to retrain it in ort, which ORTTrainer supported before(But ORTTrainer does not support save inference onnx model :( ).

@baijumeswani
Copy link
Contributor

oh. ORTTrainer supports this kind of scenario. However, ORTTrainer has been recently deprecated since the emergence of ORTModule made ORTTrainer the less desirable trainer.

Unfortunately, ORTModule requires the starting point to be a pytorch model. It handles all onnx/onnxruntime related work behind the scenes and gives the appearance of being a torch.nn.Module. It cannot work with an onnx model as a starting point.

I am not sure how advanced your use case is. But we have python and C++ training apis that can work with small models (https://github.com/microsoft/onnxruntime/tree/main/orttraining/orttraining/python/training/api). This api supports an onnx model as a starting point. However, it may not be robust enough to support very large model training. You can give it a try if you like.

Apologies if this was not helpful enough.

@ArtyZe
Copy link
Author

ArtyZe commented Feb 8, 2023

oh. ORTTrainer supports this kind of scenario. However, ORTTrainer has been recently deprecated since the emergence of ORTModule made ORTTrainer the less desirable trainer.

Unfortunately, ORTModule requires the starting point to be a pytorch model. It handles all onnx/onnxruntime related work behind the scenes and gives the appearance of being a torch.nn.Module. It cannot work with an onnx model as a starting point.

I am not sure how advanced your use case is. But we have python and C++ training apis that can work with small models (https://github.com/microsoft/onnxruntime/tree/main/orttraining/orttraining/python/training/api). This api supports an onnx model as a starting point. However, it may not be robust enough to support very large model training. You can give it a try if you like.

Apologies if this was not helpful enough

That's really a lot of incomplete modules ...., But what I suggest is if ort wants to make orttraining more popular, combine ORTTrainer's graph(support onnx and Pytorch model, ) and ORTModule's training api(use Pytorch's optimizer, loss func, train(), eval()) is the best solution. Accessibility is more important than training speed for orttraining right now 💯 , I can't even find a working orttraining module for a not complex task

@baijumeswani baijumeswani closed this as not planned Won't fix, can't repro, duplicate, stale Jan 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:CUDA issues related to the CUDA execution provider training issues related to ONNX Runtime training; typically submitted using template
Projects
None yet
Development

No branches or pull requests

3 participants