# Model Compression Toolkit (MCT) Wrapper API Comprehensive Quantization Comparison(tensorflow)

[Run this tutorial in Google Colab](https://colab.research.google.com/github/SonySemiconductorSolutions/mct-model-optimization/blob/main/tutorials/notebooks/mct_features_notebooks/keras/example_keras_mct_wrapper.ipynb)

## Overview 
This notebook provides a comprehensive demonstration of the MCT (Model Compression Toolkit) Wrapper API functionality, showcasing five different quantization methods on a MobileNetV2 model. The tutorial systematically compares the implementation, performance characteristics, and accuracy trade-offs of each quantization approach: PTQ (Post-Training Quantization), PTQ with Mixed Precision, GPTQ (Gradient-based PTQ), GPTQ with Mixed Precision. Each method utilizes the unified MCTWrapper interface for consistent implementation and comparison.

## Summary
1. **Environment Setup**: Import required libraries and configure MCT with MobileNetV2 model
2. **Dataset Preparation**: Load and prepare ImageNet validation dataset with representative data generation
3. **PTQ Implementation**: Execute basic Post-Training Quantization with 8-bit precision and bias correction
4. **PTQ + Mixed Precision**: Apply intelligent bit-width allocation based on layer sensitivity analysis (75% compression ratio)
5. **GPTQ Implementation**: Perform gradient-based optimization with 5-epoch fine-tuning for enhanced accuracy
6. **GPTQ + Mixed Precision**: Combine gradient optimization with mixed precision for optimal accuracy-compression trade-off
7. **Performance Evaluation**: Comprehensive accuracy assessment and comparison across all quantization methods
8. **Results Analysis**: Compare model sizes, inference accuracy, and quantization trade-offs

## Setup

In [None]:
TF_VER = '2.14.0'
!pip install -q tensorflow~={TF_VER}

In [None]:
import importlib
if not importlib.util.find_spec('model_compression_toolkit'):
    !pip install model_compression_toolkit

import model_compression_toolkit as mct
from model_compression_toolkit.core import QuantizationErrorMethod

In [None]:
import keras
import tensorflow as tf
from typing import Tuple, Callable, Generator, List, Any

Load a pre-trained MobileNetV2 model from Keras, in 32-bits floating-point precision format.

In [None]:
from keras.applications.mobilenet_v2 import MobileNetV2

float_model = MobileNetV2()

## Dataset preparation
Download ImageNet dataset with only the validation split.

**Note** that for demonstration purposes we use the validation set for the model quantization routines. Usually, a subset of the training dataset is used, but loading it is a heavy procedure that is unnecessary for the sake of this demonstration.

This step may take several minutes...

In [None]:
import os
 
if not os.path.isdir('imagenet'):
    !mkdir imagenet
    !wget -P imagenet https://image-net.org/data/ILSVRC/2012/ILSVRC2012_devkit_t12.tar.gz
    !wget -P imagenet https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar
    
    !cd imagenet && tar -xzf ILSVRC2012_devkit_t12.tar.gz && \
     mkdir ILSVRC2012_img_val && tar -xf ILSVRC2012_img_val.tar -C ILSVRC2012_img_val

The following code organizes the extracted data into separate folders for each label, making it compatible with Keras dataset loaders.

In [None]:
from pathlib import Path
import shutil

root = Path('./imagenet')
imgs_dir = root / 'ILSVRC2012_img_val'
target_dir = root /'val'

def extract_labels():
    !pip install -q scipy
    import scipy
    mat = scipy.io.loadmat(root / 'ILSVRC2012_devkit_t12/data/meta.mat', squeeze_me=True)
    cls_to_nid = {s[0]: s[1] for i, s in enumerate(mat['synsets']) if s[4] == 0} 
    with open(root / 'ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt', 'r') as f:
        return [cls_to_nid[int(cls)] for cls in f.readlines()]

if not target_dir.exists():
    labels = extract_labels()
    for lbl in set(labels):
        os.makedirs(target_dir / lbl)
    
    for img_file, lbl in zip(sorted(os.listdir(imgs_dir)), labels):
        shutil.move(imgs_dir / img_file, target_dir / lbl)

These functions generate a `tf.data.Dataset` from image files in a directory.

In [None]:
def imagenet_preprocess_input(images: tf.Tensor, labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
    return tf.keras.applications.mobilenet_v2.preprocess_input(images), labels

def get_dataset(batch_size: int, shuffle: bool) -> tf.data.Dataset:
    dataset = tf.keras.utils.image_dataset_from_directory(
        directory='./imagenet/val',
        batch_size=batch_size,
        image_size=[224, 224],
        shuffle=shuffle,
        crop_to_aspect_ratio=True,
        interpolation='bilinear')
    dataset = dataset.map(lambda x, y: (imagenet_preprocess_input(x, y)), num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
    return dataset

## Representative Dataset
For quantization with MCT, we need to define a representative dataset required by the PTQ algorithm. This dataset is a generator that returns a list of images:

In [None]:
batch_size = 16
n_iter = 5

dataset = get_dataset(batch_size, shuffle=True)

def representative_dataset_gen() -> Generator[List[Any], None, None]:
    for _ in range(n_iter):
        yield [dataset.take(1).get_single_element()[0].numpy()]

## Model Post-Training quantization using MCTWrapper

In [None]:
# Decorator to provide consistent logging and error handling for quantization functions
def decorator(func: Callable[[keras.Model], Tuple[bool, keras.Model]]) -> Callable[[keras.Model], Tuple[bool, keras.Model]]:
    """
    Wrapper decorator that provides standardized execution logging and error handling.
    
    This decorator enhances quantization functions by:
    - Providing clear start/end execution markers for debugging
    - Handling success/failure status from quantization operations
    - Implementing fail-fast behavior on quantization errors
    - Ensuring consistent logging format across all quantization methods
    
    Usage:
        @decorator
        def quantization_function(model):
            # quantization implementation
            return flag, quantized_model
    
    Args:
        func: Function to be decorated (typically a quantization function)
    
    Returns:
        Wrapped function with enhanced logging and error handling capabilities
    """
    def wrapper(*args, **kwargs):
        # Log function execution start with clear delimiter
        print(f"----------------- {func.__name__} Start ---------------")
        
        # Execute the quantization function and capture return values
        # Expected return format: (success_flag, quantized_model)
        flag, result = func(*args, **kwargs)
        
        # Log function execution completion
        print(f"----------------- {func.__name__} End -----------------")
        
        # Implement fail-fast behavior: exit immediately on quantization failure
        # This ensures early detection of quantization issues
        if not flag:
            exit()
        
        # Return original function results if successful
        return flag, result
    
    return wrapper

Run PTQ (Post-Training Quantization) with Keras

In [None]:
@decorator
def PTQ_Keras(float_model: keras.Model) -> Tuple[bool, keras.Model]:
    """
    Perform Post-Training Quantization (PTQ) using MCT on Keras model.
    
    PTQ is a quantization method that:
    - Does not require model retraining
    - Uses representative data for calibration
    - Provides good accuracy with minimal computational overhead
    
    Args:
        float_model: Original floating-point Keras model
    
    Returns:
        tuple: (success_flag, quantized_model)
    """
    # Configuration for basic PTQ quantization
    method = 'PTQ'                    # Post-Training Quantization method
    framework = 'tensorflow'          # Target framework (Keras/TensorFlow)
    use_internal_tpc = True                # Use MCT's built-in Target Platform Capabilities
    use_mixed_precision = False                  # Disable mixed-precision quantization

    # Parameter configuration for PTQ
    param_items = [
        ['target_platform_version', 'v1'],  # The version of the TPC to use.
        ['activation_error_method', QuantizationErrorMethod.MSE],  # Error metric for activation.
        ['weights_bias_correction', True],  # Enable bias correction for weights.
        ['z_threshold', float('inf')],  # Threshold for zero-point quantization.
        ['linear_collapsing', True],  # Enable linear layer collapsing optimization.
        ['residual_collapsing', True],  # Enable residual connection collapsing.
        ['save_model_path', './qmodel_PTQ_Keras.keras']  # Path to save the quantized model.
    ]

    # Execute quantization using MCTWrapper
    wrapper = mct.wrapper.mct_wrapper.MCTWrapper()
    flag, quantized_model = wrapper.quantize_and_export(
        float_model=float_model, 
        representative_dataset=representative_dataset_gen,
        method=method, 
        framework=framework, 
        use_internal_tpc=use_internal_tpc, 
        use_mixed_precision=use_mixed_precision, 
        param_items=param_items)
    return flag, quantized_model

Run PTQ + Mixed Precision Quantization with Keras

In [None]:
@decorator
def PTQ_Keras_mixed_precision(float_model: keras.Model) -> Tuple[bool, keras.Model]:
    """
    Perform Post-Training Quantization with Mixed Precision (PTQ + mixed_precision) on Keras model.
    
    Mixed Precision Quantization:
    - Uses different bit-widths for different layers
    - Optimizes model size while maintaining accuracy
    - Automatically selects optimal precision for each layer
    - Uses resource constraints to guide precision allocation
    
    Args:
        float_model: Original floating-point Keras model
    
    Returns:
        tuple: (success_flag, quantized_model)
    """
    # Configuration for PTQ with mixed precision
    method = 'PTQ'                    # Post-Training Quantization method
    framework = 'tensorflow'          # Target framework (Keras/TensorFlow)
    use_internal_tpc = True                # Use MCT's built-in Target Platform Capabilities
    use_mixed_precision = True                   # Enable mixed-precision quantization

    # Parameter configuration for PTQ with Mixed Precision
    param_items = [
        ['target_platform_version', 'v1'],  # The version of the TPC to use.
        ['num_of_images', 5],  # Number of epochs for gradient-based fine-tuning.
        ['use_hessian_based_scores', False],  # Use Hessian-based sensitivity scores for layer importance.
        ['weights_compression_ratio', 0.75],  # Target compression ratio for model weights (75% of original size.
        ['save_model_path', './qmodel_PTQ_Keras_mixed_precision.keras']  # Path to save the quantized model.
    ]

    # Execute mixed precision quantization using MCTWrapper
    wrapper = mct.wrapper.mct_wrapper.MCTWrapper()
    flag, quantized_model = wrapper.quantize_and_export(
        float_model=float_model, 
        representative_dataset=representative_dataset_gen,
        method=method, 
        framework=framework, 
        use_internal_tpc=use_internal_tpc, 
        use_mixed_precision=use_mixed_precision, 
        param_items=param_items)
    return flag, quantized_model

Run GPTQ (Gradient-based PTQ) with Keras

In [None]:
@decorator
def GPTQ_Keras(float_model: keras.Model) -> Tuple[bool, keras.Model]:
    """
    Perform Gradient-based Post-Training Quantization (GPTQ) on Keras model.
    
    GPTQ is an advanced quantization method that:
    - Uses gradient information to optimize quantization parameters
    - Fine-tunes the model during quantization process
    - Generally provides better accuracy than standard PTQ
    - Requires slightly more computational resources than PTQ
    
    Args:
        float_model: Original floating-point Keras model
    
    Returns:
        tuple: (success_flag, quantized_model)
    """
    # Configuration for GPTQ quantization
    method = 'GPTQ'                   # Gradient-based Post-Training Quantization
    framework = 'tensorflow'          # Target framework (Keras/TensorFlow)
    use_internal_tpc = True                # Use MCT's built-in Target Platform Capabilities
    use_mixed_precision = False                  # Disable mixed-precision quantization

    # Parameter configuration for GPTQ
    param_items = [
        # Platform configuration
        ['target_platform_version', 'v1'],  # The version of the TPC to use.
        ['n_epochs', 5],  # Number of epochs for gradient-based fine-tuning.
        ['optimizer', None],  # Optimizer to use during fine-tuning.
        ['save_model_path', './qmodel_GPTQ_Keras.keras']  # Path to save the quantized model.
    ]

    # Execute GPTQ quantization using MCTWrapper
    wrapper = mct.wrapper.mct_wrapper.MCTWrapper()
    flag, quantized_model = wrapper.quantize_and_export(
        float_model=float_model, 
        representative_dataset=representative_dataset_gen,
        method=method, 
        framework=framework, 
        use_internal_tpc=use_internal_tpc, 
        use_mixed_precision=use_mixed_precision, 
        param_items=param_items)
    return flag, quantized_model

Run GPTQ + Mixed Precision Quantization with Keras

In [None]:
@decorator
def GPTQ_Keras_mixed_precision(float_model: keras.Model) -> Tuple[bool, keras.Model]:
    """
    Perform Gradient-based Post-Training Quantization with Mixed Precision (GPTQ + mixed_precision).
    
    This combines the benefits of both techniques:
    - GPTQ: Gradient-based optimization for better quantization accuracy
    - Mixed Precision: Optimal bit-width allocation for size/accuracy trade-off
    
    This is the most advanced quantization method available, providing:
    - Best possible accuracy preservation
    - Optimal model size reduction
    - Automatic precision selection per layer
    
    Args:
        float_model: Original floating-point Keras model
    
    Returns:
        tuple: (success_flag, quantized_model)
    """
    # Configuration for GPTQ with mixed precision
    method = 'GPTQ'                   # Gradient-based Post-Training Quantization
    framework = 'tensorflow'          # Target framework (Keras/TensorFlow)
    use_internal_tpc = True                # Use MCT's built-in Target Platform Capabilities
    use_mixed_precision = True                   # Enable mixed-precision quantization

    # Parameter configuration for GPTQ with Mixed Precision
    param_items = [
        # Platform configuration
        ['target_platform_version', 'v1'],  # The version of the TPC to use.
        ['n_epochs', 5],  # Number of epochs for gradient-based fine-tuning.
        ['optimizer', None],  # Optimizer to use during fine-tuning.
        ['num_of_images', 5],  # Number of images to use for calibration.
        ['use_hessian_based_scores', False],  # Whether to use Hessian-based scores for layer importance.
        ['weights_compression_ratio', 0.75],  # Compression ratio for weights.
        ['save_model_path', './qmodel_GPTQ_Keras_mixed_precision.keras']  # Path to save the quantized model.
    ]

    # Execute advanced GPTQ with mixed precision using MCTWrapper
    wrapper = mct.wrapper.mct_wrapper.MCTWrapper()
    flag, quantized_model = wrapper.quantize_and_export(
        float_model=float_model, 
        representative_dataset=representative_dataset_gen,
        method=method, 
        framework=framework, 
        use_internal_tpc=use_internal_tpc, 
        use_mixed_precision=use_mixed_precision, 
        param_items=param_items)
    return flag, quantized_model

### Run model Post-Training Quantization
Lastly, we quantize our model using MCTWrapper API.

In [None]:
# Execute comprehensive quantization method comparison using MCT Wrapper functionality
# Each method represents different trade-offs between accuracy, model size, and computation time
print("Starting quantization experiments with different methods...")

In [None]:
# Basic Post-Training Quantization (PTQ)
# - Standard 8-bit quantization without advanced optimization techniquesed
flag, quantized_model_ptq = PTQ_Keras(float_model)

In [None]:
# PTQ with Mixed Precision Quantization
# - Uses different bit-widths for different layers based on sensitivity analysis
flag, quantized_model_ptq_mixed_precision = PTQ_Keras_mixed_precision(float_model)

In [None]:
# Gradient-based Post-Training Quantization (GPTQ)
# - Uses gradient information to fine-tune quantization parameters during conversion
flag, quantized_model_gptq = GPTQ_Keras(float_model)

In [None]:
# GPTQ with Mixed Precision Quantization
# - Combines gradient-based optimization with mixed precision techniques
flag, quantized_model_gptq_mixed_precision = GPTQ_Keras_mixed_precision(float_model)

In [None]:
print("All quantization methods completed successfully!")

## Models evaluation
In order to evaluate our models, we first need to load the validation dataset. As before, please ensure that the dataset path has been set correctly.

In [None]:
# Model Evaluation and Accuracy Comparison
print("Starting model evaluation phase...")

# Prepare validation dataset for accuracy assessment
val_dataset = get_dataset(batch_size=50, shuffle=False)

In [None]:
# Evaluate original floating-point model accuracy
print("\n=== Original Model Evaluation ===")
float_model.compile(loss=keras.losses.SparseCategoricalCrossentropy(), metrics="accuracy")
float_accuracy = float_model.evaluate(val_dataset)
print(f"Float model's Top 1 accuracy on the Imagenet validation set: {(float_accuracy[1] * 100):.2f}%")

In [None]:
# Evaluate PTQ quantized model accuracy
print("\n=== PTQ Model Evaluation ===")
quantized_model_ptq.compile(loss=keras.losses.SparseCategoricalCrossentropy(), metrics="accuracy")
quantized_accuracy = quantized_model_ptq.evaluate(val_dataset)
print(f"PTQ_Keras Quantized model's Top 1 accuracy on the Imagenet validation set: {(quantized_accuracy[1] * 100):.2f}%")

In [None]:
# Evaluate PTQ + Mixed Precision model accuracy
print("\n=== PTQ + Mixed Precision Model Evaluation ===")
quantized_model_ptq_mixed_precision.compile(loss=keras.losses.SparseCategoricalCrossentropy(), metrics="accuracy")
quantized_accuracy = quantized_model_ptq_mixed_precision.evaluate(val_dataset)
print(f"PTQ_Keras_mixed_precision Quantized model's Top 1 accuracy on the Imagenet validation set: {(quantized_accuracy[1] * 100):.2f}%")

In [None]:
# Evaluate GPTQ quantized model accuracy
print("\n=== GPTQ Model Evaluation ===")
quantized_model_gptq.compile(loss=keras.losses.SparseCategoricalCrossentropy(), metrics="accuracy")
quantized_accuracy = quantized_model_gptq.evaluate(val_dataset)
print(f"GPTQ_Keras Quantized model's Top 1 accuracy on the Imagenet validation set: {(quantized_accuracy[1] * 100):.2f}%")

In [None]:
# Evaluate GPTQ + Mixed Precision model accuracy
print("\n=== GPTQ + Mixed Precision Model Evaluation ===")
quantized_model_gptq_mixed_precision.compile(loss=keras.losses.SparseCategoricalCrossentropy(), metrics="accuracy")
quantized_accuracy = quantized_model_gptq_mixed_precision.evaluate(val_dataset)
print(f"GPTQ_Keras_mixed_precision Quantized model's Top 1 accuracy on the Imagenet validation set: {(quantized_accuracy[1] * 100):.2f}%")

In [None]:
print("finish")

## Conclusion

In this tutorial, we demonstrated how to quantize a pre-trained model using MCTWrapper with a few lines of code.

MCT can deliver competitive results across a wide range of tasks and network architectures. For more details, [check out the paper:](https://arxiv.org/abs/2109.09113).

## Copyrights

Copyright 2025 Sony Semiconductor Solutions, Inc. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
