## Offline Step - Generate the Training Artifacts

We start with a pytorch model that has been pre-trained and export it to onnx. For this demo, we will use the `MobileNetV2` model for image classification. This model has been pretrained on the imagenet dataset that has data in 1000 categories.

For our task of image classification, we want to only classify images in 4 classes. So, we change the last layer of the model to output 4 logits instead of 1000.

In [None]:
import torch
import torchvision

model = torchvision.models.mobilenet_v2(
    weights=torchvision.models.MobileNet_V2_Weights.IMAGENET1K_V2)

# The original model is trained on imagenet which has 1000 classes.
# For our image classification scenario, we need to classify among 4 categories.
# So we need to change the last layer of the model to have 4 outputs.
model.classifier[1] = torch.nn.Linear(1280, 4)

# Export the model to ONNX.
model_name = "mobilenetv2"
torch.onnx.export(model, torch.randn(1, 3, 224, 224),
                  f"training_artifacts/{model_name}.onnx", input_names=["input"],
                  output_names=["output"], dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}})

Now that the mobilenetv2 model has been exported to ONNX, we need to generate the training artifacts:
  - The training onnx `model = _gradient_(_optimize_(_stack_(inference onnx model, loss node)))`
  - The eval onnx `model = _optimize_(_stack_(inference onnx model, loss node))`
  - The optimizer onnx model - A new onnx model that takes in the model parameters as input, and updates them based on their gradients.
  - The model parameter checkpoint file - Extracted and serialized model parameters.

For this task, we will use the ONNX Runtime Python utility.

In [None]:
import onnx
import onnxruntime.training.onnxblock as onnxblock

# Define how the training model should look like.
# In this case, we stack the loss function on top of the original model.
class MobileNetV2BlockWithLoss(onnxblock.TrainingModel):
    def __init__(self):
        super().__init__()
        self.loss = onnxblock.loss.CrossEntropyLoss()

    def build(self, loss_node_input_name):
        return self.loss(loss_node_input_name)

training_block = MobileNetV2BlockWithLoss()

# This task is a transfer learning task. We want to only train the last layer of the model.
# So, we mark parameters associated with other layers as non-trainable.
for name, param in model.named_parameters():
    if not (name == "classifier.1.weight" or name == "classifier.1.bias"):
        training_block.requires_grad(name, False)

# Load the model from the exported inference ONNX file.
onnx_model = onnx.load(f"training_artifacts/{model_name}.onnx")
eval_model = None
optimizer_model = None

inference_model_output_name = "output"
with onnxblock.onnx_model(onnx_model) as model_accessor:
    loss_output_name = training_block(inference_model_output_name)
    eval_model = model_accessor.eval_model

optimizer_block = onnxblock.optim.AdamW()
with onnxblock.onnx_model() as model_accessor:
    optimizer_outputs = optimizer_block(training_block.parameters())
    optimizer_model = model_accessor.model

All training artifacts are generated. We can now save them to file.

In [None]:
onnxblock.save_checkpoint(training_block.parameters(), f"training_artifacts/{model_name}.ckpt")
onnx.save(onnx_model, f"training_artifacts/{model_name}_training.onnx")
onnx.save(eval_model, f"training_artifacts/{model_name}_eval.onnx")
onnx.save(optimizer_model, f"training_artifacts/{model_name}_optimizer.onnx")

All training artifacts have been saved to the folder [training_artifacts](training_artifacts). These artifacts are now ready to be deployed on the edge device.