# Model Compression Toolkit (MCT) Wrapper API Comprehensive Quantization Comparison (pytorch)

[Run this tutorial in Google Colab](https://colab.research.google.com/github/SonySemiconductorSolutions/mct-model-optimization/blob/main/tutorials/notebooks/mct_features_notebooks/pytorch/example_pytorch_mct_wrapper.ipynb)

## Overview 
This notebook provides a comprehensive demonstration of the MCT (Model Compression Toolkit) Wrapper API functionality, showcasing five different quantization methods on a MobileNetV2 model. The tutorial systematically compares the implementation, performance characteristics, and accuracy trade-offs of each quantization approach: PTQ (Post-Training Quantization), PTQ with Mixed Precision, GPTQ (Gradient-based PTQ), GPTQ with Mixed Precision. Each method utilizes the unified MCTWrapper interface for consistent implementation and comparison.

## Summary
1. **Environment Setup**: Import required libraries and configure MCT with MobileNetV2 model
2. **Dataset Preparation**: Load and prepare ImageNet validation dataset with representative data generation
3. **PTQ Implementation**: Execute basic Post-Training Quantization with 8-bit precision and bias correction
4. **PTQ + Mixed Precision**: Apply intelligent bit-width allocation based on layer sensitivity analysis (75% compression ratio)
5. **GPTQ Implementation**: Perform gradient-based optimization with 5-epoch fine-tuning for enhanced accuracy
6. **GPTQ + Mixed Precision**: Combine gradient optimization with mixed precision for optimal accuracy-compression trade-off
7. **Performance Evaluation**: Comprehensive accuracy assessment and comparison across all quantization methods
8. **Results Analysis**: Compare model sizes, inference accuracy, and quantization trade-offs

## Setup

In [None]:
!pip install onnx==1.16.1
!pip install -q torch==2.6.0 torchvision==0.21.0
!pip install tqdm
from typing import Tuple, Callable
from tqdm import tqdm

In [None]:
import importlib
if not importlib.util.find_spec('model_compression_toolkit'):
    !pip install model_compression_toolkit

# Import Model Compression Toolkit (MCT) core functionality for PyTorch
import model_compression_toolkit as mct
from model_compression_toolkit.core import QuantizationErrorMethod

In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision.models import mobilenet_v2, MobileNet_V2_Weights
from torchvision.datasets import ImageNet

Load a pre-trained MobileNetV2 model from torchvision, in 32-bits floating-point precision format.

In [None]:
weights = MobileNet_V2_Weights.IMAGENET1K_V2

float_model = mobilenet_v2(weights=weights)

## Dataset preparation
### Download ImageNet validation set
Download ImageNet dataset with only the validation split.

**Note** that for demonstration purposes we use the validation set for the model quantization routines. Usually, a subset of the training dataset is used, but loading it is a heavy procedure that is unnecessary for the sake of this demonstration.

This step may take several minutes...

In [None]:
import os

if not os.path.isdir('imagenet'):
    !mkdir imagenet
    !wget -P imagenet https://image-net.org/data/ILSVRC/2012/ILSVRC2012_devkit_t12.tar.gz
    !wget -P imagenet https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar

Extract ImageNet validation dataset using torchvision "datasets" module.

In [None]:
dataset = ImageNet(root='./imagenet', split='val', transform=weights.transforms())

## Representative Dataset
For quantization with MCT, we need to define a representative dataset required by the PTQ algorithm. This dataset is a generator that returns a list of images:

In [None]:
batch_size = 16
n_iter = 10

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

def representative_dataset_gen():
    dataloader_iter = iter(dataloader)
    for _ in range(n_iter):
        yield [next(dataloader_iter)[0]]

## Model Evaluation Function
Define a comprehensive evaluation function for PyTorch models that provides accurate performance measurement on the validation dataset with GPU acceleration support.

In [None]:
def evaluate(model: torch.nn.Module, testloader: DataLoader, mode: str) -> float:
    """
    Evaluate PyTorch model accuracy using a DataLoader with GPU acceleration.
    
    This function performs complete accuracy evaluation by:
    - Moving model and data to available device (GPU/CPU)
    - Running inference in evaluation mode (no gradient computation)
    - Computing Top-1 accuracy across the entire validation set
    - Providing progress tracking during evaluation
    
    Args:
        model: PyTorch model to evaluate (float or quantized)
        testloader: DataLoader containing validation dataset
        mode: String identifier for logging (e.g., 'Float', 'PTQ_Pytorch')
    
    Returns:
        float: Top-1 accuracy percentage
    """
    # Determine best available device for inference
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()  # Set model to evaluation mode
    correct = 0
    total = 0
    
    # Perform inference without gradient computation for efficiency
    with torch.no_grad():
        for data in tqdm(testloader):
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            
            # Forward pass to get predictions
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    # Calculate and display accuracy
    val_acc = (100 * correct / total)
    print(mode + ' Accuracy: %.2f%%' % val_acc)
    return val_acc

## Model Post-Training quantization using MCTWrapper

In [None]:
# Decorator to provide consistent logging and error handling for quantization functions
def decorator(func: Callable[[torch.nn.Module], Tuple[bool, torch.nn.Module]]) -> Callable[[torch.nn.Module], Tuple[bool, torch.nn.Module]]:

    """
    Wrapper decorator that provides:
    - Consistent start/end logging for quantization operations
    - Automatic error handling and program termination on failure
    - Success/failure status tracking for all quantization methods
    
    Args:
        func: Quantization function to be decorated
    
    Returns:
        Wrapped function with enhanced logging and error handling
    """
    def wrapper(*args, **kwargs):
        print(f"----------------- {func.__name__} Start ---------------")
        flag, result = func(*args, **kwargs)
        print(f"----------------- {func.__name__} End -----------------")
        if not flag:
            exit()
        return flag, result
    return wrapper

Run PTQ (Post-Training Quantization) with PyTorch

In [None]:
@decorator
def PTQ_Pytorch(float_model: torch.nn.Module) -> Tuple[bool, torch.nn.Module]:
    """
    Perform Post-Training Quantization (PTQ) on PyTorch model.
    
    PTQ for PyTorch provides:
    - Fast quantization without model retraining
    - Standard 8-bit integer quantization
    - Efficient calibration using representative data
    - Direct ONNX export for deployment
    
    Args:
        float_model: Original floating-point PyTorch model
    
    Returns:
        tuple: (success_flag, quantized_model)
    """
    # Configuration for PyTorch PTQ quantization
    framework = 'pytorch'             # Target framework (PyTorch)
    method = 'PTQ'                    # Post-Training Quantization method
    use_mixed_precision = False                  # Disable mixed-precision quantization

    # Parameter configuration for PyTorch PTQ
    param_items = [
        ['sdsp_version', '3.14'],  # The version of the SDSP converter.
        ['activation_error_method', QuantizationErrorMethod.MSE],  # Error metric for activation.
        ['weights_bias_correction', True],  # Enable bias correction for weights
        ['z_threshold', float('inf')],  # Z-threshold for quantization
        ['linear_collapsing', True],  # Enable linear layer collapsing optimization
        ['residual_collapsing', True],  # Enable residual layer collapsing optimization
        ['save_model_path', './qmodel_PTQ_Pytorch.onnx']  # Path to save quantized model as ONNX.
    ]

    # Execute PyTorch PTQ quantization using MCTWrapper
    wrapper = mct.wrapper.mct_wrapper.MCTWrapper()
    flag, quantized_model = wrapper.quantize_and_export(
        float_model=float_model, 
        representative_dataset=representative_dataset_gen, 
        framework=framework, 
        method=method, 
        use_mixed_precision=use_mixed_precision, 
        param_items=param_items)
    return flag, quantized_model

Run PTQ + Mixed Precision Quantization with PyTorch

In [None]:
@decorator
def PTQ_Pytorch_mixed_precision(float_model: torch.nn.Module) -> Tuple[bool, torch.nn.Module]:
    """
    Perform Post-Training Quantization with Mixed Precision (PTQ + mixed_precision) on PyTorch model.
    
    Mixed Precision PTQ for PyTorch offers:
    - Automatic bit-width selection per layer
    - Optimal size/accuracy trade-off
    - Resource-constrained quantization
    - Advanced sensitivity analysis for PyTorch models
    
    Args:
        float_model: Original floating-point PyTorch model
    
    Returns:
        tuple: (success_flag, quantized_model)
    """
    # Configuration for PyTorch PTQ with mixed precision
    framework = 'pytorch'             # Target framework (PyTorch)
    method = 'PTQ'                    # Post-Training Quantization method
    use_mixed_precision = True                   # Enable mixed-precision quantization

    # Parameter configuration for PyTorch PTQ with Mixed Precision
    param_items = [
        ['sdsp_version', '3.14'],  # The version of the SDSP converter.
        ['num_of_images', 5],  # Number of images for calibration
        ['use_hessian_based_scores', False],  # Use Hessian-based sensitivity scores
        ['weights_compression_ratio', 0.5],  # Compression ratio for weights
        ['save_model_path', './qmodel_PTQ_Pytorch_mixed_precision.onnx']  # Path to save quantized model as ONNX.
    ]

    # Execute PyTorch mixed precision PTQ using MCTWrapper
    wrapper = mct.wrapper.mct_wrapper.MCTWrapper()
    flag, quantized_model = wrapper.quantize_and_export(
        float_model=float_model, 
        representative_dataset=representative_dataset_gen, 
        framework=framework, 
        method=method, 
        use_mixed_precision=use_mixed_precision, 
        param_items=param_items)
    return flag, quantized_model

Run GPTQ (Gradient-based PTQ) with PyTorch

In [None]:
@decorator
def GPTQ_Pytorch(float_model: torch.nn.Module) -> Tuple[bool, torch.nn.Module]:
    """
    Perform Gradient-based Post-Training Quantization (GPTQ) on PyTorch model.
    
    GPTQ for PyTorch provides:
    - Advanced gradient-based quantization optimization
    - Fine-tuning during quantization process
    - Superior accuracy preservation compared to PTQ
    - Optimized parameter updates using representative data
    
    Args:
        float_model: Original floating-point PyTorch model
    
    Returns:
        tuple: (success_flag, quantized_model)
    """
    # Configuration for PyTorch GPTQ quantization
    framework = 'pytorch'             # Target framework (PyTorch)
    method = 'GPTQ'                   # Gradient-based Post-Training Quantization
    use_mixed_precision = False                  # Disable mixed-precision quantization

    # Parameter configuration for PyTorch GPTQ
    param_items = [
        ['sdsp_version', '3.14'],  # The version of the SDSP converter.
        ['n_epochs', 5],  # Number of epochs for gradient fine-tuning
        ['optimizer', None],  # Optimizer (None = use default)
        ['save_model_path', './qmodel_GPTQ_Pytorch.onnx']  # Path to save quantized model as ONNX.
    ]

    # Execute PyTorch GPTQ quantization using MCTWrapper
    wrapper = mct.wrapper.mct_wrapper.MCTWrapper()
    flag, quantized_model = wrapper.quantize_and_export(
        float_model=float_model, 
        representative_dataset=representative_dataset_gen, 
        framework=framework, 
        method=method, 
        use_mixed_precision=use_mixed_precision, 
        param_items=param_items)
    return flag, quantized_model

Run GPTQ + Mixed Precision Quantization with PyTorch

In [None]:
@decorator
def GPTQ_Pytorch_mixed_precision(float_model: torch.nn.Module) -> Tuple[bool, torch.nn.Module]:
    """
    Perform Gradient-based Post-Training Quantization with Mixed Precision (GPTQ + mixed_precision).
    
    This advanced method combines:
    - GPTQ: Gradient-based optimization for optimal quantization parameters
    - Mixed Precision: Automatic bit-width selection for each layer
    
    Provides the best quantization results for PyTorch models with:
    - Maximum accuracy preservation
    - Optimal model size reduction
    - Layer-wise precision optimization
    - Advanced gradient-based calibration
    
    Args:
        float_model: Original floating-point PyTorch model
    
    Returns:
        tuple: (success_flag, quantized_model)
    """
    # Configuration for PyTorch GPTQ with mixed precision
    framework = 'pytorch'             # Target framework (PyTorch)
    method = 'GPTQ'                   # Gradient-based Post-Training Quantization
    use_mixed_precision = True                   # Enable mixed-precision quantization

    # Parameter configuration for PyTorch GPTQ with Mixed Precision
    param_items = [
        ['sdsp_version', '3.14'],  # The version of the SDSP converter.
        ['n_epochs', 5],  # Number of epochs for gradient fine-tuning
        ['optimizer', None],  # Optimizer (None = use default)
        ['num_of_images', 5],  # Number of images for calibration
        ['use_hessian_based_scores', False],  # Use Hessian-based sensitivity scores
        ['weights_compression_ratio', 0.5],  # Compression ratio for weights
        ['save_model_path', './qmodel_GPTQ_Pytorch_mixed_precision.onnx']  # Path to save quantized model as ONNX.
    ]

    # Execute advanced PyTorch GPTQ+mixed_precision quantization using MCTWrapper
    wrapper = mct.wrapper.mct_wrapper.MCTWrapper()
    flag, quantized_model = wrapper.quantize_and_export(
        float_model=float_model, 
        representative_dataset=representative_dataset_gen, 
        framework=framework, 
        method=method, 
        use_mixed_precision=use_mixed_precision, 
        param_items=param_items)
    return flag, quantized_model

### Run model Post-Training Quantization
Lastly, we quantize our model using MCTWrapper API.

In [None]:
# Create DataLoader for validation/evaluation with larger batch size for efficiency
val_dataloader = DataLoader(dataset, batch_size=50, shuffle=False)

In [None]:
# Execute all PyTorch quantization methods on the same base model for comparison
print("Starting PyTorch quantization experiments with different methods...")

In [None]:
# 1. Basic Post-Training Quantization for PyTorch
flag, quantized_mode_ptq = PTQ_Pytorch(float_model)

In [None]:
# 2. PTQ with Mixed Precision (optimized size/accuracy trade-off for PyTorch)
flag, quantized_model_ptq_mixed_precision = PTQ_Pytorch_mixed_precision(float_model)

In [None]:
# 3. Gradient-based PTQ (improved accuracy through fine-tuning for PyTorch)
flag, quantized_model_gptq = GPTQ_Pytorch(float_model)

In [None]:
# 4. GPTQ with Mixed Precision (best accuracy with optimal compression for PyTorch)
flag, quantized_model_gptq_mixed_precision = GPTQ_Pytorch_mixed_precision(float_model)

In [None]:
print("All PyTorch quantization methods completed successfully!")

## Models evaluation
In order to evaluate our models, we first need to load the validation dataset. As before, please ensure that the dataset path has been set correctly.

In [None]:
# PyTorch Model Evaluation and Accuracy Comparison
print("Starting PyTorch model evaluation phase...")
print("This evaluation will test all quantized models against the validation dataset")

# Evaluate original floating-point PyTorch model accuracy
print("\n=== Original PyTorch Model Evaluation ===")
evaluate(float_model, val_dataloader, 'Float')

In [None]:
# Evaluate PTQ quantized PyTorch model accuracy
print("\n=== PyTorch PTQ Model Evaluation ===")
evaluate(quantized_mode_ptq, val_dataloader, 'PTQ_Pytorch')

In [None]:
# Evaluate PTQ + Mixed Precision PyTorch model accuracy
print("\n=== PyTorch PTQ + Mixed Precision Model Evaluation ===")
evaluate(quantized_model_ptq_mixed_precision, val_dataloader, 'PTQ_Pytorch_mixed_precision')

In [None]:
# Evaluate GPTQ quantized PyTorch model accuracy
print("\n=== PyTorch GPTQ Model Evaluation ===")
evaluate(quantized_model_gptq, val_dataloader, 'GPTQ_Pytorch')

In [None]:
# Evaluate GPTQ + Mixed Precision PyTorch model accuracy
print("\n=== PyTorch GPTQ + Mixed Precision Model Evaluation ===")
evaluate(quantized_model_gptq_mixed_precision, val_dataloader, 'GPTQ_Pytorch_mixed_precision')

In [None]:
print("finish")

## Conclusion

In this tutorial, we demonstrated how to quantize a pre-trained model using MCTWrapper with a few lines of code.

MCT can deliver competitive results across a wide range of tasks and network architectures. For more details, [check out the paper:](https://arxiv.org/abs/2109.09113).

## Copyrights

Copyright 2025 Sony Semiconductor Solutions, Inc. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
