Demonstrate how `AI Edge Quantizer` can be used to do various quantization experiment with `isnet` (http://arxiv.org/abs/2108.12382).

#Install Necessary Dependencies


In [None]:
!pip install ai-edge-torch-nightly
!pip install ai-edge-quantizer-nightly
!pip install pillow requests matplotlib
!pip install ai-edge-model-explorer

In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import skimage
import tensorflow as tf
import ai_edge_quantizer
import model_explorer

In [None]:
# @title Preprocess/postprocess utilities (unrelated to quantization) { display-mode: "form" }
MODEL_INPUT_HW = (1024, 1024)

def make_channels_first(image):
  image = tf.transpose(image, [2, 0, 1])
  image = np.expand_dims(image, axis=0)
  return image

def preprocess_image(file_path):
  image = skimage.io.imread(file_path)
  image = tf.image.resize(image, MODEL_INPUT_HW).numpy().astype(np.float32)
  image = image / 255.0
  return make_channels_first(image)

def preprocess_image_ai_edge_torch(test_image_path):
  image = Image.open(test_image_path)
  test_image = np.array(image.resize(MODEL_INPUT_HW, Image.Resampling.BILINEAR))
  test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
  return test_image

def run_segmentation(image, tflite_model):
  """Get segmentation mask of the image."""
  interpreter = tf.lite.Interpreter(model_path=tflite_model)
  interpreter.allocate_tensors()

  input_details = interpreter.get_input_details()[0]
  interpreter.set_tensor(input_details["index"], image)
  interpreter.invoke()

  output_details = interpreter.get_output_details()
  output_index = 0
  outputs = []
  for detail in output_details:
    outputs.append(interpreter.get_tensor(detail["index"]))
  mask = tf.squeeze(outputs[output_index])
  # Min-max normalization.
  tf_min = np.min(mask)
  tf_max = np.max(mask)
  mask = (mask - tf_min) / (tf_max - tf_min)
  # Scale [0, 1] -> [0, 255].
  mask = (mask * 255)
  return mask


def draw_segementation(image, float_mask, quant_mask, info):
  _, ax = plt.subplots(1, 3, figsize=(15, 10))

  ax[0].imshow(np.array(image))
  ax[1].imshow(np.array(float_mask), cmap="gray")
  ax[2].imshow(np.array(quant_mask), cmap="gray")

  ax[1].set_title("Image")
  ax[1].set_title("Float Mask")
  ax[2].set_title("Quant Mask: {}".format(info))

  plt.show()

def save_tfl_model(model_content, save_path):
  with gfile.GFile(save_path, "wb") as f:
    f.write(model_content)
  print("model saved to: {}".format(save_path))




In [None]:
!curl -H 'Accept: application/vnd.github.v3.raw'  -O   -L https://api.github.com/repos/google-ai-edge/ai-edge-quantizer/contents/colabs/test_data/input_image.jpg

IMAGE_PATH = '/content/input_image.jpg'

test_image = preprocess_image_ai_edge_torch(IMAGE_PATH)

#Getting TFlite model From Pytorch.

ref: https://github.com/google-ai-edge/ai-edge-torch/blob/main/test/image_segmentation/colab/isnet_tfl.ipynb

AI Edge quantizer takes a float TFlite and produces a quantize TFlite model, so our first step is build the float TFlite model

In [None]:
# @title Clone IS-Net DIS repo and download Pytorch model

%cd /content
!rm -rf DIS sample_data

!git clone https://github.com/xuebinqin/DIS.git
%cd DIS/IS-Net/

!curl -o ./model.tar.gz -L https://www.kaggle.com/api/v1/models/paulruiz/dis/pyTorch/8-17-22/1/download
!tar -xvf 'model.tar.gz'

In [None]:
# @title Build torch model

import torch
from models import ISNetDIS


pytorch_model_filename = 'isnet-general-use.pth'
pt_model = ISNetDIS()
pt_model.load_state_dict(
    torch.load(pytorch_model_filename, map_location=torch.device('cpu'))
)

import torch
from torch import nn
from torchvision.transforms.functional import normalize


class ImageSegmentationModelWrapper(nn.Module):

  RESCALING_FACTOR = 255.0
  MEAN = 0.5
  STD = 1.0

  def __init__(self, pt_model):
    super().__init__()
    self.model = pt_model

  def forward(self, image: torch.Tensor):
    # BHWC -> BCHW.
    image = image.permute(0, 3, 1, 2)

    # Rescale [0, 255] -> [0, 1].
    image = image / self.RESCALING_FACTOR

    # Normalize.
    image = (image - self.MEAN) / self.STD

    # Get result.
    result = self.model(image)[0][0]

    # BHWC -> BCHW.
    result = result.permute(0, 2, 3, 1)

    return result


wrapped_pt_model = ImageSegmentationModelWrapper(pt_model).eval()

In [None]:
# @title Convert torch model to TFlite using AI Edge Torch

import ai_edge_torch

import time

start = time.time()

FLOAT_MODEL_PATH = os.path.join(BASE_MODEL_PATH, "isnet_float.tflite")
sample_args = (torch.rand((1, *MODEL_INPUT_HW, 3)),)
edge_model = ai_edge_torch.convert(wrapped_pt_model, sample_args)
edge_model.export(FLOAT_MODEL_PATH)

end = time.time()
print(end - start)


In [None]:
# @title Optional: visualize the model using model explorer
model_explorer.visualize(FLOAT_MODEL_PATH)

# AI Edge Quantizer

To use the `Quantizer`, we need to provide
* the float .tflite model.
* quantization recipe (i.e., apply quantization algorithm X on Operator Y with configuration Z).






### Quantizing model with dynamic quantization

When doing calibration free quantization, Tensorflow lite by default will quantize the weight to int8 format and employ integer execution by quantizing float activation on the fly. This is known as Dynamic Quantization in Tesorflow lite. https://www.tensorflow.org/lite/performance/post_training_quantization#dynamic_range_quantization


The following example will showcase how AI Edge Quantizer can achieve the same behaviour.

In [None]:
from ai_edge_quantizer import recipe

# Initialize with float .tflite
dyn_qt = ai_edge_quantizer.Quantizer(float_model=FLOAT_MODEL_PATH)

dyn_qt.load_quantization_recipe(recipe=recipe.dynamic_wi8_afp32())

# we will store the quantized model here
DYANMIC_QUANTIZED_MODEL_PATH = os.path.join(BASE_MODEL_PATH, "isnet_dynamic_wi8_afp32.tflite")
if os.path.exists(DYANMIC_QUANTIZED_MODEL_PATH):
  os.remove(DYANMIC_QUANTIZED_MODEL_PATH)

# Quantization result contains the quantized model and a copy of the quantization recipe
quantization_result = dyn_qt.quantize()
quantization_result.save(BASE_MODEL_PATH, "isnet_dynamic_wi8_afp32")


So what's going on? Let take a look step by step:

First, we have a prebaked quantization recipe. `dynamic_wi8_afp32` in the name means the recipe will apply dynamic quantization, where we quantize the weight to int8 and activation to float32.

Let's take a look at what in this recipe

In [None]:
recipe.dynamic_wi8_afp32()

Here the recipe means: apply the naive min/max uniform algorithm (`min_max_uniform_quantize`) for all ops supported by the AI Edge Quantizer (indicated by `*`) under layers satisfying regex `.*` (i.e., all layers). We want the weights of these ops to be quantized as int8, symmetric, channel_wise, and we want to execute the ops in `Integer` mode without explicitly adding dequantize op.

Note: Explicitly adding dequantized op is one way to enable other quantization mechanisms that we will cover later in this colab


`quantization_result` has two components
* quantized tflite model (in bytearray) and
* the corresponding quantization recipe

When we save, we will always save the pair so users know how the model is quantized.

In [None]:
quantization_result.recipe

Now let try running both the float model and the newly quantized model and see how they compare.

In [None]:
quantized_mask = run_segmentation(test_image, DYANMIC_QUANTIZED_MODEL_PATH)
float_mask = run_segmentation(test_image, FLOAT_MODEL_PATH)
draw_segementation(image, float_mask, quantized_mask, "AI Edge 8-bit Integer Execution")

### Weight only quantization

As shown above, the default solution used by Tensorflow lite didn't give us a good result compared to the original float model. This is because we lose precision when doing integer compute.

To increase accuarcy, we can create model that use float compute with quantized constant, otherwise known as weight-only quantization, let's how much we can improve.

In [None]:
wo_qt = ai_edge_quantizer.Quantizer(float_model=FLOAT_MODEL_PATH)

# Create & Overwrite quantization recipe
tensor_config_8bit = ai_edge_quantizer.qtyping.TensorQuantizationConfig(
    num_bits=8,
    symmetric=True,
    granularity=ai_edge_quantizer.qtyping.QuantGranularity.CHANNELWISE)

weight_only_op_config_8bit = ai_edge_quantizer.qtyping.OpQuantizationConfig(
    weight_tensor_config=tensor_config_8bit,
    compute_precision=ai_edge_quantizer.qtyping.ComputePrecision.FLOAT,
    explicit_dequantize=True)

wo_qt.update_quantization_recipe(regex=".*",
                              operation_name="*",
                              op_config=weight_only_op_config_8bit,
                              algorithm_key=ai_edge_quantizer.algorithm_manager.AlgorithmName.MIN_MAX_UNIFORM_QUANT)


Here we build a new quantization recipe, for tensor config, we still set them to do int8, symmetric, channel_wise quantization. But this time, we will be adding explcit dequantized into the model, and require the compute to be done in float.



In [None]:
# Quantize.
quantization_result = wo_qt.quantize()

# Save models
WEIGHT_ONLY_MODEL_PATH = os.path.join(BASE_MODEL_PATH, "isnet_weight_only_wi8_afp32.tflite")
if os.path.exists(WEIGHT_ONLY_MODEL_PATH):
  os.remove(WEIGHT_ONLY_MODEL_PATH)
quantization_result.save(BASE_MODEL_PATH, "isnet_weight_only_wi8_afp32")

# Side by side comparison with the float model.
quantized_mask = run_segmentation(test_image, WEIGHT_ONLY_MODEL_PATH)
draw_segementation(image, float_mask, quantized_mask, "AI Edge 8-bit Float Execution")

With weight only quantization, we have achieved the same quality as the float execution.

# Debug through Model Explorer (visualization)

Now we know that Float execution give us better quality result, but suffer in execution time. Integer execution runs faster but the quality is less adequate.

Can we try getting the best of both world?

Let's try to figure out where did dynamic execution loses precison first.

The following code will generate a tensor-by-tensor comparison result between the dynamic quantized model and original float model.



In [None]:
# tensor by tensor comparison (float vs. quantized) using median_diff_ratio
# as the metric (i.e., mdr = abs(float_tensor - dequantized_tensor)/float_tensor)
# and save the results in .json format to visualize through Model Explorer
comparion_result = dyn_qt.validate(
    signature_test_data=[{'args_0': test_image}], error_metrics='median_diff_ratio', use_reference_kernel=True
).save(BASE_MODEL_PATH, "dynamic")

Load `drq_median_diff_ratio.json` on top of the `.tflite` using `Model Explorer` (https://github.com/google-ai-edge/model-explorer) to see how errors propagate through the model.

In [None]:
config = model_explorer.config()

DYNAMIC_NODE_DATA_PATH = os.path.join(BASE_MODEL_PATH, "dynamic_comparison_result_me_input.json")

(config
 .add_model_from_path(DYANMIC_QUANTIZED_MODEL_PATH)
 .add_node_data_from_path(DYNAMIC_NODE_DATA_PATH))

model_explorer.visualize_from_config(config)

Using Model Explorer, we find that the errors come from the last few layers ('RSU6_stage2d', 'RSU7_stage1d', 'Conv2d_side1'). Lets try not quantize them.

## Selective Dynamic Quantization

Here we'll override the original `dynamic_wi8_afp32` recipe to skip the three scopes that produce inaccurate results. Notice that for each scope, the newly added rule always take precedence.

In [None]:
scopes = ['RSU6','RSU7','Conv2d_side1']
for scope in scopes:
  dyn_qt.update_quantization_recipe(
      regex=scope,
      operation_name="CONV_2D",
      algorithm_key='no_quantize',
  )
dyn_qt.get_quantization_recipe()

In [None]:
SELECTIVE_DYNAMIC_MODEL_PATH = os.path.join(BASE_MODEL_PATH, "isnet_selective_dynamic_wi8_afp32.tflite")
if os.path.exists(SELECTIVE_DYNAMIC_MODEL_PATH):
  os.remove(SELECTIVE_DYNAMIC_MODEL_PATH)

dyn_qt.quantize().save(BASE_MODEL_PATH, "isnet_selective_dynamic_wi8_afp32")
quantized_mask = run_segmentation(test_image, SELECTIVE_DYNAMIC_MODEL_PATH)
draw_segementation(image, float_mask, quantized_mask, "Selective Dynamic")


Can we do better? Lets try to mix `weight-only` with `dynamic`. In this way, we will improve model quality (comparing to full `dynamic quantized` model) while keeping the model size small.

In [None]:
# @title Dynamic Weight-only Mix

qt = ai_edge_quantizer.Quantizer(
    float_model=FLOAT_MODEL_PATH, quantization_recipe=recipe.dynamic_wi8_afp32()
)
for scope in scopes:
  qt.update_quantization_recipe(
      regex=scope,
      operation_name="CONV_2D",
      op_config=ai_edge_quantizer.qtyping.OpQuantizationConfig(
          weight_tensor_config=tensor_config_8bit,
          compute_precision=ai_edge_quantizer.qtyping.ComputePrecision.FLOAT,
          explicit_dequantize=True,
      ),
  )
# qt.get_quantization_recipe()

In [None]:
quantization_result = qt.quantize()

MIX_DYNAMIC_WEIGHT_ONLY_MODEL_PATH = os.path.join(BASE_MODEL_PATH, "isnet_dynamic_weight_only_mix_wi8_afp32.tflite")
if os.path.exists(MIX_DYNAMIC_WEIGHT_ONLY_MODEL_PATH):
  os.remove(MIX_DYNAMIC_WEIGHT_ONLY_MODEL_PATH)

quantization_result.save(
    save_folder=BASE_MODEL_PATH, model_name='isnet_dynamic_weight_only_mix_wi8_afp32'
)

quantized_mask = run_segmentation(test_image, MIX_DYNAMIC_WEIGHT_ONLY_MODEL_PATH)
draw_segementation(image, float_mask, quantized_mask, "dynamic weight-only mix")

## Can we do even better?

We've seen that int8 weight only essentially gives us similar quality result as the float model. We can try push our boundary and try to use int4.

In [None]:
# @title INT4 weight-only
qt = ai_edge_quantizer.Quantizer(float_model=FLOAT_MODEL_PATH)

tensor_config_4bit = ai_edge_quantizer.qtyping.TensorQuantizationConfig(
    num_bits=4,
    symmetric=False,
    granularity=ai_edge_quantizer.qtyping.QuantGranularity.CHANNELWISE)

weight_only_op_config_4bit = ai_edge_quantizer.qtyping.OpQuantizationConfig(
    weight_tensor_config=tensor_config_8bit,
    compute_precision=ai_edge_quantizer.qtyping.ComputePrecision.FLOAT,
    explicit_dequantize=True)

qt.update_quantization_recipe(
    regex=".*",
    operation_name="*",
    op_config=weight_only_op_config_4bit,
    algorithm_key=ai_edge_quantizer.algorithm_manager.AlgorithmName.MIN_MAX_UNIFORM_QUANT
)
quantization_result = qt.quantize()

INT_WEIGHT_ONLY_MODEL_PATH = os.path.join(BASE_MODEL_PATH, "isnet_weight_only_wi4_afp32.tflite")
if os.path.exists(INT_WEIGHT_ONLY_MODEL_PATH):
  os.remove(INT_WEIGHT_ONLY_MODEL_PATH)

quantization_result.save(
    save_folder=BASE_MODEL_PATH, model_name='isnet_weight_only_wi4_afp32'
)

quantized_mask = run_segmentation(test_image, INT_WEIGHT_ONLY_MODEL_PATH)
draw_segementation(image, float_mask, quantized_mask, "int4 weight only")