# Post-Training Sparsification and Quantization of PyTorch models with POT

The goal of this tutorial is to demonstrate how to use the Post-training Quantization with POT to optimize a PyTorch model for the high-speed inference via OpenVINO™ Toolkit. The optimization process contains the following steps:

1. Evaluate the original model.
2. Transform the original model to a quantized one.
3. Export optimized and original models to OpenVINO IR.

This tutorial uses a ResNet-50 model, pre-trained on Tiny ImageNet, which contains 100000 images of 200 classes (500 for each class) downsized to 64×64 colored images. The tutorial will demonstrate that only a tiny part of the dataset is needed for the post-training quantization, not demanding the fine-tuning of the model.


> **NOTE**: This notebook requires that a C++ compiler is accessible on the default binary search path of the OS you are running the notebook.

### Imports

In [None]:
import warnings 
warnings.filterwarnings('ignore')

import os
import sys
import time
import zipfile
from pathlib import Path
from typing import List, Tuple, Union

from openvino.runtime import Core, serialize
from openvino.tools import mo
from openvino.tools.pot import load_model, IEEngine, create_pipeline, compress_model_weights, save_model
from openvino.runtime.ie_api import CompiledModel

import torch
from torchvision.datasets import ImageFolder
from torchvision.models import resnet50
import torchvision.transforms as transforms

sys.path.append("../utils")
from notebook_utils import download_file

### Settings

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu" 
print(f"Using {device} device")

MODEL_DIR = Path("model")
OUTPUT_DIR = Path("output")
BASE_MODEL_NAME = "resnet50"
QUANTIZED_MODEL_NAME = BASE_MODEL_NAME + "_int8"
IMAGE_SIZE = [64, 64]

OUTPUT_DIR.mkdir(exist_ok=True)
MODEL_DIR.mkdir(exist_ok=True)

# Paths where PyTorch and OpenVINO IR models will be stored.
fp32_checkpoint_filename = Path(BASE_MODEL_NAME + "_fp32").with_suffix(".pth")
fp32_onnx_path = OUTPUT_DIR / Path(BASE_MODEL_NAME + "_fp32").with_suffix(".onnx")
fp32_ir_path = OUTPUT_DIR / Path(BASE_MODEL_NAME + "_fp32").with_suffix(".xml")
int8_onnx_path = OUTPUT_DIR / Path(BASE_MODEL_NAME + "_int8").with_suffix(".onnx")
sparse_int8_ir_path = OUTPUT_DIR / Path(BASE_MODEL_NAME + "_int8").with_suffix(".xml")
sparse_int8_ir_folder = OUTPUT_DIR
sparse_int8_ir_filename = Path(BASE_MODEL_NAME + "_sparse_int8").with_suffix(".xml")

fp32_pth_url = "https://storage.openvinotoolkit.org/repositories/nncf/openvino_notebook_ckpts/304_resnet50_fp32.pth"
download_file(fp32_pth_url, directory=MODEL_DIR, filename=fp32_checkpoint_filename)

### Download and Prepare Tiny ImageNet dataset

* 100k images of shape 3x64x64,
* 200 different classes: snake, spider, cat, truck, grasshopper, gull, etc.

In [None]:
def download_tiny_imagenet_200(
    output_dir: Path,
    url: str = "http://cs231n.stanford.edu/tiny-imagenet-200.zip",
    tarname: str = "tiny-imagenet-200.zip",
):
    archive_path = output_dir / tarname
    download_file(url, directory=output_dir, filename=tarname)
    zip_ref = zipfile.ZipFile(archive_path, "r")
    zip_ref.extractall(path=output_dir)
    zip_ref.close()
    print(f"Successfully downloaded and extracted dataset to: {output_dir}")


def create_validation_dir(dataset_dir: Path):
    VALID_DIR = dataset_dir / "val"
    val_img_dir = VALID_DIR / "images"

    fp = open(VALID_DIR / "val_annotations.txt", "r")
    data = fp.readlines()

    val_img_dict = {}
    for line in data:
        words = line.split("\t")
        val_img_dict[words[0]] = words[1]
    fp.close()

    for img, folder in val_img_dict.items():
        newpath = val_img_dir / folder
        if not newpath.exists():
            os.makedirs(newpath)
        if (val_img_dir / img).exists():
            os.rename(val_img_dir / img, newpath / img)

DATASET_DIR = OUTPUT_DIR / "tiny-imagenet-200"
if not DATASET_DIR.exists():
    download_tiny_imagenet_200(OUTPUT_DIR)
    create_validation_dir(DATASET_DIR)

### Helpers classes and functions
The code below will help to count accuracy and visualize validation process.

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, name: str, fmt: str = ":f"):
        self.name = name
        self.fmt = fmt
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val: float, n: int = 1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    """Displays the progress of validation process"""

    def __init__(self, num_batches: int, meters: List[AverageMeter], prefix: str = ""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch: int):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print("\t".join(entries))

    def _get_batch_fmtstr(self, num_batches: int):
        num_digits = len(str(num_batches // 1))
        fmt = "{:" + str(num_digits) + "d}"
        return "[" + fmt + "/" + fmt.format(num_batches) + "]"


def accuracy(output: torch.Tensor, target: torch.Tensor, topk: Tuple[int] = (1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))

        return res

### Validation function

In [None]:
def validate(model: Union[torch.nn.Module, CompiledModel]):
    """Compute the metrics using data from val_loader for the model"""
    batch_time = AverageMeter("Time", ":3.3f")
    top1 = AverageMeter("Acc@1", ":2.2f")
    top5 = AverageMeter("Acc@5", ":2.2f")
    
    # Switch to evaluate mode.
    val_dataset, val_loader = create_dataloader(batch_size=1)

    if isinstance(model, CompiledModel):
        def forward_fun(images, target):
            output_layer = model.output(0)
            output = model(images)[output_layer]
            return (torch.from_numpy(output), target)
        
    else:
        def forward_fun(images, target):
            return (model(images.to(device)), target.to(device))
    
    progress = ProgressMeter(len(val_loader), [batch_time, top1, top5], prefix="Test: ")
    start_time = time.time()
    
    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            
            # Compute the output.     
            output, target = forward_fun(images, target)
            
            # Measure accuracy and record loss.
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            # Measure elapsed time.
            batch_time.update(time.time() - end)
            end = time.time()

            print_frequency = 1000
            if i % print_frequency == 0:
                progress.display(i)

        print(
            " * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f} Total time: {total_time:.3f}".format(top1=top1, top5=top5, total_time=end - start_time)
        )
    return top1.avg

### Create and load original uncompressed model

ResNet-50 from the [torchvision repository](https://github.com/pytorch/vision) is pre-trained on ImageNet with more prediction classes than Tiny ImageNet, so the model is adjusted by swapping the last FC layer to one with fewer output values.

In [None]:
def create_model(model_path: Path):
    """Creates the ResNet-50 model and loads the pretrained weights"""
    model = resnet50()
    # Update the last FC layer for Tiny ImageNet number of classes.
    NUM_CLASSES = 200
    model.fc = torch.nn.Linear(in_features=2048, out_features=NUM_CLASSES, bias=True)
    model.to(device)
    if model_path.exists():
        checkpoint = torch.load(str(model_path), map_location="cpu")
        model.load_state_dict(checkpoint["state_dict"], strict=True)
        
    else:
        raise RuntimeError("There is no checkpoint to load")
    model.eval()
    return model

model = create_model(MODEL_DIR / fp32_checkpoint_filename)

In [None]:
def create_dataloader(batch_size: int = 1):
    """Creates train dataloader that is used for quantization initialization and validation dataloader for computing the model accuracy"""
    val_dir = DATASET_DIR / "val" / "images"

    val_dataset = ImageFolder(
        val_dir,
        transforms.Compose(
            [transforms.Resize(IMAGE_SIZE),
             transforms.PILToTensor(),
             transforms.ConvertImageDtype(torch.float),
             transforms.Normalize(
                 mean=[0.485, 0.456, 0.406], 
                 std=[0.229, 0.224, 0.225])
            ]
        )
    )
    
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True
    )
    return val_dataset, val_dataloader

## Model quantization and benchmarking
With the validation pipeline, model files, and data-loading procedures for model calibration now prepared, it's time to proceed with the actual post-training quantization using POT.

### I. Evaluate the loaded model

In [None]:
acc1 = validate(model)
print(f"Test accuracy of FP32 model: {acc1:.3f}")

### II. Create and initialize quantization


#### Load quantization function

In [None]:
def generate_pot_quantized_model(input_filepath,
                                 output_folder,
                                 output_filename,
                                 sparsity_level=0.0):

    print(f"Sparsifying and Quantizing {input_filepath}")

    model_config = {
        "model_name": "model",
        "model": str(input_filepath),
        "weights": str(input_filepath).replace("xml", "bin")
    }

    # Engine config.
    engine_config = {"device": "CPU"}
    algorithms = []
    
    if sparsity_level > 0:
        algorithms.append(
            {
                "name": "WeightSparsity",
                "params": {
                    "target_device": "CPU",
                    "sparsity_level": sparsity_level,
                    "stat_subset_size": 10000,
                    "ignored_scope": "images",
                    "use_fast_bias": True
                }
            })
           
    algorithms.append({
        "name": "DefaultQuantization",
        "params": {
            "target_device": "CPU",
            "stat_subset_size": 10000,
            "preset": "performance",
            "use_fast_bias": True
        },
    })
           
    # Step 1: Implement and create a user data loader.
    print("Step 1")
    val_dataset, val_dataloader = create_dataloader()
    
    # Step 2: Load a model.
    print("Step 2")
    model = load_model(model_config=model_config)

    # Step 3: Initialize the engine for metric calculation and statistics collection.
    print("Step 3")
    engine = IEEngine(config=engine_config, data_loader=val_dataset)
    
    # Step 4: Create a pipeline of compression algorithms and run it.
    print("Step 4")                                    
    pipeline = create_pipeline(algorithms, engine)
    compressed_model = pipeline.run(model=model)

    # Step 5 (Optional): Compress model weights to quantized precision
    #                     to reduce the size of the final .bin file.
    print("Step 5")
    compress_model_weights(compressed_model)

    # Step 6: Save the compressed model to the desired path.
    # Set save_path to the directory where the model should be saved.
    print("Step 6")
    save_model(
        model=compressed_model,
        save_path=output_folder,
        model_name=str(output_filename).replace(".xml", "")
    )

    print(f"Generated {output_folder}/{output_filename}")                  

### III. Convert the models to OpenVINO Intermediate Representation (OpenVINO IR)

Use Model Optimizer Python API to convert the Pytorch models to OpenVINO IR. The models will be saved to the 'OUTPUT' directory for latter benchmarking.

For more information about Model Optimizer, refer to the [Model Optimizer Developer Guide](https://docs.openvino.ai/2023.0/openvino_docs_MO_DG_Deep_Learning_Model_Optimizer_DevGuide.html).

Before converting models export them to ONNX. Executing the following command may take a while.

In [None]:
dummy_input = torch.randn(1, 3, *IMAGE_SIZE).to(device)

torch.onnx.export(model, dummy_input, fp32_onnx_path)
model_ir = mo.convert_model(input_model=fp32_onnx_path)

serialize(model_ir, str(fp32_ir_path))

2. Create a sparse (50%), quantized model from the pre-trained `FP32` model and the calibration dataset.

In [None]:
generate_pot_quantized_model(input_filepath=fp32_ir_path, 
                             output_folder=sparse_int8_ir_folder, 
                             output_filename=sparse_int8_ir_filename, 
                             sparsity_level=0.5)

3. Evaluate the new model on the validation set after initialization of quantization. The accuracy should be close to the accuracy of the floating-point `FP32` model for a simple case like the one being demonstrated now.

In [None]:
core = Core()

original_model = core.read_model(fp32_ir_path)
original_model = core.compile_model(original_model, "CPU")

acc1 = validate(original_model)
print(f"Test accuracy of FP32 model: {acc1:.3f}")

In [None]:
sparse_quantized_model = core.read_model(f"{sparse_int8_ir_folder}/{sparse_int8_ir_filename}")
sparse_quantized_model = core.compile_model(sparse_quantized_model, "CPU")

acc1 = validate(sparse_quantized_model)
print(f"Accuracy of initialized INT8 model: {acc1:.3f}")

# Result
Small accuracy drop is observed by inference considerably faster