Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ORT Training: Training mnist using provided sample? #3706

Closed
pranav-prakash opened this issue Apr 26, 2020 · 4 comments
Closed

ORT Training: Training mnist using provided sample? #3706

pranav-prakash opened this issue Apr 26, 2020 · 4 comments
Labels
training issues related to ONNX Runtime training; typically submitted using template

Comments

@pranav-prakash
Copy link
Contributor

pranav-prakash commented Apr 26, 2020

Is your feature request related to a problem? Please describe.
I see as part of the recently added ort training there are some sample examples that use it to train e.g. an MNIST network:

void setup_training_params(TrainingRunner::Parameters& params) {

These take as input an onnx model "to be trained." Is there additional documentation on the format of the model to input? I'm assuming it needs to somehow be augmented with the gradient ops – is there a canonical way to export a training model from PyTorch in onnx format, or is there a provided tool to convert a standard inference onnx model to a training one suitable for use with this runner?

Describe the solution you'd like
Documentation on the type of input model that the ort trainer examples expect.

@pranav-prakash pranav-prakash changed the title ORT Training: Running the mnist model? ORT Training: Training the mnist model used provided example? Apr 26, 2020
@pranav-prakash pranav-prakash changed the title ORT Training: Training the mnist model used provided example? ORT Training: Training mnist used provided sample? Apr 26, 2020
@faxu faxu added the training issues related to ONNX Runtime training; typically submitted using template label Apr 28, 2020
@pranav-prakash pranav-prakash changed the title ORT Training: Training mnist used provided sample? ORT Training: Training mnist using provided sample? Apr 28, 2020
@liqunfu
Copy link
Contributor

liqunfu commented Apr 28, 2020

The recommended way to use ort training feature is with ORT python front end. You may use ORTTrainer to train a PyTorch or an ONNX model.

There is a MNIST example to train a PyTorch model.

def testMNISTTrainingAndTesting(self):

Model conversion and augmentation with gradient ops are handed in ORTTrainer and the ORT backend.

If you already have an ONNX model. there is an example too (which also uses ORTTrainer). However in this case the ONNX model needs to output loss as well). Please be noted that this approach is only used for testing purpose for now. We will see if there are strong need to really support this use case:

In both cases, we like to make the train script as nature to the PyTorch training as possible. Please give it a try and let us know.
Thanks

@pranav-prakash
Copy link
Contributor Author

Thank you for the response!

recommended way to use ort training feature is with ORT python front end

Our use case (training on embedded/edge devices) prevents us from using the Python API; instead we're limited to native binaries and therefore need to call into the C/C++ APIs directly. It seems like this should be possible given the existence of the training example in

("model_name", "model to be trained", cxxopts::value<std::string>())

Does this example require "the ONNX model.. to output loss as well" as you mentioned? If so, how would one go about adding the necessary gradient ops to construct such a model? I noticed there's also:

TERMINATE_IF_FAILED(training_session.AddLossFuncion({"SoftmaxCrossEntropy", PREDICTION_NAME, "labels", "loss", kMSDomain}));

Is this meant to be used with orttraining/orttraining/models/mnist/main.cc?

@pranav-prakash
Copy link
Contributor Author

pranav-prakash commented Apr 28, 2020

Adding on, I tried running the onnxruntime_training_mnist example binary (the code below)

("model_name", "model to be trained", cxxopts::value<std::string>())

on the Conv/Relu/Maxpool MNIST model builder provided at orttraining/tools/mnist_model_builder/mnist_conv_builder.ipynb

This seems to throw an exception:

terminate called after throwing an instance of 'onnxruntime::OnnxRuntimeException'
  what():  /home/onnxruntime/onnxruntime/orttraining/orttraining/core/graph/gradient_builder_base.h:63 onnxruntime::training::ArgDef onnxruntime::training::GradientBuilderBase::O(size_t) const i < node_->OutputDefs().size() was false. 

at the line

GradientDef node_defs = GetGradientForOp(node, output_args_need_grad, input_args_need_grad);

which calls into

GradientDef GetGradientDefs() const {

For reference, the node type and input/output args that need grad are:

Node type MaxPool
Output args need grad:
         T3
Input args need grad:
         T2

This seems to be because the gradient builder for MaxPool is defined as:

IMPLEMENT_GRADIENT_BUILDER(GetMaxPoolGradient) {
return std::vector<NodeDef>{
NodeDef("MaxPoolGrad",
{GO(0), O(1)},
{GI(0)},
SrcNodeAttributes())};
}

even though we only have 1 input/output def. Given that AveragePoolGradient doesn't have O(1) I'm not sure if this is an issue with my model or a possible bug. (ONNX operator spec mentions that the second output for MaxPool: Indices (optional) : I is optional.

@liqunfu
Copy link
Contributor

liqunfu commented Apr 29, 2020

I attached the mnist test data and models. Please run with provided command line. As you can see, to work with cpp API, you need to follow the mnist example to build the graph with loss output so that backprop graph can be constructed.

mnist.zip
./onnxruntime_training_mnist --model_name ~/mnist/mnist_gemm_simple --train_data_dir ~/mnist/mnist_data/

Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
training issues related to ONNX Runtime training; typically submitted using template
Projects
None yet
Development

No branches or pull requests

3 participants