# LightlyTrain - Instance Segmentation - ONNX Export

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lightly-ai/lightly-train/blob/main/examples/notebooks/object_detection_export.ipynb)

This notebook demonstrates how to export an object detection model to ONNX and TensorRT.

The notebook covers the following steps:
1. Install LightlyTrain
2. Export a trained EoMT model to ONNX
3. Run inference with the ONNX model

> **Important**: When running on Google Colab make sure to select a GPU runtime. You can do this by going to `Runtime` > `Change runtime type` and selecting a GPU hardware accelerator.

## Installation

LightlyTrain can be installed directly via `pip`:

In [None]:
!pip install "lightly-train[onnx,onnxruntime,onnxslim]"

## Export to ONNX

### Load the model weights

Then load the model with LightlyTrain's `load_model` function. This will automatically download the model weights and load the model.

In [None]:
import lightly_train

model = lightly_train.load_model("dinov3/vits16-eomt-inst-coco")

### Download an example image

Download an example image for inference with the following command:

In [None]:
!wget -O image.jpg http://images.cocodataset.org/val2017/000000039769.jpg

### Preprocessing

In [None]:
import torchvision.transforms.v2 as T
from PIL import Image
from torchvision.transforms.functional import pil_to_tensor

# Load image with PIL.
image_pil = Image.open("image.jpg").convert("RGB")

# Convert PIL image to tensor for plotting.
image_tensor = pil_to_tensor(image_pil)

# Define pre-processing transforms.
w, h = image_pil.size
transforms = T.Compose(
    [
        T.Resize((model.image_size)),
        T.ToTensor(),
        T.Normalize(**model.image_normalize),
    ]
)

# Apply transforms for ONNX inference.
image_tensor_tranformed = transforms(image_pil)[None]

### Get the model predictions for reference

We define a helper function to visualize the predictions.
The function will be used to compare the predictions from PyTorch and ONNX models.

In [None]:
import matplotlib.pyplot as plt
import torch
from torchvision.utils import draw_segmentation_masks


def visualize_predictions(image, masks, labels, classes):
    # Convert to boolean masks if needed.
    if masks.dtype != torch.bool:
        masks = masks > 0.5

    # Draw segmentation masks.
    image_with_masks = draw_segmentation_masks(image, masks, alpha=0.5)

    plt.imshow(image_with_masks.permute(1, 2, 0))
    plt.axis("off")
    plt.show()

    # Print detected classes.
    print("Detected:", [classes[l.item()] for l in labels])

In [None]:
# Get predictions from the PyTorch model.
prediction = model.predict(image_tensor, threshold=0.5)

# Visualize predictions from the PyTorch model.
visualize_predictions(
    image_tensor,
    masks=prediction["masks"],
    labels=prediction["labels"],
    classes=model.classes,
)

### Export the model to ONNX

In [None]:
# Export the PyTorch model to ONNX.
model.export_onnx(
    out="model.onnx",
    validate=False,
    # batch_size=1, # Set custom batch size, default is 1.
    # height=512, # Set custom height and width, default is model.image_size.
    # width=512,
    # opset_version=15, # Set custom ONNX opset version, default is automatically determined by ONNX.
    # verify=False, # Disable ONNX model verification, default is True.
    # simplify=False, # Disable model simplification, default is True.
)

### Run inference with the ONNX model

In [None]:
import onnxruntime as ort
import torch.nn.functional as F

# Create an ONNX Runtime session.
sess = ort.InferenceSession("model.onnx")

# Run inference.
labels_onnx, masks_onnx, scores_onnx = sess.run(
    output_names=None,
    input_feed={
        "images": image_tensor_tranformed.numpy(),
    },
)

# Remove batch dimension.
labels_onnx = labels_onnx[0]
masks_onnx = masks_onnx[0]
scores_onnx = scores_onnx[0]

# Filter by score (ONNX export returns all queries).
keep = scores_onnx > 0.5
labels_onnx = labels_onnx[keep]
masks_onnx = masks_onnx[keep]
scores_onnx = scores_onnx[keep]

# Resize ONNX predictions to original image size if necessary.
# This is only needed if the original image size is different from the model input size.
masks_onnx = torch.from_numpy(masks_onnx)
masks_onnx = F.interpolate(
    masks_onnx.float().unsqueeze(1), size=(h, w), mode="nearest"
).squeeze(1)

# Visualize predictions from the ONNX model.
visualize_predictions(
    image_tensor,
    masks=masks_onnx,
    labels=torch.from_numpy(labels_onnx),
    classes=model.classes,
)

**Note**: There might be small visual differences between the masks predicted by
the PyTorch model with `model.predict` and the ONNX model. This is because the PyTorch
model uses `transforms` that resize the image slightly differently compared to the fixed
size resize used for ONNX.