# MobileViT Offline Processing

On-device training requires steps that happen off the device, referred to as "offline" steps. 

This notebook contains the offline processing steps for an example using MobileViT for facial expression recognition.

## Training artifact generation

In order to train on the device, the following files are required: a model checkpoint with the weights, the training model ONNX file, the optimizer ONNX file, and the evaluation model ONNX file. 

The generate_artifacts method simplifies this process, allowing you to pass in an initial ONNX model (for example, imported from HuggingFace Transformers), specify a loss type and optimizer (both required), and will generate the required training artifacts for you.

The parameters that require gradient and the frozen parameters are taken from the initial PyTorch model.

Before passing to the generate_artifacts function, the model is configured to suit the dataset: for example, the random input that the model is built off of has the same image dimensions as the dataset images, and the number of labels is configured to reflect the dataset. 

Although this example uses PyTorch to export a HuggingFace Transformers model into an ONNX file to be passed to generate_artifacts, any method of creating or exporting an ONNX file can be used with generate_artifacts.

In [None]:
import torch
import transformers
import onnx
from onnxruntime.training import artifacts

In [None]:
# change the configuration to reflect the number of labels used in the dataset
config = transformers.MobileViTConfig.from_pretrained("apple/mobilevit-xx-small", num_labels=7)
model = transformers.MobileViTForImageClassification.from_pretrained("apple/mobilevit-xx-small", config=config, ignore_mismatched_sizes=True)

In [None]:
onnx_name = "mobilevit_init_test.onnx"

# generates random pixel values for 5 images
# originally images were 256 by 256
random_input = {"pixel_values": torch.rand(5, 3, 512, 512),
                "labels": torch.randint(0, 6, (5,))}
# random_input = torch.rand(5, 3, 256, 256)

torch.onnx.export(model, random_input, onnx_name,
                    input_names=["pixel_values", "labels"], output_names=["output"],
                    export_params=True,
                    dynamic_axes={
                        "pixel_values": {0: "batch_size"},
                        "labels": {0: "batch_size"},
                        "output": {0: "batch_size"}
                    },
                    do_constant_folding=False,
                    training=torch.onnx.TrainingMode.TRAINING)

In [None]:
requires_grad = []
frozen_params = []
for name, param in model.named_parameters():
    if param.requires_grad:
        requires_grad.append(name)
    else:
        frozen_params.append(name)

for name, param in model.named_buffers():
    frozen_params.append(name)

onnx_model = onnx.load(onnx_name)

artifacts.generate_artifacts(
    onnx_model,
    optimizer=artifacts.OptimType.AdamW,
    loss=artifacts.LossType.CrossEntropyLoss,
    requires_grad=requires_grad,
    frozen_params=frozen_params
)