# LightlyTrain - Panoptic Segmentation - ONNX and TensorRT Export

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lightly-ai/lightly-train/blob/main/examples/notebooks/panoptic_segmentation_export.ipynb)

This notebook demonstrates how to export a panoptic segmentation model to ONNX and TensorRT.

The notebook covers the following steps:
1. Install LightlyTrain
2. Export a trained EoMT model to ONNX
3. Export a trained EoMT model to TensorRT
4. Run inference with the TensorRT engine

> **Important**: When running on Google Colab make sure to select a GPU runtime. You can do this by going to `Runtime` > `Change runtime type` and selecting a GPU hardware accelerator.

## Installation

LightlyTrain can be installed directly via `pip`:

In [None]:
!pip install "lightly-train[onnx,onnxruntime,onnxslim]"

## Export to ONNX

### Load the model weights

Then load the model with LightlyTrain's `load_model` function. This will automatically download the model weights and load the model.

In [None]:
import lightly_train

model = lightly_train.load_model("dinov3/vits16-eomt-panoptic-coco")

### Download an example image

Download an example image for inference with the following command:

In [None]:
!wget -O image.jpg http://images.cocodataset.org/val2017/000000039769.jpg

### Preprocessing

In [None]:
import torch
import torchvision.transforms.v2 as T
from PIL import Image
from torchvision.transforms.functional import pil_to_tensor

# Load image with PIL.
image_pil = Image.open("image.jpg").convert("RGB")

# Convert PIL image to tensor for plotting.
image_tensor = pil_to_tensor(image_pil)

# Define pre-processing transforms.
w, h = image_pil.size
transforms = T.Compose(
    [
        T.Resize((model.image_size)),
        T.ToTensor(),
        T.Normalize(**model.image_normalize),
    ]
)

# Apply transforms for ONNX and TensorRT inference.
image_tensor_transformed = transforms(image_pil)[None]

### Get the model predictions for reference

We define a helper function to visualize the predictions.
The function will be used to compare the predictions from PyTorch, ONNX and
TensorRT models.

In [None]:
import matplotlib.pyplot as plt
from torchvision.utils import draw_segmentation_masks


def visualize_segmentations(image, masks, segment_ids):
    # masks is (H, W, 2)
    # segment_ids is (num_segments)

    # Convert masks to boolean masks for each segment
    # masks[..., 1] contains the segment id for each pixel.
    # We add a mask for unassigned pixels (segment id -1).
    masks_bool = torch.stack(
        [masks[..., 1] == -1]
        + [masks[..., 1] == segment_id for segment_id in segment_ids]
    )

    # Create colors for visualization
    colors = [(0, 0, 0)] + [
        [int(color * 255) for color in plt.cm.tab20c(i / len(segment_ids))[:3]]
        for i in range(len(segment_ids))
    ]

    image_with_masks = draw_segmentation_masks(
        image, masks_bool, colors=colors, alpha=1.0
    )

    fig, axs = plt.subplots(1, 2, figsize=(12, 8))
    axs[0].imshow(image.permute(1, 2, 0))
    axs[0].axis("off")
    axs[1].imshow(image_with_masks.permute(1, 2, 0))
    axs[1].axis("off")
    fig.show()

In [None]:
# Get predictions from the PyTorch model.
results = model.predict(image_tensor)

# Visualize predictions from the PyTorch model.
visualize_segmentations(
    image_tensor, masks=results["masks"], segment_ids=results["segment_ids"]
)

### Export the model to ONNX

In [None]:
# Export the PyTorch model to ONNX.
model.export_onnx(
    out="model.onnx",
    # batch_size=1, # Panoptic segmentation models only support batch size 1.
    # height=640, # Set custom height and width, default is model.image_size.
    # width=640,
    # opset_version=17, # Set custom ONNX opset version.
    # verify=False, # Disable ONNX model verification, default is True.
    # simplify=False, # Disable model simplification, default is True.
)

### Run inference with the ONNX model

In [None]:
import onnxruntime as ort
import torch.nn.functional as F

# Create an ONNX Runtime session.
sess = ort.InferenceSession("model.onnx")

# Run inference.
outputs = sess.run(
    output_names=None,
    input_feed={
        "images": image_tensor_transformed.numpy(),
    },
)

masks_onnx = torch.from_numpy(outputs[0])
segment_ids_onnx = torch.from_numpy(outputs[1][0])
scores_onnx = torch.from_numpy(outputs[2][0])

# Resize ONNX predictions to original image size if necessary.
masks_onnx = masks_onnx.permute(0, 3, 1, 2).float()  # (1, 2, H_model, W_model)
masks_onnx = F.interpolate(
    masks_onnx, size=(h, w), mode="nearest"
)  # (1, 2, H_orig, W_orig)
masks_onnx = masks_onnx.permute(0, 2, 3, 1).long()  # (1, H_orig, W_orig, 2)

# Visualize predictions from the ONNX model.
visualize_segmentations(image_tensor, masks=masks_onnx[0], segment_ids=segment_ids_onnx)

## Export to TensorRT

### Requirements

TensorRT is not part of LightlyTrainâ€™s dependencies and must be installed separately. Installation depends on your OS, Python
version, GPU, and NVIDIA driver/CUDA setup. See the [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/latest/installing-tensorrt/installing.html) for more details.

On CUDA 12.x systems you can often install the Python package via:

In [None]:
!pip install tensorrt-cu12

In [None]:
# Get the TensorRT engine.
model.export_tensorrt(
    out="model.trt",
)

### Run inference with the TensorRT engine

In [None]:
import numpy as np
import tensorrt as trt


class TRT:
    def __init__(self, engine_path: str, device: str = "cuda:0", verbose: bool = False):
        self.device = torch.device(device)
        self.trt_logger = trt.Logger(trt.Logger.VERBOSE if verbose else trt.Logger.INFO)

        trt.init_libnvinfer_plugins(self.trt_logger, "")
        runtime = trt.Runtime(self.trt_logger)
        with open(engine_path, "rb") as f:
            self.engine = runtime.deserialize_cuda_engine(f.read())
        if self.engine is None:
            raise RuntimeError("Failed to load TensorRT engine.")

        self.context = self.engine.create_execution_context()

        # IO names in engine order.
        self.io_names = [
            self.engine.get_tensor_name(i) for i in range(self.engine.num_io_tensors)
        ]
        self.in_names = [
            n
            for n in self.io_names
            if self.engine.get_tensor_mode(n) == trt.TensorIOMode.INPUT
        ]
        self.out_names = [
            n
            for n in self.io_names
            if self.engine.get_tensor_mode(n) == trt.TensorIOMode.OUTPUT
        ]

        # Allocate buffers once (uses engine's static shapes).
        self.buffers = {}
        self.bindings = []
        for name in self.io_names:
            shape = list(self.context.get_tensor_shape(name))
            # Handle dynamic shapes if possible, or just fail/warn?
            # For panoptic segmentation, "masks", "segment_ids", "scores" have dynamic shapes.
            # We can't pre-allocate exact buffers without knowing the output size.
            # This TRT wrapper is simplified and might fail for dynamic output shapes.
            np_dtype = trt.nptype(self.engine.get_tensor_dtype(name))
            torch_dtype = torch.from_numpy(np.empty((), dtype=np_dtype)).dtype

            # Simple heuristic for dynamic shapes (-1).
            # We treat -1 as a large enough value for this example, or just 100 segments?
            # This is brittle.
            processed_shape = [
                s if s != -1 else 100 for s in shape
            ]  # Assuming max 100 segments/batch size 1?

            t = torch.empty(
                tuple(processed_shape), device=self.device, dtype=torch_dtype
            )
            self.buffers[name] = t
            self.bindings.append(t.data_ptr())

            # For real dynamic output handling, we need a more complex allocator or
            # use new TRT APIs.

    @torch.no_grad()
    def __call__(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        # Copy inputs into preallocated buffers.
        for name in self.in_names:
            self.buffers[name].copy_(inputs[name].to(self.device))

        # Set input shapes if dynamic.
        for name in self.in_names:
            self.context.set_input_shape(name, inputs[name].shape)

        ok = self.context.execute_v2(self.bindings)
        if not ok:
            raise RuntimeError("TensorRT execution failed")

        return {name: self.buffers[name] for name in self.out_names}

In [None]:
# Instantiate the TensorRT model.
trt_model = TRT("model.trt")

# Run inference with the TensorRT model.
outputs_trt = trt_model({"images": image_tensor_transformed.to(trt_model.device)})

masks_trt = outputs_trt["masks"][0]
segment_ids_trt = outputs_trt["segment_ids"][0]
scores_trt = outputs_trt["scores"][0]

# Filter out segments with scores below threshold. This is required because TensorRT
# returns scores and segment ids for all possible segments.
keep = scores_trt > 0.8
segment_ids_trt = segment_ids_trt[keep]
scores_trt = scores_trt[keep]

# Resize TensorRT predictions to original image size if necessary.
masks_trt = masks_trt.permute(2, 0, 1).float()
masks_trt = masks_trt.unsqueeze(0)
masks_trt = F.interpolate(masks_trt, size=(h, w), mode="nearest")
masks_trt = masks_trt.squeeze(0).permute(1, 2, 0).long()

# Visualize predictions from the TensorRT model.
visualize_segmentations(
    image_tensor, masks=masks_trt.cpu(), segment_ids=segment_ids_trt.cpu()
)