In [None]:
import torch
from torchvision import models
import onnxruntime as ort
import numpy as np

In [None]:
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
model.eval()  

## ONNX

ONNX (Open Neural Network Exchange) is an open, framework-agnostic format that lets you export trained models so they can be run and deployed across different frameworks, runtimes, and hardware platforms.

<p align="center">
  <img src="../../assets/img/deployment/onnx_summary.png" width="400">
</p>

In [None]:
dummy_input = torch.randn(1, 3, 224, 224)

torch.onnx.export(
    model,
    dummy_input,
    "resnet18.onnx",
    export_params=True,        # store trained weights
    opset_version=11,          # widely supported
    do_constant_folding=True,  # optimize constants
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={
        "input": {0: "batch_size"},
        "output": {0: "batch_size"},
    },
)

In [None]:
session = ort.InferenceSession("resnet18.onnx")

input_name = session.get_inputs()[0].name
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)

# inference
outputs = session.run(None, {input_name: input_data})

prediction = np.argmax(outputs[0], axis=1)
print("Predicted class:", prediction)