### Save the Lightning MNIST model to ONNX format

In [10]:
import numpy as np

from torchvision import transforms

# from here we get MNIST dataset
from torchvision.datasets import MNIST

from light_mnist_cnn import LitMNISTCNN

In [11]:
# where we're storing the downloaded dataset
PATH_DATASETS = "."

#### Reload the model from a checkpoint and save in ONNX format

In [12]:
CKPT_PATH = "/home/datascience/pytorch-on-oci/ch-04/checkpoint_mnist/best.ckpt"

model = LitMNISTCNN.load_from_checkpoint(CKPT_PATH)

In [13]:
# get the summary of CNN architecture

# I have moved the Lightmodule defining the CNN to n external py file (see import)

model

LitMNISTCNN(
  (model): Sequential(
    (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1))
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): ReLU()
    (5): Dropout(p=0.5, inplace=False)
    (6): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): ReLU()
    (9): Dropout(p=0.5, inplace=False)
    (10): Flatten(start_dim=1, end_dim=-1)
    (11): Linear(in_features=576, out_features=256, bias=True)
    (12): ReLU()
    (13): Dropout(p=0.1, inplace=False)
    (14): Linear(in_features=256, out_features=10, bias=True)
  )
  (val_accuracy): Accuracy()
  (test_accuracy): Accuracy()
)

In [19]:
# we need an input sample to save in onnx format

# the definition of transforms over images
img_transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        # normalization is clarified here
        # https://discuss.pytorch.org/t/normalization-in-the-mnist-example/457
        transforms.Normalize((0.1307,), (0.3081,)),
    ]
)

# we take an input image from the dataset
# when we load the dataset we apply transforms as expected from the model
mnist_test = MNIST(".", train=False, download=True, transform=img_transforms)

# index of the image for the test
INDEX = 5

# take a sample
img_tensor, label = mnist_test[INDEX]

# make it a batch
input_batch = img_tensor.unsqueeze(0)

#### Save in ONNX format

In [20]:
ONNX_FILE = "light_cnn.onnx"

model.to_onnx(ONNX_FILE, input_batch, export_params=True)

#### Test using onnxruntime

In [24]:
import onnx
import onnxruntime

In [29]:
# check the model (takes some seconds)
onnx_model = onnx.load(ONNX_FILE)

try:
    onnx.checker.check_model(onnx_model, full_check=True)
except onnx.checker.ValidationError as e:
    print("The model is invalid: %s" % e)
else:
    print("The model is valid!")

The model is valid!


In [22]:
# only at setup
ort_session = onnxruntime.InferenceSession(ONNX_FILE)

input_name = ort_session.get_inputs()[0].name

In [23]:
# as input you need to pass a np array, not torch tensor
# otherwise it wouldn't be... independent from the framework !
ort_inputs = {input_name: input_batch.numpy()}

out_class = np.argmax(ort_session.run(None, ort_inputs))

print()
print(f"Predicted label is: {out_class}, expected label is: {label}")


Predicted label is: 1, expected label is: 1
