In this notebook, we will generate the offline training artifacts for the MNIST web demo.

#### Importing libraries

In [None]:
from onnxruntime.training import artifacts
import torch
import onnx
import io
import netron

## 1 - Offline Step

To run your training loop, first you need to generate training, eval (optional) and optimizer graphs.

We expect the users to have an onnx forward only model, you can generate this model with different ways, in this example we will be using torch.export to generate this model.

In [None]:
# Pytorch class that we will use to generate the graphs.
class MNISTNet(torch.nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(MNISTNet, self).__init__()

        self.fc1 = torch.nn.Linear(input_size, hidden_size)
        self.relu = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(hidden_size, num_classes)

    def forward(self, model_input):
        out = self.fc1(model_input)
        out = self.relu(out)
        out = self.fc2(out)
        return out

# Create a MNISTNet instance.
device = "cpu"
batch_size, input_size, hidden_size, output_size = 64, 784, 500, 10
pt_model = MNISTNet(input_size, hidden_size, output_size).to(device)

### Generating forward only graph.

In [None]:
# Generate a random input.
model_inputs = (torch.randn(batch_size, input_size, device=device),)

input_names = ["input"]
output_names = ["output"]
dynamic_axes = {"input": {0: "batch_size"}, "output": {0: "batch_size"}}

f = io.BytesIO()
torch.onnx.export(
    pt_model,
    model_inputs,
    f,
    input_names=input_names,
    output_names=output_names,
    opset_version=14,
    do_constant_folding=False,
    training=torch.onnx.TrainingMode.TRAINING,
    dynamic_axes=dynamic_axes,
    export_params=True,
    keep_initializers_as_inputs=False,
)
onnx_model = onnx.load_model_from_string(f.getvalue())

##### After creating forward only graph, we can now create the training graph.


In [None]:
requires_grad = [name for name, param in pt_model.named_parameters() if param.requires_grad]

frozen_params = [name for name, param in pt_model.named_parameters() if not param.requires_grad]

artifacts.generate_artifacts(
    onnx_model,
    optimizer=artifacts.OptimType.AdamW,
    loss=artifacts.LossType.CrossEntropyLoss,
    requires_grad=requires_grad,
    frozen_params=frozen_params,
    additional_output_names=output_names)

#### You can use netron to visualize the graphs.
This is an example of how an eval graph looks like 

In [None]:
netron.start("data/eval_model.onnx")

![](graph.png)