# ONNX Format
In this demo we are going to go through the process of exporting our best model to ONNX format and then using the runtime for inference.

Why do this?
ONNX is an open standard format that enables model interoperability across different frameworks and platforms, making it easier to deploy models in diverse environments such as cloud, edge, or mobile devices. 

ONNX Runtime is highly optimized for performance, providing faster inference speeds through techniques like graph optimizations and support for hardware accelerators, including GPUs, CPUs, and specialized inference chips. This combination allows developers to achieve scalability, portability, and performance improvements, while simplifying integration into non-PyTorch ecosystems.

In [None]:
# Install the required modules
!pip install onnx onnxruntime

In [None]:
# RESTART YOUR NOTEBOOK FOR CHANGES TO TAKE 

## Load our best model
Before we begin we must load our best model


In [None]:
# Import modules
import torch
import torch.nn as nn
from torchvision import models

In [None]:

# Load the mobilenet_v3_large model with default weights
model = models.mobilenet_v3_large(weights=models.MobileNet_V3_Large_Weights.DEFAULT)

In [None]:
# Modify last layer of the model for 2 classes as output
model.classifier[-1] = nn.Linear(1280, 2)

In [None]:
# Load the model from checkpoint
checkpoint = torch.load('mobilenet_checkpoint.tar', weights_only=True)

In [None]:
# Load the parameters from the checkpoint
model.load_state_dict(checkpoint['model_state_dict'])

## Export our Model to ONNX format

In [None]:
# Import the module: NOTE that ONNX is built into PyTorch!
import torch.onnx

In [None]:
# Read the helper function to export
help(torch.onnx.export)

In [None]:
# Create an example output 
example_input = torch.randn(1, 3, 224, 224)

In [None]:
# Invoke export
torch.onnx.export(model, example_input, "image_classifier.onnx")

In [None]:
# Check the model consistency
import onnx

# Load it with ONNX
onnx_model = onnx.load("image_classifier.onnx")
# Check it
print(onnx.checker.check_model(onnx_model))


## Load an example image for inference

In [None]:
# Transformations are still required
from PIL import Image
from torchvision.transforms import v2

transform = v2.Compose([
    v2.Resize((224, 224)),
    v2.ToImage(), 
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], 
                 std=[0.229, 0.224, 0.225])
])

In [None]:
# Open an image
image_path = 'sample-input.jpg'
image = Image.open(image_path)

In [None]:
# Apply the transformation
transformed_image = transform(image)
transformed_image.shape

In [None]:
# Add additional dimension due to requirements: [batch_size, channels, height, width]
transformed_image = transformed_image.unsqueeze(0)
transformed_image.shape

In [None]:
# Convert our transformed image to a Numpy Array
import numpy as np

image_np = np.array(transformed_image, dtype=np.float32)

## Run inference using ONNX Runtime
The ONNX Runtime is a high-performance inference engine designed to execute models in the open ONNX format across various platforms and devices. It optimizes model execution through graph-level optimizations and supports hardware accelerators, enabling fast, scalable, and portable deployments in diverse environments.

In [None]:
# Import the runtime
import onnxruntime as ort

In [None]:
# Load the model
import onnx

onnx_model = onnx.load("image_classifier.onnx")

In [None]:
# Start on inference Session on the runtime 
session = ort.InferenceSession("image_classifier.onnx")

In [None]:
# Convert the image to a numpy array
import numpy as np 

image_np = np.array(transformed_image, dtype=np.float32)

In [None]:
# Run inference

# Create input to be passed to the model
inputs = {session.get_inputs()[0].name: image_np}
# Run the inference
outputs = session.run(None, inputs)
print(outputs) # raw outputs (logits) from final layer

In [None]:
# Get the predicted class
predicted = outputs[0][0].argmax(0)
print(predicted)

In [None]:
# Define our Dataset Class and label encoding
label_encoding = {"malignant": 0, "benign": 1}

In [None]:
# Reverse index the label_encoding dictionary 
index_to_class_map = {v: k for k, v in label_encoding.items()}
print(f"Predicted Class: {index_to_class_map[predicted.item()]}")