In [None]:
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

This demo will teach you how to convert a PyTorch [MobileViT](https://huggingface.co/docs/transformers/en/model_doc/mobilevit#overview) model to a TensorFlow Lite model intended to run with [MediaPipe](https://developers.google.com/mediapipe/solutions) Tasks using Google's AI Edge Torch library. You will then run the newly converted `tflite` model locally using the MediaPipe Tasks on-device inference tool, as well as learn where to find other tools for running your newly converted model on other edge hardware, including mobile devices and web browsers.

# Prerequisites

You can start by importing the necessary dependencies for converting the model, as well as some additional utilities for displaying various information as you progress through this sample.

In [None]:
!pip install -r https://github.com/google-ai-edge/ai-edge-torch/releases/download/v0.1.1/requirements.txt
!pip install mediapipe
!pip install ai-edge-torch==0.1.1

You will also need to download an image to verify model functionality.

In [None]:
import urllib

IMAGE_FILENAMES = ['cat.jpg']

for name in IMAGE_FILENAMES:
  url = f'https://storage.googleapis.com/ai-edge/models-samples/torch_converter/image_classification_mobile_vit/{name}'
  urllib.request.urlretrieve(url, name)

Optionally, you can upload your own image. If you want to do so, uncomment and run the cell below.

In [None]:
# from google.colab import files
# uploaded = files.upload()

# for filename in uploaded:
#   content = uploaded[filename]
#   with open(filename, 'wb') as f:
#     f.write(content)
# IMAGE_FILENAMES = list(uploaded.keys())

# print('Uploaded files:', IMAGE_FILENAMES)

Now go ahead and verify that the image was loaded successfully

In [None]:
import cv2
from google.colab.patches import cv2_imshow
import math

DESIRED_HEIGHT = 480
DESIRED_WIDTH = 480

def resize_and_show(image):
  h, w = image.shape[:2]
  if h < w:
    img = cv2.resize(image, (DESIRED_WIDTH, math.floor(h/(w/DESIRED_WIDTH))))
  else:
    img = cv2.resize(image, (math.floor(w/(h/DESIRED_HEIGHT)), DESIRED_HEIGHT))
  cv2_imshow(img)


# Preview the images.
images = {name: cv2.imread(name) for name in IMAGE_FILENAMES}

for name, image in images.items():
  print(name)
  resize_and_show(image)

# PyTorch model validation

Now that you have your test images, it's time to validate the PyTorch model (in this case MobileViT) that will be converted to the TensorFlow Lite format.

Start by retrieving the PyTorch model and the appropriate corresponding processor.

In [None]:
from transformers import MobileViTImageProcessor, MobileViTForImageClassification

hf_model_path = 'apple/mobilevit-small'
processor = MobileViTImageProcessor.from_pretrained(hf_model_path)
pt_model = MobileViTForImageClassification.from_pretrained(hf_model_path)

The MobileViTImageProcessor defined below will perform multiple steps on the input image to match the requirements of the MobileViT model:

1. Convert the image from RGB to BGR.
2. Rescale the image from the [0, 255] range to the [0, 1] range.
3. Resize input image to 256x256 pixels. Differes from default behaviour of processor (includes padding and center cropping) to make it easier to validate the converted model with MediaPipe Tasks (more details in the corresponding section).

In [None]:
from PIL import Image

images = []
for filename in IMAGE_FILENAMES:
  images.append(Image.open(filename))

inputs = processor(
    images=images,
    return_tensors='pt',
    size={'height': 256, 'width': 256},
    do_center_crop=False
)

Now that you have your test data ready and the inputs processed, it's time to validate the classifications. In this step you will loop through your test image(s) and display the top 5 predicted classification categories. This model was trained with ImageNet-1000, so there are 1000 different potential classifications that may be applied to your test data.

In [None]:
import torch
from torch import nn

for image_index in range(len(IMAGE_FILENAMES)) :
  outputs = pt_model(**inputs)
  logits = outputs.logits
  probs, indices = nn.functional.softmax(logits[image_index], dim=-1).flatten().topk(k=5)

  print(IMAGE_FILENAMES[image_index], 'predictions: ')
  for prediction_index in range(len(indices)):
    class_label = pt_model.config.id2label[indices[prediction_index].item()]
    prob = probs[prediction_index].item()
    print(f'{(prob * 100):4.1f}%  {class_label}')
  print('\n')

# Convert to TFLite

Before converting the PyTorch model to TFLite, you will need to take an extra step to match it to the format expected by MediaPipe (MP) Tasks. Here are the necessary adjustments:

1. MediaPipe Tasks require channel-last images (BHWC) while PyTorch uses channel-first (BCHW).

2. For the Image Classification Task, MediaPipe requires an additional sigmoid layer on classification logits.

You can also include preprocessing steps into a wrapper, such as converting from RGB to BGR and scaling, similar to what you did when validating the PyTorch model in the previous section.

In [None]:
class HF2MP_ImageClassificationModelWrapper(nn.Module):

  def __init__(self, hf_image_classification_model, hf_processor):
    super().__init__()
    self.model = hf_image_classification_model
    if hf_processor.do_rescale:
      self.rescale_factor = hf_processor.rescale_factor
    else:
      self.rescale_factor = 1.0

  def forward(self, image: torch.Tensor):
    # BHWC -> BCHW.
    image = image.permute(0, 3, 1, 2)
    # RGB -> BGR.
    image = image.flip(dims=(1,))
    # Scale [0, 255] -> [0, 1].
    image = image * self.rescale_factor
    logits = self.model(pixel_values=image).logits  # [B, 1000] float32.
    # Softmax is required for MediaPipe classification model.
    logits = torch.nn.functional.softmax(logits, dim=-1)

    return logits


hf_model_path = 'apple/mobilevit-small'
hf_mobile_vit_processor = MobileViTImageProcessor.from_pretrained(hf_model_path)
hf_mobile_vit_model = MobileViTForImageClassification.from_pretrained(hf_model_path)
wrapped_pt_model = HF2MP_ImageClassificationModelWrapper(
hf_mobile_vit_model, hf_mobile_vit_processor).eval()

## Convert to TFLite

Now it's time to perform the conversion! You will need to provide simple arguments, such as the expected input shape (in this case three layers for images that are 256 height by 256 width).

In [None]:
import ai_edge_torch

sample_args = (torch.rand((1, 256, 256, 3)),)
edge_model = ai_edge_torch.convert(wrapped_pt_model, sample_args)

Once the conversion is finished and you have a new `tflite` model file, you will convert the raw tflite file into a *model buffer* so that you can do a little more additional processing to prepare the file for working with MediaPipe Tasks. This includes attaching the labels for the model to the new `tflite` model so that it can be used with MediaPipe Tasks Image Classification.

In [None]:
from mediapipe.tasks.python.metadata.metadata_writers import image_classifier
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
from mediapipe.tasks.python.vision.image_classifier import ImageClassifier
from pathlib import Path

flatbuffer_file = Path('hf_mobile_vit_mp_image_classification_raw.tflite')
edge_model.export(flatbuffer_file)
tflite_model_buffer = flatbuffer_file.read_bytes()

labels = list(hf_mobile_vit_model.config.id2label.values())

writer = image_classifier.MetadataWriter.create(
    tflite_model_buffer,
    input_norm_mean=[0.0], #  Normalization is not needed for this model.
    input_norm_std=[1.0],
    labels=metadata_writer.Labels().add(labels),
)
tflite_model_buffer, _ = writer.populate()

After attaching the metadata to the intermediate model buffer object, you can convert the buffer back into a `tflite` file.

In [None]:
tflite_filename = 'hf_mobile_vit_mp_image_classification.tflite'
# Save converted model to Colab's local file system.
with open(tflite_filename, 'wb') as f:
  f.write(tflite_model_buffer)

Before moving on to *using* the converted model, it's always a good idea to make sure the model was successefully saved.

In [None]:
!ls -l /content/hf_mobile_vit_mp_image_classification.tflite

# Validate converted model with MediaPipe Tasks

Now it's time to test your newly converted model directly with the MediaPipe Image Classification Task. Before getting into that code, you can add the following utility functions to improve the output displayed.

In [None]:
from matplotlib import pyplot as plt
plt.rcParams.update({
    'axes.spines.top': False,
    'axes.spines.right': False,
    'axes.spines.left': False,
    'axes.spines.bottom': False,
    'xtick.labelbottom': False,
    'xtick.bottom': False,
    'ytick.labelleft': False,
    'ytick.left': False,
    'xtick.labeltop': False,
    'xtick.top': False,
    'ytick.labelright': False,
    'ytick.right': False
})


def display_one_image(image, title, subplot, titlesize=16):
    """Displays one image along with the predicted category name and score."""
    plt.subplot(*subplot)
    plt.imshow(image)
    if len(title) > 0:
        plt.title(title, fontsize=int(titlesize), color='black', fontdict={'verticalalignment':'center'}, pad=int(titlesize/1.5))
    return (subplot[0], subplot[1], subplot[2]+1)

def display_batch_of_images(images, predictions):
    """Displays a batch of images with the classifications."""
    # Images and predictions.
    images = [image.numpy_view() for image in images]

    # Auto-squaring: this will drop data that does not fit into square or square-ish rectangle.
    rows = int(math.sqrt(len(images)))
    cols = len(images) // rows

    # Size and spacing.
    FIGSIZE = 13.0
    SPACING = 0.1
    subplot=(rows,cols, 1)
    if rows < cols:
        plt.figure(figsize=(FIGSIZE,FIGSIZE/cols*rows))
    else:
        plt.figure(figsize=(FIGSIZE/rows*cols,FIGSIZE))

    # Display.
    for i, (image, prediction) in enumerate(zip(images[:rows*cols], predictions[:rows*cols])):
        dynamic_titlesize = FIGSIZE * SPACING / max(rows,cols) * 40 + 3
        subplot = display_one_image(image, prediction, subplot, titlesize=dynamic_titlesize)

    # Layout.
    plt.tight_layout()
    plt.subplots_adjust(wspace=SPACING, hspace=SPACING)
    plt.show()

Now it's time to move on to the actual inference code and display the highest confidence classification result.

While the converted model expects a square input image with a height of 256 pixels and a width of 256 pixels, the MediaPipe Image Classification Task automatically resizes and adds padding to the input image to meet the model's input requirements.

During this validation step, you will ensure that the converted model produces roughly the same output as the original PyTorch model for the same input. One thing worth noting is since the resizing and padding in MediaPipe differs from that performed in MobileViTImageProcessor, there will likely be some minor differences in prediction confidences. To account for this, we will bypass the padding and automatic resizing step by resizing the input image manually before feeding it to the image classifier.

In [None]:
import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python.components import processors
from mediapipe.tasks.python import vision

# STEP 1: Create an ImageClassifier object.

base_options= python.BaseOptions(
        model_asset_path=f'/content/{tflite_filename}')

options = vision.ImageClassifierOptions(
    base_options=base_options,
    max_results=5)

classifier = vision.ImageClassifier.create_from_options(options)

images = []
predictions = []
for image_name in IMAGE_FILENAMES:
  # STEP 2: Load the input image(s).
  image = mp.Image.create_from_file(image_name)

  # STEP 3: Classify the input image(s).
  classification_result = classifier.classify(image)

  # STEP 4: Process the classification result. In this case, visualize it.
  images.append(image)
  top_category = classification_result.classifications[0].categories[0]
  predictions.append(f"{top_category.category_name} ({top_category.score:.2f})")

display_batch_of_images(images, predictions)

You should now see your loaded test images and their confidence scores/classifications that match the original PyTorch model results! If everything looks good, the final step should be downloading your newly converted `tflite` model file to your computer so you can use it elsewhere.

In [None]:
from google.colab import files

files.download(tflite_filename)

# Next steps

Now that you have learned how to convert a PyTorch model to the TFLite format, it's time to do more with it! You can go over additional [MediaPipe](https://github.com/google-ai-edge/mediapipe-samples) samples for Android, iOS, web, and Python (including the Raspberry Pi!) to try your new model on multiple platforms, check out the [TFLite Interpreter API](https://ai.google.dev/edge/lite/) for running custom solutions, and read more about the PyTorch to TFLite framework with our [official documentation](https://ai.google.dev/edge/lite/models/convert_pytorch).