-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ORT Training: Training mnist using provided sample? #3706
Comments
The recommended way to use ort training feature is with ORT python front end. You may use ORTTrainer to train a PyTorch or an ONNX model. There is a MNIST example to train a PyTorch model.
Model conversion and augmentation with gradient ops are handed in ORTTrainer and the ORT backend. If you already have an ONNX model. there is an example too (which also uses ORTTrainer). However in this case the ONNX model needs to output loss as well). Please be noted that this approach is only used for testing purpose for now. We will see if there are strong need to really support this use case:
In both cases, we like to make the train script as nature to the PyTorch training as possible. Please give it a try and let us know. |
Thank you for the response!
Our use case (training on embedded/edge devices) prevents us from using the Python API; instead we're limited to native binaries and therefore need to call into the C/C++ APIs directly. It seems like this should be possible given the existence of the training example in
Does this example require "the ONNX model.. to output loss as well" as you mentioned? If so, how would one go about adding the necessary gradient ops to construct such a model? I noticed there's also:
Is this meant to be used with |
Adding on, I tried running the
on the Conv/Relu/Maxpool MNIST model builder provided at This seems to throw an exception:
at the line
which calls into
For reference, the node type and input/output args that need grad are:
This seems to be because the gradient builder for onnxruntime/orttraining/orttraining/core/graph/gradient_builder.cc Lines 492 to 498 in 9a4d1c7
even though we only have 1 input/output def. Given that |
I attached the mnist test data and models. Please run with provided command line. As you can see, to work with cpp API, you need to follow the mnist example to build the graph with loss output so that backprop graph can be constructed. mnist.zip Thanks |
Is your feature request related to a problem? Please describe.
I see as part of the recently added ort training there are some sample examples that use it to train e.g. an MNIST network:
onnxruntime/orttraining/orttraining/models/mnist/main.cc
Line 96 in 78fde2c
These take as input an onnx model "to be trained." Is there additional documentation on the format of the model to input? I'm assuming it needs to somehow be augmented with the gradient ops – is there a canonical way to export a training model from PyTorch in onnx format, or is there a provided tool to convert a standard inference onnx model to a training one suitable for use with this runner?
Describe the solution you'd like
Documentation on the type of input model that the ort trainer examples expect.
The text was updated successfully, but these errors were encountered: