## Learning on the edge: Offline Step - Training validation

In this step, we will validate the concept of transfer learning with an example dataset, and evaluate the ort training api.

We will use an animal dataset and learn to classify animals among 4 categories: `Dog`, `Cat`, `Elephant` and `Cow`

First, we begin by loading in the dataset.

In [None]:
import glob

# load the dataset files into a dictionary
def load_dataset_files():
    animals = {
        "dog": [],
        "cat": [],
        "elephant": [],
        "cow": []
    }

    for animal in animals:
        animals[animal] = glob.glob(
            f"data/images/{animal}/*")

    return animals

animals = load_dataset_files()

We now define a way to process the input image and make it ready to be processed by the onnxruntime training api.

This is done by:
  - Loading the image
  - Cropping along the longer dimension to get a square image
  - Resizing the image to be of shape [3 x 224 x 224]
  - Normalizing the tensor by subtracting the mean (0.485, 0.456, 0.406) and dividing by the standard deviation (0.229, 0.224, 0.225)

In [None]:
import numpy as np

# Preprocess the images and convert to tensors as expected by the model
# Makes the image a square and resizes it to 224x224 as is expected by
# the mobilenetv2 model
# Normalize the image by subtracting the mean (0.485, 0.456, 0.406) and
# dividing by the standard deviation (0.229, 0.224, 0.225)
def image_file_to_tensor(file):
    from PIL import Image

    image = Image.open(file)
    width, height = image.size
    if width > height:
        left = (width - height) // 2
        right = (width + height) // 2
        top = 0
        bottom = height
    else:
        left = 0
        right = width
        top = (height - width) // 2
        bottom = (height + width) // 2
    image = image.crop((left, top, right, bottom)).resize((224, 224))

    pix = np.transpose(np.array(image, dtype=np.float32), (2, 0, 1))
    pix = pix / 255.0
    pix[0] = (pix[0] - 0.485) / 0.229
    pix[1] = (pix[1] - 0.456) / 0.224
    pix[2] = (pix[2] - 0.406) / 0.225
    return pix

This block defines some training metadata variables and establishes the number of training samples and number of training epochs.

For this demo, we will pick the number of training samples as 20 per class (20 dogs, 20 cats, 20 elephants and 20 cows), and the number of epochs as 5.

In [None]:
# Training metadata
dog, cat, elephant, cow = "dog", "cat", "elephant", "cow" # labels
label_to_id_map = {
    "dog": 0,
    "cat": 1,
    "elephant": 2,
    "cow": 3
} # label to index mapping

num_samples_per_class = 20
num_epochs = 5

Now, we can define our training loop. This is where we get to interact with the ort training api.

In particular, we instantiate 3 variables that are very tightly coupled:
1. The checkpoint state - contains the state of the model parameters at any given time.
2. The training module - responsible for executing the training and eval graphs:
   - Executing the training graph results in the computation of the training loss, and the gradients associated with the model parameters.
   - Executing the eval graph results in the computation of the eval loss.
   - Switching between the train and eval mode is done by calling `module.train()` or `module.eval()`.
3. The optimizer - responsible for updating the model parameters in the direction of their computed gradients.

   

In [None]:
import onnxruntime.training.api as orttraining

# Instantiate the training session by defining the checkpoint state, module, and optimizer
# The checkpoint state contains the state of the model parameters at any given time.
checkpoint_state = orttraining.CheckpointState(
    "training_artifacts/mobilenetv2.ckpt")

model = orttraining.Module(
    "training_artifacts/mobilenetv2_training.onnx",
    checkpoint_state,
    "training_artifacts/mobilenetv2_eval.onnx",
)

optimizer = orttraining.Optimizer(
    "training_artifacts/mobilenetv2_optimizer.onnx", model
)

# Training loop
for epoch in range(num_epochs):
    model.train()
    loss = 0
    for index in range(num_samples_per_class):
        batch = []
        labels = []
        for animal in animals:
            batch.append(image_file_to_tensor(animals[animal][index]))
            labels.append(label_to_id_map[animal])
        batch = np.stack(batch)
        labels = np.array(labels, dtype=np.int32)

        # ort training api - training model execution outputs the training loss and the parameter gradients
        loss += model([batch, labels])[0]
        # ort training api - update the model parameters by taking a step in the direction of the gradients
        optimizer.step()
        # ort training api - reset the gradients to zero so that new gradients can be computed in the next run
        model.lazy_reset_grad()

    print(f"Epoch {epoch+1} Loss {loss/num_samples_per_class}")

Run inferencing on unseen data to verify the training logic.

In [None]:
from onnxruntime import InferenceSession
from onnxruntime.capi import _pybind_state as C

from IPython.display import Image, display

# ort training api - export the model for so that it can be used for inferencing
model.export_model_for_inferencing("inference_artifacts/inference.onnx", ["output"])

# Run inference on the exported model
session = InferenceSession("inference_artifacts/inference.onnx", providers=C.get_available_providers())

def softmax(logits):
    return (np.exp(logits)/np.exp(logits).sum())

def predict(test_file, test_name):
    logits = session.run(["output"], {"input": np.stack([image_file_to_tensor(test_file)])})
    probabilities = softmax(logits) * 100
    display(Image(filename=test_file))
    print_prediction(probabilities, test_name)

def print_prediction(prediction, test_name):
    print(f"test\t{dog}\t{cat}\t{elephant}\t{cow}")
    print("-------------------------------------------------")
    print(f"{test_name}\t{prediction[0][0][0]:.2f}\t{prediction[0][0][1]:.2f}\t{prediction[0][0][2]:.2f}\t\t{prediction[0][0][3]:.2f}")


In [None]:
# Test on sample image (test1.jpg)
predict("inference_artifacts/test1.jpg", "test1")

In [None]:
# Test on another sample image (test2.jpg)
predict("inference_artifacts/test2.jpg", "test2")