In [None]:
# Copyright 2024 The AI Edge Quantizer Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

In [None]:
!pip install ai-edge-torch-nightly
!pip install ai-edge-quantizer-nightly
!pip install ai-edge-model-explorer

## Install

In [None]:
import logging

import tensorflow as tf
import numpy as np
print("TensorFlow version: ", tf.__version__)

import matplotlib.pylab as plt
import pathlib
import random

import ai_edge_torch

import torch
import torchvision

import numpy as np
import model_explorer

from ai_edge_quantizer import quantizer
from ai_edge_quantizer import recipe
from ai_edge_quantizer import qtyping
from ai_edge_quantizer.utils import tfl_flatbuffer_utils

## This Colab shows how to take a PyTorch model, convert using AI Edge Torch and then quantize with AI Edge Quantizer. More details of conversion of PyTorch models is at https://ai.google.dev/edge/litert/models/convert_pytorch

In [None]:
resnet18 = torchvision.models.resnet18(torchvision.models.ResNet18_Weights.IMAGENET1K_V1).eval()
sample_inputs = (torch.randn(1, 3, 224, 224),)
torch_output = resnet18(*sample_inputs)

# Conversion
edge_model = ai_edge_torch.convert(resnet18, sample_inputs)

# Inference
edge_output = edge_model(*sample_inputs)

# Validation
if np.allclose(torch_output.detach().numpy(), edge_output, atol=1e-5):
    print("Inference result with Pytorch and LiteRT was within tolerance")
else:
    print("Something wrong with Pytorch --> LiteRT")

# Serialization
edge_model.export('resnet.tflite')

# Model Explorer Visualization
model_explorer.visualize('resnet.tflite')

# Quantization (API will quantize and save a flatbuffer as *.tflite)
qt = quantizer.Quantizer('resnet.tflite', recipe.dynamic_wi8_afp32())
quant_result = qt.quantize().save("", "aeq_resnet")

## Compare size of flatbuffers


In [None]:
!ls -lh *.tflite