# LightlyTrain - Semantic Segmentation - ONNX and TensorRT Export

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lightly-ai/lightly-train/blob/main/examples/notebooks/semantic_segmentation_export.ipynb)

This notebook demonstrates how to export a semantic segmentation model to ONNX and TensorRT.

The notebook covers the following steps:
1. Install LightlyTrain
2. Export a model to ONNX
3. Export a model to TensorRT
4. Run inference with the TensorRT engine

> **Important**: When running on Google Colab make sure to select a GPU runtime. You can do this by going to `Runtime` > `Change runtime type` and selecting a GPU hardware accelerator.

## Installation

LightlyTrain can be installed directly via `pip`:

In [None]:
!pip install "lightly-train[onnx,onnxruntime,onnxslim]"

**Note:** This notebook requires NVIDIA TensorRT. TensorRT is not part of LightlyTrainâ€™s
dependencies and must be installed separately. Installation depends on your OS, Python
version, GPU, and NVIDIA driver/CUDA setup.

On CUDA 12.x systems you can often install the Python package via:

In [None]:
!pip install tensorrt-cu12

## Export to ONNX

### Load the model weights

Then load the model with LightlyTrain's `load_model` function. This will automatically download the model weights and load the model.

In [None]:
import lightly_train

model = lightly_train.load_model("dinov3/vits16-eomt-coco")

### Download an example image

Download an example image for inference with the following command:

In [None]:
!wget -O image.jpg http://images.cocodataset.org/val2017/000000039769.jpg

### Preprocessing

In [None]:
import torchvision.transforms.v2 as T
from PIL import Image
from torchvision.transforms.functional import pil_to_tensor

# Load image with PIL.
image_pil = Image.open("image.jpg").convert("RGB")

# Convert PIL image to tensor for plotting.
image_tensor = pil_to_tensor(image_pil)

# Define pre-processing transforms.
w, h = image_pil.size
transforms = T.Compose(
    [
        T.Resize((model.image_size)),
        T.ToTensor(),
        T.Normalize(**model.image_normalize),
    ]
)

# Apply transforms for ONNX and TensorRT inference.
image_tensor_tranformed = transforms(image_pil)[None]

### Get the model predictions for reference

We define a helper function to visualize the predictions.
The function will be used to compare the predictions from PyTorch, ONNX and
TensorRT models.

In [None]:
import matplotlib.pyplot as plt
import torch
from torchvision.utils import draw_segmentation_masks


def visualize_segmentations(image, masks):
    masks = torch.stack([masks == class_id for class_id in masks.unique()])
    image_with_masks = draw_segmentation_masks(image, masks, alpha=0.6)
    plt.imshow(image_with_masks.permute(1, 2, 0))
    plt.axis("off")
    plt.show()

In [None]:
# Get predictions from the PyTorch model.
masks = model.predict(image_tensor)

# Visualize predictions from the PyTorch model.
visualize_segmentations(image_tensor, masks=masks)

### Export the model to ONNX

In [None]:
# Export the PyTorch model to ONNX.
model.export_onnx(out="model.onnx")

### Run inference with the ONNX model

In [None]:
import onnxruntime as ort
import torch.nn.functional as F

# Create an ONNX Runtime session.
sess = ort.InferenceSession("model.onnx")

# Run inference.
masks_onnx = sess.run(
    output_names=None,
    input_feed={
        "images": image_tensor_tranformed.data.numpy(),
    },
)

# The ONNX model does not resize masks to original image size, so we do it here.
# This is only needed if your original image size is different from the model input size.
masks_onnx = F.interpolate(torch.from_numpy(masks_onnx), size=(h, w), mode="bilinear")

# Visualize predictions from the ONNX model.
visualize_segmentations(image_tensor, masks=masks_onnx)