In this Notebook we'll build a simple neural network in ORT and train it to recognize handwritten digits using the MNIST dataset.

This tutorial has two sections:

1. Offline Phase - Preparing training artifacts that will be consumed in the training phase.
2. Training Phase - Train the model on the device.


#### Importing libraries

Make sure to install onnxruntime-training's nightly version. Check [our website (click Optimize Training)](https://onnxruntime.ai/getting-started) for instructions on downloading the onnxruntime-training nightly Python package.

In [None]:
import onnxruntime.training.onnxblock as onnxblock
from onnxruntime.training.api import CheckpointState, Module, Optimizer
from onnxruntime.training import artifacts
from onnxruntime import InferenceSession
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import torch
import onnx
import io
import netron
import evaluate

## 1 - Offline Step

To run your training loop, first you need to generate training, eval (optional) and optimizer graphs.

We expect the users to have an onnx forward only model, you can generate this model with different ways, in this example we will be using torch.export to generate this model.

In [None]:
# Pytorch class that we will use to generate the graphs.
class MNISTNet(torch.nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(MNISTNet, self).__init__()

        self.fc1 = torch.nn.Linear(input_size, hidden_size)
        self.relu = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(hidden_size, num_classes)

    def forward(self, model_input):
        out = self.fc1(model_input)
        out = self.relu(out)
        out = self.fc2(out)
        return out

# Create a MNISTNet instance.
device = "cpu"
batch_size, input_size, hidden_size, output_size = 64, 784, 500, 10
pt_model = MNISTNet(input_size, hidden_size, output_size).to(device)

### Generating forward only graph.

In [None]:
# Generate a random input.
model_inputs = (torch.randn(batch_size, input_size, device=device),)

model_outputs = pt_model(*model_inputs)
if isinstance(model_outputs, torch.Tensor):
    model_outputs = [model_outputs]
    
input_names = ["input"]
output_names = ["output"]
dynamic_axes = {"input": {0: "batch_size"}, "output": {0: "batch_size"}}

f = io.BytesIO()
torch.onnx.export(
    pt_model,
    model_inputs,
    f,
    input_names=input_names,
    output_names=output_names,
    opset_version=14,
    do_constant_folding=False,
    training=torch.onnx.TrainingMode.TRAINING,
    dynamic_axes=dynamic_axes,
    export_params=True,
    keep_initializers_as_inputs=False,
)
onnx_model = onnx.load_model_from_string(f.getvalue())

##### After creating forward only graph, we can now create the training graph.

**Method 1:** 

Alternatively, you can use the generate_artifacts function provided by the onnxblock library. This function automatically generates a training graph based on the forward-only graph and the specified loss function.

In [None]:
requires_grad = [name for name, param in pt_model.named_parameters() if param.requires_grad]

frozen_params = [name for name, param in pt_model.named_parameters() if not param.requires_grad]

artifacts.generate_artifacts(
    onnx_model,
    optimizer=artifacts.OptimType.AdamW,
    loss=artifacts.LossType.CrossEntropyLoss,
    requires_grad=requires_grad,
    frozen_params=frozen_params,
    artifact_directory="data",
    additional_output_names=["output"])

**Method 2:** 

The first step is creating a simple class that inherits from onnxblock.TrainingModel, and define the loss function.
the build function defines the output of our model.

In [None]:
# Creating a class with a Loss function.
class MNISTTrainingBlock(onnxblock.TrainingBlock):
    def __init__(self):
        super(MNISTTrainingBlock, self).__init__()
        self.loss = onnxblock.loss.CrossEntropyLoss()

    def build(self, output_name):
        return self.loss(output_name), output_name

In [None]:
# Build the onnx model with loss
training_block = MNISTTrainingBlock()
for param in onnx_model.graph.initializer:
    print(param.name)
    training_block.requires_grad(param.name, True)

# Building training graph and eval graph.
model_params = None
with onnxblock.base(onnx_model):
    _ = training_block(*[output.name for output in onnx_model.graph.output])
    training_model, eval_model = training_block.to_model_proto()
    model_params = training_block.parameters()

# Building the optimizer graph
optimizer_block = onnxblock.optim.AdamW()
with onnxblock.empty_base() as accessor:
    _ = optimizer_block(model_params)
    optimizer_model = optimizer_block.to_model_proto()

In [None]:
# Saving all the files to use them later for the training.
onnxblock.save_checkpoint(training_block.parameters(), "data/checkpoint")
onnx.save(training_model, "data/training_model.onnx")
onnx.save(optimizer_model, "data/optimizer_model.onnx")
onnx.save(eval_model, "data/eval_model.onnx")

#### You can use netron to visualize the graphs.
This is an example of how an eval graph looks like 

In [None]:
netron.start("data/eval_model.onnx")

![](graph.png)

## 2 - Data Preparation
we're going to use datasets to load the MNIST Dataset and then we'll wrap it in a DataLoader.

In [None]:
batch_size = 64
train_kwargs = {'batch_size': batch_size}
test_batch_size = 1000
test_kwargs = {'batch_size': test_batch_size}

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
    ])

dataset1 = datasets.MNIST('../data', train=True, download=True,
                    transform=transform)
dataset2 = datasets.MNIST('../data', train=False,
                    transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

## 3 - Initialize Module and Optimizer
We will use the saved files to initialize the state, model and optimizer.
Note that the eval graph is optional.

In [None]:
# Create checkpoint state.
state = CheckpointState.load_checkpoint("data/checkpoint")

# Create module.
model = Module("data/training_model.onnx", state, "data/eval_model.onnx")

# Create optimizer.
optimizer = Optimizer("data/optimizer_model.onnx", model)

## 4 - Run Training and Testing Loops
In this step we will define training and testing loops.
The steps for training are simple :

1 - set model to train mode : model.train()

2 - prepare the input by making sure all inputs are numpy arrays

3 - pass the input to the model : model(input)

4 - call optimizer.step()


In [None]:
# Util function to convert logits to predictions.
def get_pred(logits):
    return np.argmax(logits, axis=1)

# Training Loop :
def train(epoch):
    model.train()
    losses = []
    for _, (data, target) in enumerate(train_loader):
        forward_inputs = [data.reshape(len(data),784).numpy(),target.numpy().astype(np.int64)]
        train_loss, _ = model(*forward_inputs)
        optimizer.step()
        model.lazy_reset_grad()
        losses.append(train_loss)

    print(f'Epoch: {epoch+1},Train Loss: {sum(losses)/len(losses):.4f}')

# Test Loop :
def test(epoch):
    model.eval()
    losses = []
    metric = evaluate.load('accuracy')

    for _, (data, target) in enumerate(train_loader):
        forward_inputs = [data.reshape(len(data),784).numpy(),target.numpy().astype(np.int64)]
        test_loss, logits = model(*forward_inputs)
        metric.add_batch(references=target, predictions=get_pred(logits))
        losses.append(test_loss)

    metrics = metric.compute()
    print(f'Epoch: {epoch+1}, Test Loss: {sum(losses)/len(losses):.4f}, Accuracy : {metrics["accuracy"]:.2f}')
    


In [None]:
for epoch in range(5):
    train(epoch)
    test(epoch)

## 5 - Run Inferencing
In this step we will use InferenceSession to run inferencing.

In [None]:
model.export_model_for_inferencing("data/inference_model.onnx",["output"])
session = InferenceSession('data/inference_model.onnx',providers=['CPUExecutionProvider'])

In [None]:
# getting one example from test list to try inference.
data = next(iter(test_loader))[0][0]

input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name 
output = session.run([output_name], {input_name: data.reshape(1,784).numpy()})

# plotting the picture
plt.imshow(data[0], cmap='gray')
plt.show()

print("Predicted Label : ",get_pred(output[0]))
    