## Offline Step - Generate the Training Artifacts

We start with a pytorch model that has been pre-trained and export it to onnx. For this demo, we will use the `MobileNetV2` model for image classification. This model has been pretrained on the imagenet dataset that has data in 1000 categories.

For our task of image classification, we want to only classify images in 4 classes. So, we change the last layer of the model to output 4 logits instead of 1000.

In [None]:
import torch
import torchvision

model = torchvision.models.mobilenet_v2(
   weights=torchvision.models.MobileNet_V2_Weights.IMAGENET1K_V2)

# The original model is trained on imagenet which has 1000 classes.
# For our image classification scenario, we need to classify among 4 categories.
# So we need to change the last layer of the model to have 4 outputs.
model.classifier[1] = torch.nn.Linear(1280, 4)

# Export the model to ONNX.
model_name = "mobilenetv2"
torch.onnx.export(model, torch.randn(1, 3, 224, 224),
                  f"training_artifacts/{model_name}.onnx",
                  input_names=["input"], output_names=["output"],
                  dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}})

Now that the mobilenetv2 model has been exported to ONNX, we need to generate the training artifacts:
  - The training onnx `model = _gradient_(_optimize_(_stack_(inference onnx model, loss node)))`
  - The eval onnx `model = _optimize_(_stack_(inference onnx model, loss node))`
  - The optimizer onnx model - A new onnx model that takes in the model parameters as input, and updates them based on their gradients.
  - The model parameter checkpoint file - Extracted and serialized model parameters.

For this task, we will use the ONNX Runtime Python utility.

In [None]:
import onnx
from onnxruntime.training import artifacts

# Load the onnx model.
onnx_model = onnx.load(f"training_artifacts/{model_name}.onnx")

requires_grad = ["classifier.1.weight", "classifier.1.bias"]
frozen_params = [
   param.name
   for param in onnx_model.graph.initializer
   if param.name not in requires_grad
]


# Generate the training artifacts.
artifacts.generate_artifacts(
   onnx_model,
   requires_grad=requires_grad,
   frozen_params=frozen_params,
   loss=artifacts.LossType.CrossEntropyLoss,
   optimizer=artifacts.OptimType.AdamW,
   artifact_directory="training_artifacts"
)

All training artifacts have been saved to the folder [training_artifacts](training_artifacts). These artifacts are now ready to be deployed on the edge device.