In [None]:
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

This demo will teach you how to convert a PyTorch [IS-Net](https://github.com/xuebinqin/DIS) model to a LiteRT model using Google's AI Edge Torch library. You will then run the newly converted `tflite` model locally using the LiteRT API, as well as learn where to find other tools for running your newly converted model on other edge hardware, including mobile devices and web browsers.

# Prerequisites

You can start by importing the necessary dependencies for converting the model, as well as some additional utilities for displaying various information as you progress through this sample.

In [None]:
!pip install ai-edge-torch

You will also need to download an image to verify model functionality.

In [None]:
import urllib

IMAGE_FILENAMES = ['astrid_happy_hike.jpg']

for name in IMAGE_FILENAMES:
  url = f'https://storage.googleapis.com/ai-edge/models-samples/torch_converter/image_segmentation_dis/{name}'
  urllib.request.urlretrieve(url, name)

Optionally, you can upload your own image. If you want to do so, uncomment and run the cell below. Additionally, this will allow you to select multiple images to upload and test at each step in this colab.

In [None]:
from google.colab import files
uploaded = files.upload()

for filename in uploaded:
  content = uploaded[filename]
  with open(filename, 'wb') as f:
    f.write(content)
IMAGE_FILENAMES = list(uploaded.keys())

print('Uploaded files:', IMAGE_FILENAMES)

Now go ahead and verify that the image was loaded successfully

In [None]:
import cv2
from google.colab.patches import cv2_imshow
import math

DESIRED_HEIGHT = 480
DESIRED_WIDTH = 480

def resize_and_show(image):
  h, w = image.shape[:2]
  if h < w:
    img = cv2.resize(image, (DESIRED_WIDTH, math.floor(h/(w/DESIRED_WIDTH))))
  else:
    img = cv2.resize(image, (math.floor(w/(h/DESIRED_HEIGHT)), DESIRED_HEIGHT))
  cv2_imshow(img)


# Preview the images.
images = {name: cv2.imread(name) for name in IMAGE_FILENAMES}

for name, image in images.items():
  print(name)
  resize_and_show(image)

Finally, we've written a few utility functions to help with visualizing each step in this process, as well as one function that performs inference using the various models that can be passed into it. Go ahead and run this cell now so that they're available.

In [None]:
  def display_two_column_images(title_1, title_2, image_1, image_2):
    f, ax = plt.subplots(1, 2, figsize = (7,7))
    ax[0].imshow(image_1)
    ax[1].imshow(image_2, cmap = 'gray')
    ax[0].set_title(title_1)
    ax[1].set_title(title_2)
    ax[0].axis('off')
    ax[1].axis('off')
    plt.tight_layout()
    plt.show()

  def display_three_column_images(title_1, title_2, title_3, image_1, image_2, image_3):
    f, ax = plt.subplots(1, 3, figsize = (10,10))
    ax[0].imshow(image_1)  # Original image.
    ax[1].imshow(image_2, cmap = 'gray')  # PT segmentation mask.
    ax[2].imshow(image_3, cmap = 'gray')  # TFL segmentation mask.
    ax[0].set_title(title_1)
    ax[1].set_title(title_2)
    ax[2].set_title(title_3)
    ax[0].axis('off')
    ax[1].axis('off')
    ax[2].axis('off')
    plt.tight_layout()
    plt.show()

  def get_processed_isnet_result(model_output, original_image_hw):
    # Min-max normalization.
    output_min = model_output.min()
    output_max = model_output.max()
    result = (model_output - output_min) / (output_max - output_min)

    # Scale [0, 1] -> [0, 255].
    result = (result * 255).astype(np.uint8)

    # Restore original image size.
    result = Image.fromarray(result.squeeze(), "L")
    return result.resize(original_image_hw, Image.Resampling.BILINEAR)

# PyTorch model validation

Now that you have your test images and utility functions, it's time to test the original PyTorch model that will be converted to the `tflite` format. You can start by retrieving the PyTorch model from Kaggle, along with the original project from GitHub that will be used for building the model.

In [None]:
%cd /content
!rm -rf DIS sample_data

!git clone https://github.com/xuebinqin/DIS.git
%cd DIS/IS-Net/

!curl -o ./model.tar.gz -L https://www.kaggle.com/api/v1/models/paulruiz/dis/pyTorch/8-17-22/1/download
!tar -xvf 'model.tar.gz'

Next you will load in that new model build it to run locally.

In [None]:
import torch
from models import ISNetDIS

pytorch_model_filename = 'isnet-general-use.pth'
pt_model = ISNetDIS()
pt_model.load_state_dict(
    torch.load(pytorch_model_filename, map_location=torch.device('cpu'))
)
pt_model.eval();

And to finish validating the original model, you can use it to run inference on the test image(s) that you loaded earlier. In this step you will save the generated PyTorch segmentation mask images so they can be compared to your LiteRT segmentation mask images later in this colab.

In [None]:
from io import BytesIO
import numpy as np
from skimage import io

import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms.functional import normalize

from matplotlib import pyplot as plt

MODEL_INPUT_HW = (1024, 1024)
pt_result = []
images = []
for index in range(len(IMAGE_FILENAMES)) :
  images.append(io.imread('../../'+IMAGE_FILENAMES[index]))

  # BHWC -> BCHW.
  image_tensor = torch.tensor(images[index], dtype=torch.float32).permute(2, 0, 1)

  # Resize to meet model input size requirements.
  image_tensor = F.upsample(torch.unsqueeze(image_tensor, 0),
                            MODEL_INPUT_HW, mode='bilinear').type(torch.uint8)

  # Scale [0, 255] -> [0, 1].
  pt_image = torch.divide(image_tensor, 255.0)

  # Normalize.
  pt_image = normalize(pt_image, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0])

  # Get output with the most accurate prediction.
  pt_result.append(pt_model(pt_image)[0][0])

  # Recover the prediction spatial size to the orignal image size.
  pt_result[index] = F.upsample(pt_result[index], images[index].shape[:2],  mode='bilinear')
  pt_result[index] = torch.squeeze(pt_result[index], 0)

  # Min-max normalization.
  ma = torch.max(pt_result[index])
  mi = torch.min(pt_result[index])
  pt_result[index] = (pt_result[index] - mi) / (ma - mi)

  # Scale [0, 1] -> [0, 255].
  pt_result[index] = pt_result[index] * 255

  # BCHW -> BHWC.
  pt_result[index] = pt_result[index].permute(1, 2, 0)

  # Get numpy array.
  pt_result[index] = pt_result[index].cpu().data.numpy().astype(np.uint8)

  display_two_column_images('Original Image', 'Mask', images[index], pt_result[index])

# Convert to LiteRT

## Add model wrapper

The original IS-Net model generates 12 outputs, each corresponding to different stages in the segmentation process. While the official PyTorch model demo provides guidance on selecting the final (best) output, obtaining the desired output from the converted LiteRT model requires additional effort.

One of the methods you can use to get to this final output is to download the `tflite` file after the conversion step in this colab, open it with [Model Explorer](https://ai.google.dev/edge/model-explorer) and confirm which output in the graph has the expected output shape.

That's kind of a lot for this example, so to simplify the process and eliminate this effort, you can use a wrapper for the PyTorch model that narrows the scope to only the final output. This approach ensures that your new LiteRT model has only a single output after the conversion stage.

Additionally, this colab include some extra pre and post-processing steps, such as excluding min-max normalization because `torch.min` and `torch.max` are not currently supported in the conversion process.

You can create the wrapper by running the following cell:

In [None]:
class ImageSegmentationModelWrapper(nn.Module):

  RESCALING_FACTOR = 255.0
  MEAN = 0.5
  STD = 1.0

  def __init__(self, pt_model):
    super().__init__()
    self.model = pt_model

  def forward(self, image: torch.Tensor):
    # BHWC -> BCHW.
    image = image.permute(0, 3, 1, 2)

    # Rescale [0, 255] -> [0, 1].
    image = image / self.RESCALING_FACTOR

    # Normalize.
    image = (image - self.MEAN) / self.STD

    # Get result.
    result = self.model(image)[0][0]

    # BHWC -> BCHW.
    result = result.permute(0, 2, 3, 1)

    return result


wrapped_pt_model = ImageSegmentationModelWrapper(pt_model).eval()

## Convert to LiteRT

Provide sample arguments -- result LiteRT model will expect input of this size -- and convert the model.

Now it's time to perform the conversion! You will need to provide a couple arguments, such as the expected input shape (for example: 1, model input height, model input width, and 3 for the RGB layers of an image) and the wrapper that you created in the last step.

In [None]:
import ai_edge_torch

sample_args = (torch.rand((1, *MODEL_INPUT_HW, 3)),)
edge_model = ai_edge_torch.convert(wrapped_pt_model, sample_args)

# Validate converted model with LiteRT Interpreter

Now that you have a converted model stored in colab, it's time to test it. You can start by preparing the test image(s) that you loaded earlier. Since all of the preprocessing steps were into the model earlier, you will only need to resize and type cast the input image(s) in this step. At the end of this stage you should see the original image, the PyTorch mask graphic, and the LiteRT mask graphic for your test input.

In [None]:
from PIL import Image

np_images = []
image_sizes = []
for index in range(len(IMAGE_FILENAMES)) :
  # Retrieve each image from the file system
  image = Image.open('../../' + IMAGE_FILENAMES[index])
  # Track each image's size here to simplify displaying later
  image_sizes.append(image.size)
  # Convert each image into a NumPy array and save for later
  np_images.append(np.array(image.resize(MODEL_INPUT_HW, Image.Resampling.BILINEAR)))
  np_images[index] = np.expand_dims(np_images[index], axis=0).astype(np.float32)

  # Retrieve an output from the converted model
  edge_model_output = edge_model(np_images[index])

  # Use the visualization utility created earlier to get a displayable image
  lrt_result = get_processed_isnet_result(edge_model_output, image_sizes[index])

  display_three_column_images('Original Image', 'PT Mask', 'TFL Mask', images[index], pt_result[index], lrt_result)

# Post Training and Dynamic-Range Quantization with LiteRT

At this point you should have a working `tflite` model that you have converted from the original PyTorch format. Congratulations! But if you're working with edge devices, then you likely know that model size is an **important** consideration for things like mobile devices. Using a technique called *quantization*, you can reduce a model's size to roughly a quarter of the original size while maintaining a similar level of output quality. To do this with the Google AI Edge PyTorch Converter, you can pass in an optimization flag to the `convert` function to include a step for dynamic-range quantization.

If you'd like to know more about quantization and other optimizations, you can find our official documentation [here](https://www.tensorflow.org/lite/performance/post_training_quantization).

In [None]:
import tensorflow as tf


tfl_converter_flags={
    "optimizations": [tf.lite.Optimize.DEFAULT]
}
tfl_drq_model = ai_edge_torch.convert(
    wrapped_pt_model,
    sample_args,
    _ai_edge_converter_flags=tfl_converter_flags
)

After the conversion has finished, you can compare the newly converted and quantized model with the original image and PyTorch mask image from earlier.

In [None]:
for index in range(len(IMAGE_FILENAMES)) :

  tfl_drq_model_output = tfl_drq_model(np_images[index])

  tfl_drq_result = get_processed_isnet_result(tfl_drq_model_output, image_sizes[index])

  display_three_column_images('Original Image', 'PT Mask', 'TFLQ Mask', images[index], pt_result[index], tfl_drq_result)

# Post Training and Dynamic-Range Quantization with PT2E

Another available option for dynamic-range quantization is called PT2E, which is a framework-level quantization feature available in PyTorch 2.0. For more details see [PyTorch tutorial](https://pytorch.org/tutorials/prototype/quantization_in_pytorch_2_0_export_tutorial.html).

PT2EQuantizer is developed specifically for the AI Edge Torch framework and is configured to quantize models leveraging various operators and kernals offered by the LiteRT Runtime.

You can see how to configure the PT2EQuantizer and use it as an additional parameter in the `convert` function below.

In [None]:
from ai_edge_torch.quantize.pt2e_quantizer import get_symmetric_quantization_config
from ai_edge_torch.quantize.pt2e_quantizer import PT2EQuantizer
from ai_edge_torch.quantize.quant_config import QuantConfig

from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
from torch._export import capture_pre_autograd_graph


pt2e_quantizer = PT2EQuantizer().set_global(
    get_symmetric_quantization_config(is_per_channel=True, is_dynamic=True)
)

# Following are the required steps recommended in the PT2E quantization
# workflow.
autograd_torch_model = capture_pre_autograd_graph(wrapped_pt_model, sample_args)
# 1. Prepare for quantization.
pt2e_torch_model = prepare_pt2e(autograd_torch_model, pt2e_quantizer)
# 2. Run the prepared model with sample input data to ensure that internal
# observers are populated with correct values.
pt2e_torch_model(*sample_args)
# 3. Finally, convert (quantize) the prepared model.
pt2e_torch_model = convert_pt2e(pt2e_torch_model, fold_quantize=False)

pt2e_drq_model = ai_edge_torch.convert(
    pt2e_torch_model,
    sample_args,
    quant_config=QuantConfig(pt2e_quantizer=pt2e_quantizer)
)

Once the model has been converted again using the PT2E Quantizer, it's time to review the results so you can compare them to both the original image and the PyTorch inferred mask.

In [None]:
for index in range(len(IMAGE_FILENAMES)) :

  pt2e_drq_output = pt2e_drq_model(np_images[index])

  pt2e_drq_result = get_processed_isnet_result(pt2e_drq_output, image_sizes[index])

  display_three_column_images('Original Image', 'PT Mask', 'PT2E DRQ Mask', images[index], pt_result[index], pt2e_drq_result)

# Download converted models

Now that you've converted and optimized the DIS model for LiteRT, it's time to save those models. The following cells are set up to download three models: the newly converted `tflite` model without optimizations, the converted model using dynamic range quantization, and the model that uses PT2E quantization. When you've finished downloading these files, check out their finished sizes! You'll notice that the original converted model is about 175MB in size, whereas the quantized models are about 45MB - much more manageable for edge devices!

In [None]:
from google.colab import files

tfl_filename = "isnet.tflite"
edge_model.export(tfl_filename)

files.download(tfl_filename)

In [None]:
tfl_drq_filename = 'isnet_tfl_drq.tflite'
tfl_drq_model.export(tfl_drq_filename)

files.download(tfl_drq_filename)

In [None]:
pt2e_drq_filename = 'isnet_pt2e_drq.tflite'
pt2e_drq_model.export(pt2e_drq_filename)

files.download(pt2e_drq_filename)

# Next steps

Now that you have learned how to convert this segmentation model from the PyTorch to tflite format, it's time to do more with it! You can go over additional LiteRT API samples for multiple platforms, including Android, iOS, and Python, as well as learn more about on-device machine learning inference from the [Google AI Edge official documentation](https://ai.google.dev/edge/). You can find samples to run the output from this colab on Android [here](https://github.com/google-ai-edge/litert-samples/tree/main/examples/image_segmentation_DIS/android) and on iOS [here](https://github.com/google-ai-edge/litert-samples/tree/main/examples/image_segmentation_DIS/ios).