### Test PyTorch Convolutional Network using ONNX conda env

In this second example on using ONNX together with PyTorch and PyTorch Lightning, we see as the **deployment is really easy**.

1. Actually, as soon as we have the model deployed in ONNX format, we don't need the code of the Lightning Module
2. We only need an environment with the ONNX runtime
3. torchvision here is used only to have quick access to images amd implment the needed tranforms


This NB can be run inside ONNX 1.10 for CPU on Python 3.7 conda env.

You need to install only torch and torchvision to have quick access to MNIST images and transforms.

In [5]:
import numpy as np

from torchvision import transforms
from torchvision.datasets import MNIST

import onnxruntime

In [6]:
# where we're storing the downloaded dataset
PATH_DATASETS = "."

ONNX_FILE = "light_cnn.onnx"

In [10]:
# we need an input sample to save in onnx format

# the definition of transforms over images
img_transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        # normalization is clarified here
        # https://discuss.pytorch.org/t/normalization-in-the-mnist-example/457
        transforms.Normalize((0.1307,), (0.3081,)),
    ]
)

# we take an input image from the dataset
# when we load the dataset we apply transforms as expected from the model
mnist_test = MNIST(".", train=False, download=True, transform=img_transforms)

# index of the image for the test
INDEX = 10

# take a sample
img_tensor, label = mnist_test[INDEX]

# make it a batch
input_batch = img_tensor.unsqueeze(0)

In [8]:
# only at setup
ort_session = onnxruntime.InferenceSession(ONNX_FILE)

input_name = ort_session.get_inputs()[0].name

In [11]:
# as input you need to pass a np array, not torch tensor
# otherwise it wouldn't be... independent from the framework !
ort_inputs = {input_name: input_batch.numpy()}

out_class = np.argmax(ort_session.run(None, ort_inputs))

print()
print(f"Predicted label is: {out_class}, expected label is: {label}")


Predicted label is: 0, expected label is: 0
