Copyright (C) 2022, Microsoft.

## In this Notebook we'll build a simple convolutional neural network in ORT and train it to recognize handwritten digits using the MNIST dataset.

### What is ONNX Runtime Training APIS.
ONNX Runtime Training apis gives you the ability to run an end to end training loops using just onnxruntime, you will still need to generate the training , eval and optimizer graphs using pyTorch first, but after that everything depends on onnxruntime only.

These apis were introduced mainly for on device training.

#### Importing libraries

Make sure to install onnxruntime-training's nightly version.

```pip3 install onnxruntime-training -f https://download.onnxruntime.ai/onnxruntime_nightly_cu116.html```

In [1]:
import onnxruntime.training.onnxblock as onnxblock
from onnxruntime.training.api import CheckpointState, Module, Optimizer
from onnxruntime import InferenceSession
from torchvision import datasets, transforms
import numpy as np
import torch
import onnx
import io
import netron
import evaluate

  from .autonotebook import tqdm as notebook_tqdm


## Graphs Generation:
To run your training loop, first you need to generate training, eval (optional) and optimizer graphs.

In [60]:
# Pytorch class that we will use to generate the graphs.
class SimpleNet(torch.nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(SimpleNet, self).__init__()

        self.fc1 = torch.nn.Linear(input_size, hidden_size)
        self.relu = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(hidden_size, num_classes)

    def forward(self, model_input):
        out = self.fc1(model_input)
        out = self.relu(out)
        out = self.fc2(out)
        return out

# Create a Simple instance.
device = "cpu"
batch_size, input_size, hidden_size, output_size = 64, 784, 500, 10
pt_model = SimpleNet(input_size, hidden_size, output_size).to(device)

### Generating forward only graph.

In [61]:
# Generate a random input.
model_inputs = (torch.randn(batch_size, input_size, device=device),)

model_outputs = pt_model(*model_inputs)
if isinstance(model_outputs, torch.Tensor):
    model_outputs = [model_outputs]
    
dynamic_axes = {}
input_names = []
output_names = []

for i, model_input in enumerate(model_inputs):
    input_name = f"input-{i}"
    input_names.append(input_name)
    dynamic_axes[input_name] = {}
    for dim_idx in range(len(model_input.shape)):
        dynamic_axes[input_name].update({dim_idx: f"{input_name}_dim{dim_idx}"})

for i, model_output in enumerate(model_outputs):
    output_name = f"output-{i}"
    output_names.append(output_name)
    dynamic_axes[output_name] = {}
    for dim_idx in range(len(model_output.shape)):
        dynamic_axes[output_name].update({dim_idx: f"{output_name}_dim{dim_idx}"})

f = io.BytesIO()
torch.onnx.export(
    pt_model,
    model_inputs,
    f,
    input_names=input_names,
    output_names=output_names,
    opset_version=14,
    do_constant_folding=False,
    training=torch.onnx.TrainingMode.TRAINING,
    dynamic_axes=dynamic_axes,
    export_params=True,
    keep_initializers_as_inputs=False,
)
onnx_model = onnx.load_model_from_string(f.getvalue())

##### After creating forward only graph, we can now create the training graph.

The first step is creating a simple class that inherits from onnxblock.TrainingModel, and define the loss function.
the build function defines the output of our model.

In [62]:
# Creating a class with a Loss function.
class SimpleModelWithCrossEntropyLoss(onnxblock.TrainingModel):
    def __init__(self):
        super(SimpleModelWithCrossEntropyLoss, self).__init__()
        self.loss = onnxblock.loss.CrossEntropyLoss()

    def build(self, output_name):
        return self.loss(output_name), output_name

In [63]:
# Build the onnx model with loss
simple_model = SimpleModelWithCrossEntropyLoss()

# Building training graph and eval graph.
with onnxblock.onnx_model(onnx_model) as accessor:
    _ = simple_model(onnx_model.graph.output[0].name)
    eval_model = accessor.eval_model

# Building the optimizer graph
optimizer = onnxblock.optim.AdamW()
with onnxblock.onnx_model() as accessor:
    _ = optimizer(simple_model.parameters())
    optimizer_model = accessor.model

2022-11-09 20:32:20.394830072 [I:onnxruntime:Default, reshape_fusion.cc:42 ApplyImpl] Total fused reshape node count: 0
2022-11-09 20:32:20.394877470 [I:onnxruntime:Default, concat_slice_elimination.cc:36 ApplyImpl] Total fused concat node count: 0


In [64]:
# Saving all the files to use them later for the training.

trainable_params, non_trainable_params = simple_model.parameters()
onnxblock.save_checkpoint((trainable_params, non_trainable_params), "data/checkpoint.ckpt")
onnx.save(onnx_model, "data/training_model.onnx")
onnx.save(optimizer_model, "data/optimizer.onnx")
onnx.save(eval_model, "data/eval_model.onnx")

2022-11-09 20:32:21.143151087 [W:onnxruntime:Default, checkpoint.cc:187 OrtSaveInternal] Checkpoint directory exists - data may be overwritten.


## You can use netron to visualize the graphs.
This is an example of how an eval graph looks like 

In [65]:
netron.start("data/eval_model.onnx")

Serving 'data/eval_model.onnx' at http://localhost:8081


('localhost', 8081)

![](https://i.imgur.com/C7W7cw2.png)

#### After generating the required files for training, let's start training out model.
##### Let's load the datasets from torchvision.

In [66]:
batch_size = 64
train_kwargs = {'batch_size': batch_size}
test_batch_size = 1000
test_kwargs = {'batch_size': test_batch_size}

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
    ])

dataset1 = datasets.MNIST('../data', train=True, download=True,
                    transform=transform)
dataset2 = datasets.MNIST('../data', train=False,
                    transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)


#### Use the saved files to initialize the state, model and optimizer.

In [67]:
# Create checkpoint state.
state = CheckpointState("data/checkpoint.ckpt")

# Create module.
model = Module("data/training_model.onnx", state, "data/eval_model.onnx")

# Create optimizer.
optimizer = Optimizer("data/optimizer.onnx", model)

#### Training and Testing Loops definition.

In [76]:
# Util function to convert logits to predictions.
def get_pred(logits):
    return np.argmax(logits, axis=1)

# Training Loop :
def train(epoch):
    model.train()
    losses = []
    for batch_idx, (data, target) in enumerate(train_loader):
        forward_inputs = [data.reshape(len(data),784).numpy(),target.numpy().astype(np.int32)]
        train_loss, _ = model(forward_inputs)
        optimizer.step()
        model.reset_grad()
        losses.append(train_loss)

    print('Epoch: {},Train Loss: {:.4f}'.format(epoch+1, sum(losses)/len(losses)))

# Test Loop :
def test(epoch):
    model.eval()
    losses = []
    metric = evaluate.load('accuracy')

    for batch_idx, (data, target) in enumerate(train_loader):
        forward_inputs = [data.reshape(len(data),784).numpy(),target.numpy().astype(np.int32)]
        test_loss, logits = model(forward_inputs)
        metric.add_batch(references=target, predictions=get_pred(logits))
        losses.append(test_loss)

    metrics = metric.compute()
    
    print('Epoch: {}, Test Loss: {:.4f}, Accuracy : {:.2f}'.format(epoch+1, sum(losses)/len(losses), metrics['accuracy']*100))

In [78]:
for epoch in range(5):
    train(epoch)
    test(epoch)

Epoch: 1,Train Loss: 0.2199
Epoch: 1, Test Loss: 0.1269, Accuracy : 95.92
Epoch: 2,Train Loss: 0.0901
Epoch: 2, Test Loss: 0.0859, Accuracy : 97.03
Epoch: 3,Train Loss: 0.0565
Epoch: 3, Test Loss: 0.0596, Accuracy : 97.92
Epoch: 4,Train Loss: 0.0372
Epoch: 4, Test Loss: 0.0373, Accuracy : 98.75
Epoch: 5,Train Loss: 0.0281
Epoch: 5, Test Loss: 0.0539, Accuracy : 98.29


#### Save Inference model and Run inference.

In [79]:
model.export_model_for_inferencing("data/inference_model.onnx",["output-0"])
session = InferenceSession('data/inference_model.onnx',providers=['CPUExecutionProvider'])

In [80]:
# getting one example from test list to try inference.
data = next(iter(test_loader))[0][0].reshape(1,784).numpy()

input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name 
output = session.run([output_name], {input_name: data})

print("Predicted Label : ",get_pred(output[0]))
    

Predicted Label :  [7]
