# Convert a SegNeXt PyTorch Model to LiteRT

This notebook demonstrates how to convert a **SegNeXt** model (originally trained and published in PyTorch) into a LiteRT model using [AI Edge Torch](https://ai.google.dev/edge). The sample also shows how to optimize the resulting model with dynamic-range quantization using [AI Edge Quantizer](https://github.com/google-ai-edge/ai-edge-quantizer).

## What you'll learn

- **Setup:** Installing necessary libraries and tools to download and load the SegNeXt model.
- **Inference Validation:** Running the PyTorch model for segmentation.
- **Model Conversion:** Converting a SegNext model to LiteRT using AI Edge Torch.
- **Verifying Results:** Comparing outputs between PyTorch and LiteRT models.
- **Quantization:** Applying post-training quantization techniques to reduce model size.
- **Export and Download**: Download your newly created or optimized LiteRT model.

## Install and Import Dependencies

You can start by importing the necessary dependencies for converting the model, as well as some additional tweaks to get the `mmsegmentation` library working as expected with the AI Torch Edge Converter.

Make sure to run the following cells to set up the environment with the required libraries:

In [None]:
# Install MMCV and its dependencies.
!pip install openmim -q
!mim install mmengine -q
!mim install mmcv==2.2.0 -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.4/index.html
!pip install ftfy

# Install AI Torch Edge and Quantizer.
!pip install ai-edge-torch-nightly -q
!pip install ai-edge-quantizer-nightly -q

In [None]:
# Clone the MMSegmentation GitHub repository.
!git clone -b v1.2.2 https://github.com/open-mmlab/mmsegmentation.git

# Patch the version constraints in mmseg/__init__.py
!sed -i "s/MMCV_MAX = '2.2.0'/MMCV_MAX = '6.5.0'/g" mmsegmentation/mmseg/__init__.py

# Install MMSegmentation
%cd mmsegmentation
!pip install -e .

In [None]:
import urllib
import cv2
import math
import sys

# PyTorch, Vision, and AI Edge Torch.
import torch
import torchvision.transforms as T
import ai_edge_torch

# PIL, NumPy, IPython display.
import numpy as np
from PIL import Image
from IPython import display

# Google Colab utilities.
from google.colab import files
from google.colab.patches import cv2_imshow

# Matplotlib for visualization.
from matplotlib import gridspec
from matplotlib import pyplot as plt

# AI Edge Torch Quantization utilities.
import ai_edge_litert
from ai_edge_quantizer import quantizer, recipe

### Patch the `MMEngine` registry
We'll also patch the `Registry` to address potential naming collisions in the mmseg registry, then import our classes and create an inference object.

In [None]:
# @markdown We implemented some functions to patch the mmengine registry. <br/> Run the following cell to activate the functions.
%%writefile patch_registry.py
import logging

from mmengine.registry import Registry
from mmengine.logging import print_log
from typing import Type, Optional, Union, List

def _register_module(self,
                     module: Type,
                     module_name: Optional[Union[str, List[str]]] = None,
                     force: bool = False) -> None:
    """Register a module.

    Args:
        module (type): Module to be registered. Typically a class or a
            function, but generally all ``Callable`` are acceptable.
        module_name (str or list of str, optional): The module name to be
            registered. If not specified, the class name will be used.
            Defaults to None.
        force (bool): Whether to override an existing class with the same
            name. Defaults to False.
    """
    if not callable(module):
        raise TypeError(f'module must be Callable, but got {type(module)}')

    if module_name is None:
        module_name = module.__name__
    if isinstance(module_name, str):
        module_name = [module_name]
    for name in module_name:
        if not force and name in self._module_dict:
            existed_module = self.module_dict[name]
            print_log(
                f'{name} is already registered in {self.name} '
                f'at {existed_module.__module__}. Registration ignored.',
                logger='current',
                level=logging.INFO
            )
        self._module_dict[name] = module

Registry._register_module = _register_module


In [None]:
# Patch the MMEngine registry.
import patch_registry

# Check MMSegmentation installation.
import mmseg
print(mmseg.__version__)

# Import the `apis` and `datasets` modules.
from mmseg import apis, datasets

### Download a Sample Image
We'll retrieve an image that we'll use for our segmentation demo. Feel free to upload your own image(s) if desired.

In [None]:
import urllib

IMAGE_FILENAMES = ['test.jpg']

for name in IMAGE_FILENAMES:
    url = 'https://upload.wikimedia.org/wikipedia/commons/9/9c/Bruce_car.JPG'
    urllib.request.urlretrieve(url, name)

If you want to upload additional images, uncomment and run the cell below. Then update `IMAGE_FILENAMES` to match your uploaded file(s).

In [None]:
# from google.colab import files
# uploaded = files.upload()
#
# for filename in uploaded:
#   content = uploaded[filename]
#   with open(filename, 'wb') as f:
#     f.write(content)
#
# IMAGE_FILENAMES = list(uploaded.keys())
# print('Uploaded files:', IMAGE_FILENAMES)

Quickly display the loaded image(s) to confirm.

In [None]:
DESIRED_HEIGHT = 480
DESIRED_WIDTH = 480

def resize_and_show(image):
    h, w = image.shape[:2]
    if h < w:
        img = cv2.resize(image, (DESIRED_WIDTH, math.floor(h/(w/DESIRED_WIDTH))))
    else:
        img = cv2.resize(image, (math.floor(w/(h/DESIRED_HEIGHT)), DESIRED_HEIGHT))
    cv2_imshow(img)

# Preview the images.
images = {name: cv2.imread(name) for name in IMAGE_FILENAMES}

for name, image in images.items():
    print(name)
    resize_and_show(image)

## Load SegNext
We'll clone the [mmsegmentation](https://github.com/open-mmlab/mmsegmentation) repo, install it, and then load a SegNeXt model trained on [ADE20K](https://ade20k.csail.mit.edu/). In this example, we're using the [SegNeXt mscan-b ADE20K model](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/segnext).

In [None]:
# Load the SegNext PyTorch model while setting the device to CPU.
inferencer = apis.MMSegInferencer(model='segnext_mscan-b_1xb16-adamw-160k_ade20k-512x512', device='cpu')

# Retrieve the actual PyTorch model.
pt_model = inferencer.model
pt_model.eval()

## The MIT ADE20K scene parsing dataset  
`ADE20K` is composed of more than 27K images from the [SUN](https://groups.csail.mit.edu/vision/SUN/hierarchy.html) and [Places](https://www.csail.mit.edu/research/places-database-scene-recognition) databases. Images are fully annotated with objects, spanning over 3K object categories.

In [None]:
classes = datasets.ADE20KDataset.METAINFO['classes']
palette = datasets.ADE20KDataset.METAINFO['palette']

For the dataset, we extract the class labels and the color palette from its metadata. We also retrieve the mean and standard deviation values from the data preprocessor configuration (via `inferencer.cfg`), which will be essential when using the converter later.

In [None]:
data_preprocessor_dict = inferencer.cfg.to_dict()['data_preprocessor']
data_preprocessor_dict['mean'], data_preprocessor_dict['std']

## Inference using the PyTorch Model
Let's verify the PyTorch model by doing a quick inference and visualizing the segmentation output.


In [None]:
# The output mask is saved under 'outputs/vis/test.jpg'
results = inferencer(IMAGE_FILENAMES[0], out_dir='outputs', img_out_dir='vis', return_vis=True)

Visualize the segmentation result.

In [None]:
display.Image('outputs/vis/test.jpg')

We have now confirmed that the original PyTorch model can generate valid segmentation predictions and that it runs properly in Python.


## Create a Model Wrapper
To simplify the model output and ensure a single output node during conversion, we'll create a wrapper module. We'll also handle the typical mean/std normalization manually within this wrapper (since some methods, like `torch.min` or `torch.max`, might not be fully supported in the LiteRT conversion).

In [None]:
class ImageSegmentationModelWrapper(torch.nn.Module):
    def __init__(self, pt_model, mmseg_cfg):
        super().__init__()
        self.model = pt_model
        data_preprocessor_dict = mmseg_cfg.to_dict()['data_preprocessor']
        # Convert the mean and std from shape (3,) to (1, 3, 1, 1)
        self.image_mean = torch.tensor(data_preprocessor_dict['mean']).view(1, -1, 1, 1)
        self.image_std = torch.tensor(data_preprocessor_dict['std']).view(1, -1, 1, 1)

    def forward(self, image: torch.Tensor):
        # Input shape: (N, H, W, C)
        # Convert BHWC to BCHW.
        image = image.permute(0, 3, 1, 2)

        # Normalize.
        image = (image - self.image_mean) / self.image_std

        # Model output is typically (N, C, H, W).
        result = self.model(image)

        # Convert from NCHW to NHWC.
        result = result.permute(0, 2, 3, 1)

        return result

# Create the wrapped model.
wrapped_pt_model = ImageSegmentationModelWrapper(pt_model, inferencer.cfg).eval()

## Convert to LiteRT

One of the methods you can use to get to this final output is to download the `tflite` file after the conversion step in this colab, open it with [Model Explorer](https://ai.google.dev/edge/model-explorer) and confirm which output in the graph has the expected output shape.

That's kind of a lot for this example, so to simplify the process and eliminate this effort, you can use a wrapper for the PyTorch model that narrows the scope to only the final output. This approach ensures that your new LiteRT model has only a single output after the conversion stage.

We'll use AI Edge Torch to convert our PyTorch model. We pass in a sample input of appropriate shape to guide the conversion. (This shape also becomes your expected inference shape in the resulting LiteRT model.)

In [None]:
MODEL_INPUT_HW = (512, 512)
sample_args = (torch.rand((1, *MODEL_INPUT_HW, 3)),)

edge_model = ai_edge_torch.convert(wrapped_pt_model, sample_args)

## Validate Converted Model with LiteRT Interpreter
We can test the converted LiteRT model's output. Since our pre-processing is embedded within the wrapper, we'll only resize and cast the input image.


In [None]:
# @markdown We implemented some functions to visualize the segmentation results. <br/> Run the following cell to activate the functions.

# Visualization utilities
def label_to_color_image(label, palette):
    if label.ndim != 2:
        raise ValueError('Expect 2-D input label')
    colormap = np.asarray(palette)
    if np.max(label) >= len(colormap):
        raise ValueError('label value too large.')
    return colormap[label]

def vis_segmentation(image, seg_map, palette, label_names):
    plt.figure(figsize=(15, 5))
    grid_spec = gridspec.GridSpec(1, 4, width_ratios=[6, 6, 6, 1])

    plt.subplot(grid_spec[0])
    plt.imshow(image)
    plt.axis('off')
    plt.title('input image')

    plt.subplot(grid_spec[1])
    seg_image = label_to_color_image(seg_map, palette).astype(np.uint8)
    plt.imshow(seg_image)
    plt.axis('off')
    plt.title('segmentation map')

    H, W = image.shape[:2]
    plt.subplot(grid_spec[2])
    plt.imshow(image, extent=(0, W, H, 0))
    plt.imshow(seg_image, alpha=0.7, extent=(0, W, H, 0))
    plt.axis('off')
    plt.title('segmentation overlay')

    unique_labels = np.unique(seg_map)
    ax = plt.subplot(grid_spec[3])
    full_color_map = label_to_color_image(
        np.arange(len(label_names)).reshape(len(label_names), 1),
        palette
    )
    plt.imshow(full_color_map[unique_labels].astype(np.uint8), interpolation='nearest')
    ax.yaxis.tick_right()
    plt.yticks(range(len(unique_labels)), label_names[unique_labels])
    plt.xticks([], [])
    ax.tick_params(width=0.0)
    plt.grid('off')
    plt.show()

LABEL_NAMES = np.asarray(classes)
PALETTE = palette

In [None]:
np_images = []
image_sizes = []

for index in range(len(IMAGE_FILENAMES)):
    # Retrieve each image from the file system.
    image = Image.open(IMAGE_FILENAMES[index])
    # Save the size for reference.
    image_sizes.append(image.size)
    # Convert each image into a NumPy array with shape (1, H, W, 3)
    np_image = np.array(image.resize(MODEL_INPUT_HW, Image.Resampling.BILINEAR))
    np_image = np.expand_dims(np_image, axis=0).astype(np.float32)
    np_images.append(np_image)

    # Retrieve an output from the converted model.
    edge_model_output = edge_model(np_image)
    segmentation_map = edge_model_output.squeeze()

    # Visualize.
    vis_segmentation(
        np_images[index][0].astype(np.uint8),
        np.argmax(segmentation_map, axis=-1),
        PALETTE,
        LABEL_NAMES
    )

In [None]:
# Serialize the LiteRT model.
edge_model.export('segnext.tflite')

## Apply Quantization
Model size matters on edge devices. Post-training quantization can significantly reduce the size of your `tflite` model. This section demonstrates how to use **dynamic-range quantization** through AI Edge Quantizer.

### Quantizing the model with dynamic quantization (AI Edge Quantizer)

To use the `Quantizer`, we need to
* Instantiate a Quantizer class. This is the entry point to the quantizer's functionalities.
* Load a desired quantization recipe.
* Quantize (and save) the model. This is where most of the quantizer's internal logic works.

In [None]:
# Quantization (API will quantize and save a flatbuffer as *.tflite).
quantizer = quantizer.Quantizer(float_model='segnext.tflite')
quantizer.load_quantization_recipe(recipe=recipe.dynamic_wi8_afp32())

quantization_result = quantizer.quantize()
quantization_result.export_model('segnext_dynamic_wi8_afp32.tflite')

`quantization_result` has two components


* quantized LiteRT model (in bytearray) and
* the corresponding quantization recipe

Let's compare the size of flatbuffers


In [None]:
!ls -lh *.tflite

Let's take a look at what in this recipe



In [None]:
quantization_result.recipe

Here the recipe means: apply the naive min/max uniform algorithm (`min_max_uniform_quantize`) for all ops supported by the AI Edge Quantizer (indicated by `*`) under layers satisfying regex `.*` (i.e., all layers). We want the weights of these ops to be quantized as int8, symmetric, channel_wise, and we want to execute the ops in `Integer` mode.


Next, you'll create a function using LiteRT to run the newly generated quantized model.


In [None]:
# @markdown We implemented some functions to run segmentation on the quantized model. <br/> Run the following cell to activate the functions.
def run_segmentation(image, model_path):
  """Get segmentation mask of the image."""
  image = np.expand_dims(image, axis=0)
  interpreter = ai_edge_litert.interpreter.Interpreter(model_path=model_path)
  interpreter.allocate_tensors()

  input_details = interpreter.get_input_details()[0]
  interpreter.set_tensor(input_details['index'], image)
  interpreter.invoke()

  output_details = interpreter.get_output_details()
  output_index = 0
  outputs = []
  for detail in output_details:
    outputs.append(interpreter.get_tensor(detail['index']))
  mask = np.squeeze(outputs[output_index])
  return mask

Now let try running the newly quantized model and see how they compare.

In [None]:
# Validate the model.
for index in range(len(IMAGE_FILENAMES)):
    quantized_model_output = run_segmentation(np_images[index][0],
                                             'segnext_dynamic_wi8_afp32.tflite')
    vis_segmentation(
        np_images[index][0].astype(np.uint8),
        np.argmax(quantized_model_output, axis=-1),
        PALETTE,
        LABEL_NAMES
    )

## Export and Download Models
Let's save and download the converted `tflite` model, along with the dynamic-range quantized version.

In [None]:
files.download('segnext.tflite')

In [None]:
files.download('segnext_dynamic_wi8_afp32.tflite')

## Next Steps
Now you've got a fully converted (and optionally quantized!) `tflite` model. Here are some ideas on what to do next:

- Explore [AI Edge Torch documentation](https://ai.google.dev/edge) for additional use cases or advanced topics.
- Try out your new model on mobile or web using the [LiteRT API samples](https://ai.google.dev/edge/docs/litert).
- Further tune or calibrate your quantization techniques to achieve the desired balance between model size and accuracy.

Have fun deploying your model to the edge!