In [None]:
# Copyright 2025 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

This demo will teach you how to convert a PyTorch [ConvNext V2](https://huggingface.co/docs/transformers/en/model_doc/convnextv2#overview) model to a LiteRT (formally TensorFlow Lite) model using Google's AI Edge Torch library.

# Prerequisites

Before starting the conversion process, ensure that you have all the necessary dependencies installed and required resources (like test images) available. You can start by importing the necessary dependencies for converting the model, as well as some additional utilities for displaying various information as you progress through this sample.

In [None]:
!pip install ai-edge-torch-nightly
!pip install transformers pillow requests matplotlib

You will also need to download an image to verify model functionality.

In [None]:
import urllib

IMAGE_FILENAMES = ['cat.jpg']

for name in IMAGE_FILENAMES:
  # TODO: Update path to the appropriate task subfolder in the GCS bucket
  url = f'https://storage.googleapis.com/ai-edge/models-samples/torch_converter/image_classification_mobile_vit/{name}'
  urllib.request.urlretrieve(url, name)

Optionally, you can upload your own image. If you want to do so, uncomment and run the cell below.

In [None]:
# from google.colab import files
# uploaded = files.upload()

# for filename in uploaded:
#   content = uploaded[filename]
#   with open(filename, 'wb') as f:
#     f.write(content)
# IMAGE_FILENAMES = list(uploaded.keys())

# print('Uploaded files:', IMAGE_FILENAMES)

Now go ahead and verify that the image was loaded successfully

In [None]:
import cv2
from google.colab.patches import cv2_imshow
import math

DESIRED_HEIGHT = 480
DESIRED_WIDTH = 480

def resize_and_show(image):
  h, w = image.shape[:2]
  if h < w:
    img = cv2.resize(image, (DESIRED_WIDTH, math.floor(h/(w/DESIRED_WIDTH))))
  else:
    img = cv2.resize(image, (math.floor(w/(h/DESIRED_HEIGHT)), DESIRED_HEIGHT))
  cv2_imshow(img)


# Preview the images.
images = {name: cv2.imread(name) for name in IMAGE_FILENAMES}

for name, image in images.items():
  print(name)
  resize_and_show(image)

# PyTorch model validation

Now that you have your test images, it's time to validate the PyTorch model (in this case ConvNext V2) that will be converted to the LiteRT format.

Start by retrieving the PyTorch model and the appropriate corresponding processor.

In [None]:
from transformers import ConvNextImageProcessor, ConvNextV2ForImageClassification

# Define the Hugging Face model path
hf_model_path = 'facebook/convnextv2-tiny-1k-224'

# Initialize the image processor
processor = ConvNextImageProcessor.from_pretrained(
    hf_model_path
)

In [None]:
# Display the image normalization parameters
print("Image Mean:", processor.image_mean)
print("Image Std:", processor.image_std)

In [None]:
pt_model = ConvNextV2ForImageClassification.from_pretrained(hf_model_path)

The `ConvNextImageProcessor` defined below will perform multiple steps on the input image to match the requirements of the `ConvNextV2` model:

1. Rescale the image from the [0, 255] range to the range specified by the pretrained model.
2. Resize input image to 224x224 pixels. Differes from default behaviour of processor (includes padding and center cropping) to make it easier to validate the converted model with LiteRT (more details in the corresponding section).

In [None]:
from PIL import Image

images = []
for filename in IMAGE_FILENAMES:
  images.append(Image.open(filename))

inputs = processor(
    images=images,
    return_tensors='pt',
    # Adjusts the image to have the shortest edge of 224 pixels
    size={'shortest_edge': 224},
    do_center_crop=False
)

Now that you have your test data ready and the inputs processed, it's time to validate the classifications. In this step you will loop through your test image(s) and display the top 5 predicted classification categories. This model was trained with ImageNet-1000, so there are 1000 different potential classifications that may be applied to your test data.

In [None]:
import torch
from torch import nn

for image_index in range(len(IMAGE_FILENAMES)) :
  outputs = pt_model(**inputs)
  logits = outputs.logits
  probs, indices = nn.functional.softmax(logits[image_index], dim=-1).flatten().topk(k=5)

  print(IMAGE_FILENAMES[image_index], 'predictions: ')
  for prediction_index in range(len(indices)):
    class_label = pt_model.config.id2label[indices[prediction_index].item()]
    prob = probs[prediction_index].item()
    print(f'{(prob * 100):4.1f}%  {class_label}')
  print('\n')

# Convert to the `tflite` Format

Before converting the PyTorch model to work with the tflite format, you will need to take an extra step to match it to the format expected by LiteRT. Here are the necessary adjustments:

1. **Channel Ordering**: Convert images from channel-first (BCHW) to channel-last (BHWC) format.
2. **Softmax Layer**: Add a softmax layer to the classification logits as required by LiteRT as this is an image classification task.
3. **Preprocessing Wrapper**: Incorporate preprocessing steps (e.g., RGB to BGR conversion, scaling, normalization) into a wrapper class, similar to what you did when validating the PyTorch model in the previous section.


In [None]:
class HF2LiteRT_ImageClassificationModelWrapper(nn.Module):

  def __init__(self, hf_image_classification_model, hf_processor):
    super().__init__()
    self.model = hf_image_classification_model
    if hf_processor.do_rescale:
      self.rescale_factor = hf_processor.rescale_factor
    else:
      self.rescale_factor = 1.0

    # Initialize image_mean and image_std as instance variables
    self.image_mean = torch.tensor(hf_processor.image_mean).view(1, -1, 1, 1)  # Shape: [1, C, 1, 1]
    self.image_std = torch.tensor(hf_processor.image_std).view(1, -1, 1, 1) # Shape: [1, C, 1, 1]

  def forward(self, image: torch.Tensor):
    # BHWC -> BCHW.
    image = image.permute(0, 3, 1, 2)
    # Scale [0, 255] -> [0, 1].
    image = image * self.rescale_factor
    # Normalize
    image = (image - self.image_mean) / self.image_std
    logits = self.model(pixel_values=image).logits  # [B, 1000] float32.
    # Softmax is required for MediaPipe classification model.
    logits = torch.nn.functional.softmax(logits, dim=-1)

    return logits


hf_convnext_v2_processor = ConvNextImageProcessor.from_pretrained(hf_model_path)
hf_convnext_v2_model = ConvNextV2ForImageClassification.from_pretrained(hf_model_path)
wrapped_pt_model = HF2LiteRT_ImageClassificationModelWrapper(
    hf_convnext_v2_model, hf_convnext_v2_processor
).eval()

## Convert to `TFLite`

Now it's time to perform the conversion! You will need to provide simple arguments, such as the expected input shape (in this case three layers for images that are 224 height by 224 width).

In [None]:
import ai_edge_torch

sample_args = (torch.rand((1, 224, 224, 3)),)
edge_model = ai_edge_torch.convert(wrapped_pt_model, sample_args)

## Export the Converted Model

Running the following saves the converted model as a **FlatBuffer** file, which is compatible with **LiteRT**.


In [None]:
from pathlib import Path

TFLITE_MODEL_PATH = 'hf_convnext_v2_mp_image_classification_raw.tflite'
flatbuffer_file = Path(TFLITE_MODEL_PATH)
edge_model.export(flatbuffer_file)

# Validate converted model with LiteRT

Now it's time to test your newly converted model directly with the LiteRT Interpreter API. Before getting into that code, you can add the following utility functions to improve the output displayed.

In [None]:
#@markdown Functions to visualize the image classification results. <br/> Run this cell to activate the functions.

from matplotlib import pyplot as plt
plt.rcParams.update({
    'axes.spines.top': False,
    'axes.spines.right': False,
    'axes.spines.left': False,
    'axes.spines.bottom': False,
    'xtick.labelbottom': False,
    'xtick.bottom': False,
    'ytick.labelleft': False,
    'ytick.left': False,
    'xtick.labeltop': False,
    'xtick.top': False,
    'ytick.labelright': False,
    'ytick.right': False
})


def display_one_image(image, title, subplot, titlesize=16):
    """Displays one image along with the predicted category name and score."""
    plt.subplot(*subplot)
    plt.imshow(image)
    if len(title) > 0:
        plt.title(title, fontsize=int(titlesize), color='black', fontdict={'verticalalignment':'center'}, pad=int(titlesize/1.5))
    return (subplot[0], subplot[1], subplot[2]+1)

def display_batch_of_images(images, predictions):
    """Displays a batch of images with the classifications."""
    # Auto-squaring: this will drop data that does not fit into square or square-ish rectangle.
    rows = int(math.sqrt(len(images)))
    cols = len(images) // rows

    # Size and spacing.
    FIGSIZE = 13.0
    SPACING = 0.1
    subplot=(rows,cols, 1)
    if rows < cols:
        plt.figure(figsize=(FIGSIZE,FIGSIZE/cols*rows))
    else:
        plt.figure(figsize=(FIGSIZE/rows*cols,FIGSIZE))

    # Display.
    for i, (image, prediction) in enumerate(zip(images[:rows*cols], predictions[:rows*cols])):
        dynamic_titlesize = FIGSIZE * SPACING / max(rows,cols) * 40 + 3
        subplot = display_one_image(image, prediction, subplot, titlesize=dynamic_titlesize)

    # Layout.
    plt.tight_layout()
    plt.subplots_adjust(wspace=SPACING, hspace=SPACING)
    plt.show()

## Inference with LiteRT Interpreter

Now it's time to move on to the actual inference code and display the highest confidence classification result. Let's now run inference using the converted LiteRT model and compare the results with the original PyTorch model.

In [None]:
import numpy as np

# Load the LiteRT model and allocate tensors.
from ai_edge_litert.interpreter import Interpreter

# Path to the converted LiteRT model
TFLITE_MODEL_PATH = 'hf_convnext_v2_mp_image_classification_raw.tflite'

# Initialize the LiteRT interpreter
interpreter = Interpreter(TFLITE_MODEL_PATH)
interpreter.allocate_tensors()

# Get input and output tensor details
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

print("LiteRT Model Input Details:", input_details)
print("LiteRT Model Output Details:", output_details)

## Define Preprocessing and Postprocessing Functions
Prepare functions to preprocess images for LiteRT and to extract top predictions.

In [None]:
def preprocess_image_lite(image_path, size=(224, 224)):
    """
    Loads an image, resizes it to the specified size, and converts it to a NumPy array.
    """
    image = Image.open(image_path).convert('RGB')
    image_resized = image.resize(size, Image.Resampling.BILINEAR)
    image_array = np.array(image_resized).astype(np.float32)
    # Expand dimensions to match model's expected input shape (1, H, W, C)
    image_array = np.expand_dims(image_array, axis=0)
    return image_array

def get_top_k_predictions_lite(output, k=5):
    """
    Returns the top K predictions from the already softmaxed output.
    """
    # Convert the numpy array to a PyTorch tensor
    probs_tensor = torch.from_numpy(output)

    # Retrieve the top K probabilities and their corresponding indices
    top_probs, top_indices = torch.topk(probs_tensor, k)

    # Convert the results back to numpy arrays and flatten them
    return top_probs.numpy().flatten(), top_indices.numpy().flatten()

## Run Inference and Visualize
Execute the inference process and visualize the predictions.



In [None]:
images = []
predictions = []

for image_name in IMAGE_FILENAMES:
    # STEP 1: Load the input image(s).
    image = np.array(Image.open(image_name).convert('RGB'))

    # STEP 2: Load and preprocess the input image
    lite_input = preprocess_image_lite(image_name, size=(224, 224))

    # STEP 3: Classify the input image using LiteRT model
    interpreter.set_tensor(input_details[0]['index'], lite_input)
    interpreter.invoke()
    lite_output = interpreter.get_tensor(output_details[0]['index'])

    # STEP 4: Process the classification result (get top 5 predictions)
    lite_probs, lite_indices = get_top_k_predictions_lite(lite_output, k=5)

    # STEP 5: Get the top category (highest probability) and visualize
    top_prob = lite_probs[0]
    top_idx = lite_indices[0]
    top_category_name = pt_model.config.id2label[top_idx]
    prediction_text = f"{top_category_name} ({top_prob * 100:.2f}%)"

    images.append(image)
    predictions.append(prediction_text)

# Display the image with prediction
display_batch_of_images(images, predictions)

You should now see your loaded test images and their confidence scores/classifications that match the original PyTorch model results! If everything looks good, the final step should be downloading your newly converted `tflite` model file to your computer so you can use it elsewhere.

In [None]:
from google.colab import files

files.download(TFLITE_MODEL_PATH)

# Next steps

Now that you have learned how to convert a PyTorch model to the LiteRT format, it's time to check out the [LiteRT Interpreter API](https://ai.google.dev/edge/litert) for running other custom solutions, and read more about the PyTorch to LiteRT framework with our [official documentation](https://ai.google.dev/edge/lite/models/convert_pytorch).