# Image Segmentation using the Google AI Edge LiteRT API

## Preparation

Let's start by importing TensorFlow, LiteRT, and also downloading an off-the-shelf model. Check out [Kaggle Models](https://www.kaggle.com/models/tensorflow/deeplabv3) for more information about the DeepLab V3 that you will be using in this tutorial.


In [None]:
!pip install ai-edge-litert-nightly

In [None]:
import tensorflow as tf
from ai_edge_litert.interpreter import Interpreter
import numpy as np

## Download the image segmenter model

The next thing you will need to do is download the image segmentation model that will be used for this demo. In this case you will use the DeepLab V3 model.

In [None]:
#@title Start downloading here.
import pathlib
import kagglehub

path = kagglehub.model_download("tensorflow/deeplabv3/tfLite/default")
print("Path to model files:", path)

MODEL_PATH = str(next(pathlib.Path(path).rglob('*.tflite')))

Optionally, you can upload your own model (.tflite). If you want to do so, uncomment and run the cell below.


In [None]:
# from google.colab import files
# uploaded = files.upload()

# for filename in uploaded:
#   content = uploaded[filename]
#   with open(filename, 'wb') as f:
#     f.write(content)

# MODEL_PATH = list(uploaded.keys())[0]

# print('Uploaded model:', MODEL_PATH)

In [None]:
#@markdown We implemented some functions to visualize the image segmentation results. <br/> Run the following cell to activate the functions.
# The visualization utilities here are mostly taken from the DeepLabV3 Demo Colab notebook
# https://colab.research.google.com/github/tensorflow/models/blob/master/research/deeplab/deeplab_demo.ipynb

from matplotlib import gridspec
from matplotlib import pyplot as plt

def create_pascal_label_colormap():
  """Creates a label colormap used in PASCAL VOC segmentation benchmark.

  Returns:
    A Colormap for visualizing segmentation results.
  """
  colormap = np.zeros((256, 3), dtype=int)
  ind = np.arange(256, dtype=int)

  for shift in reversed(range(8)):
    for channel in range(3):
      colormap[:, channel] |= ((ind >> channel) & 1) << shift
    ind >>= 3

  return colormap


def label_to_color_image(label):
  """Adds color defined by the dataset colormap to the label.

  Args:
    label: A 2D array with integer type, storing the segmentation label.

  Returns:
    result: A 2D array with floating type. The element of the array
      is the color indexed by the corresponding element in the input label
      to the PASCAL color map.

  Raises:
    ValueError: If label is not of rank 2 or its value is larger than color
      map maximum entry.
  """
  if label.ndim != 2:
    raise ValueError('Expect 2-D input label')

  colormap = create_pascal_label_colormap()

  if np.max(label) >= len(colormap):
    raise ValueError('label value too large.')

  return colormap[label]


def visualize_segmentation(image, seg_map):
  """Visualizes input image, segmentation map and overlay view."""
  plt.figure(figsize=(15, 5))
  grid_spec = gridspec.GridSpec(1, 4, width_ratios=[6, 6, 6, 1])

  plt.subplot(grid_spec[0])
  plt.imshow(image)
  plt.axis('off')
  plt.title('Input Image')

  plt.subplot(grid_spec[1])
  seg_image = label_to_color_image(seg_map).astype(np.uint8)
  plt.imshow(seg_image)
  plt.axis('off')
  plt.title('Segmentation Map')

  plt.subplot(grid_spec[2])
  plt.imshow(image)
  plt.imshow(seg_image, alpha=0.7)
  plt.axis('off')
  plt.title('Segmentation Overlay')

  unique_labels = np.unique(seg_map)
  ax = plt.subplot(grid_spec[3])
  plt.imshow(
      FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation='nearest')
  ax.yaxis.tick_right()
  plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
  plt.xticks([], [])
  ax.tick_params(width=0.0)
  plt.grid('off')
  plt.show()

# Labels for the PASCAL VOC dataset
LABEL_NAMES = np.asarray([
    'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
    'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
    'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tv'
])

FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)

## Download a test image

After downloading the model, it's time to grab an image that you can use for testing! It's worth noting that while this is working with a single image, you can download a collection of images to store in the `IMAGE_FILENAMES` array.

In [None]:
import urllib

IMAGE_FILENAMES = ['segmentation_input_rotation0.jpg']

for name in IMAGE_FILENAMES:
  url = f'https://storage.googleapis.com/mediapipe-assets/{name}'
  urllib.request.urlretrieve(url, name)

Optionally, you can upload your own image. If you want to do so, uncomment and run the cell below.


In [None]:
# from google.colab import files
# uploaded = files.upload()

# for filename in uploaded:
#   content = uploaded[filename]
#   with open(filename, 'wb') as f:
#     f.write(content)
# IMAGE_FILENAMES = list(uploaded.keys())

# print('Uploaded files:', IMAGE_FILENAMES)

## Preview the downloaded image

With the test image downloaded, go ahead and display it.

In [None]:
import cv2
from google.colab.patches import cv2_imshow
import math

# Height and width that will be used by the model
DESIRED_HEIGHT = 480
DESIRED_WIDTH = 480

# Performs resizing and showing the image
def resize_and_show(image):
  h, w = image.shape[:2]
  if h < w:
    img = cv2.resize(image, (DESIRED_WIDTH, math.floor(h/(w/DESIRED_WIDTH))))
  else:
    img = cv2.resize(image, (math.floor(w/(h/DESIRED_HEIGHT)), DESIRED_HEIGHT))
  cv2_imshow(img)


# Preview the image(s)
images = {name: cv2.imread(name) for name in IMAGE_FILENAMES}
for name, image in images.items():
  print(name)
  resize_and_show(image)

## Running inference and visualizing the results
To run inference using the Interpreter API, you will need to need the load the model, get the model's input details to get the desired input size and we finally perform image segmentation by running our input image on the model. This example will separate the background and foreground of the image and apply separate colors for them to highlight where each distinctive area exists.

Check out the [Interpreter documentation](https://www.tensorflow.org/lite/guide/inference#load_and_run_a_model_in_python) to learn more about configuration options for the Interpreter API.

Note: You need to match input/output tensor shapes if you happen to be using custom models.

In [None]:
# STEP 1: Import the necessary modules.
from PIL import Image
from PIL import ImageOps


# STEP 2: Load the TFLite model in LiteRT Interpreter and allocate tensors.
interpreter = Interpreter(model_path=MODEL_PATH)
interpreter.allocate_tensors()

# STEP 3: Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# STEP 4: Get image size - converting from BHWC to WH.
input_size = input_details[0]['shape'][2], input_details[0]['shape'][1]

# Loop through demo image(s)
for image_file_name in IMAGE_FILENAMES:
  # STEP 5: Load the input image.
  image = Image.open(image_file_name)

  # STEP 6: Crop to the desired model input size size while keeping the aspect ratio.
  cropped_image = ImageOps.contain(image, input_size)

  # Step 7: Resize the cropped image to the desired model size
  resized_image = cropped_image.convert('RGB').resize(input_size, Image.BILINEAR)

  # Step 8: Convert to a NumPy array, add a batch dimension, and normalize the image.
  image_np = np.asarray(resized_image).astype(np.float32)
  image_np = np.expand_dims(image_np, 0)
  image_np = image_np / 127.5 - 1

  # Step 9: Set the input tensor and perform segmentation on the input image.
  interpreter.set_tensor(input_details[0]['index'], image_np)
  interpreter.invoke()
  output_tensor = interpreter.get_tensor(output_details[0]['index'])

  # Step 10: Process the segmentation result. In this case, we visualize it.
  width, height = cropped_image.size
  segmentation_map = tf.argmax(tf.image.resize(output_tensor, (height, width)), axis=3)
  segmentation_map = tf.squeeze(segmentation_map).numpy().astype(np.int8)
  visualize_segmentation(cropped_image, segmentation_map)