In [None]:
!pip -qq install super-gradients==3.6.0 pytorch-quantization==2.1.2 --extra-index-url https://pypi.ngc.nvidia.com

In [None]:
from super_gradients import Trainer
import pprint

experiment_name = "cifar100_ptq_qat_classification"

CHECKPOINT_DIR = 'checkpoints'
trainer = Trainer(experiment_name=experiment_name, ckpt_root_dir=CHECKPOINT_DIR)

In [None]:
from super_gradients.training import dataloaders
from torchvision import datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

# Define data augmentation transforms
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

valid_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Create train and validation datasets
train_dataset = datasets.CIFAR100(root="./data", train=True, download=True, transform=train_transform)
valid_dataset = datasets.CIFAR100(root="./data", train=False, download=True, transform=valid_transform)

# Create train and validation dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2, persistent_workers=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=64, shuffle=False, num_workers=2, persistent_workers=True)


Files already downloaded and verified
Files already downloaded and verified


In [None]:
from matplotlib import pyplot as plt

def show(images, labels, classes, rows=6, columns=5):
  fig = plt.figure(figsize=(10, 10))

  for i in range(1, columns * rows + 1):
      fig.add_subplot(rows, columns, i)
      plt.imshow(images[i-1].permute(1, 2, 0).clamp(0, 1))
      plt.xticks([])
      plt.yticks([])
      plt.title(f"{classes[labels[i-1]]}")

In [None]:
vis_images_train, vis_labels_train = next(iter(train_dataloader))
show(vis_images_train, vis_labels_train, classes=train_dataloader.dataset.classes)

print(vis_images_train.shape, vis_labels_train.shape)

In [None]:
from super_gradients.training import models
from super_gradients.common.object_names import Models

model_type = 'mobilenet'

if model_type == 'resnet':
  model = models.get(model_name=Models.RESNET50, num_classes=100, pretrained_weights="imagenet")
elif model_type == 'efficientnet':
  model = models.get(model_name=Models.EFFICIENTNET_B0, num_classes=100, pretrained_weights="imagenet")
elif model_type == 'mobilenet':
  model = models.get(model_name=Models.MOBILENET_V3_SMALL, num_classes=100, pretrained_weights="imagenet")
else:
  raise ValueError(f"Unknown model type: {model_type}")

Downloading: "https://sghub.deci.ai/models/mobilenet_v3_small_imagenet.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v3_small_imagenet.pth
 77%|███████▋  | 22.4M/29.3M [00:03<00:00, 14.0MB/s]

In [None]:
from super_gradients.training import Trainer
from super_gradients.training import training_hyperparams

# you can see more recipes in super_gradients/recipes
training_params =  training_hyperparams.get("training_hyperparams/cifar10_resnet_train_params")

In [None]:
pprint.pprint("Training parameters")
pprint.pprint(training_params)

In [None]:
training_params["initial_lr"] = 0.1
training_params["max_epochs"] = 1
training_params["lr_updates"] = [10, 25, 45]
training_params["lr_warmup_epochs"] = 5
training_params["warmup_initial_lr"] = 0.01
training_params["save_ckpt_epoch_list"] = [1, 5, 10]

In [None]:
trainer.train(model=model,
              training_params=training_params,
              train_loader=train_dataloader,
              valid_loader=valid_dataloader)

In [None]:
import copy
import os
import sys
import torch
from torch import nn
from super_gradients.training.utils.quantization.calibrator import QuantizationCalibrator
from super_gradients.training.utils.quantization.export import export_quantized_module_to_onnx
from super_gradients.training.utils.quantization.selective_quantization_utils import SelectiveQuantizer
from super_gradients.modules.repvgg_block import fuse_repvgg_blocks_residual_branches

def load_checkpoint(model, ckpt_file):
  checkpoint = torch.load(ckpt_file, map_location="cpu")
  ckpt_key = "ema_net" if "ema_net" in checkpoint else "net"
  state_dict = checkpoint[ckpt_key]
  model.load_state_dict(state_dict)

def validate_model(model, dataloader, training_hyperparams):
  trainer = Trainer(experiment_name=experiment_name, ckpt_root_dir=CHECKPOINT_DIR)

  valid_metrics_dict = trainer.test(model=model, test_loader=dataloader, test_metrics_list=training_hyperparams.get("valid_metrics_list"))

  results = ["Validate Results"]
  results += [f"\t- {metric:4}: {value:.3f}" for metric, value in valid_metrics_dict.items()]

  res_string = "\r\n".join(results)

  print(res_string, file=sys.stderr)

  return valid_metrics_dict

print(os.path.join(trainer.checkpoints_dir_path, "ckpt_best.pth"))
load_checkpoint(model, os.path.join(trainer.checkpoints_dir_path, "ckpt_best.pth"))
validate_model(model, valid_dataloader, training_params)

checkpoints/cifar100_ptq_qat_classification/RUN_20240310_110023_818008/ckpt_best.pth


Testing:  96%|█████████▌| 150/157 [00:07<00:00, 21.55it/s]

{'Accuracy': 0.7960000038146973, 'Top5': 0.9625999927520752}

Testing:  97%|█████████▋| 153/157 [00:07<00:00, 20.70it/s]Testing:  99%|█████████▉| 156/157 [00:07<00:00, 21.82it/s]Testing: 100%|██████████| 157/157 [00:07<00:00, 21.08it/s]
Validate Results
	- Accuracy: 0.796
	- Top5: 0.963


POST TRAINING QUANTIZATION (PTQ)

In [None]:
def quantize_and_calibrate(
    model: nn.Module,
    calibration_dataloader,
    num_calib_batches=16,             # number of batches to use for calibration
    method_w="max",                   # calibrator type for weights, acceptable types are ["max", "histogram"]
    method_i="histogram",             # calibrator type for inputs, acceptable types are ["max", "histogram"]
    calibration_method="percentile",  # calibration method for all "histogram" calibrators, acceptable types are ["percentile", "entropy", mse"], "max" calibrators are not affected
    percentile=99.99,                 # percentile for all histogram calibrators with method "percentile", other calibrators are not affected
    per_channel=True,                 # per-channel quantization of weights, activations stay per-tensor by default
    learn_amax=False,                 # enable learnable amax in all TensorQuantizers using straight-through estimator
    skip_modules=None,                # optional list of module names (strings) to skip from quantization
    verbose=False,                    # if calibrator should be verbose
):
    model.eval()

    q_util = SelectiveQuantizer(
        default_quant_modules_calibrator_weights=method_w,
        default_quant_modules_calibrator_inputs=method_i,
        default_per_channel_quant_weights=per_channel,
        default_learn_amax=learn_amax,
        verbose=verbose
    )

    if skip_modules is not None:
        q_util.register_skip_quantization(layer_names=set(skip_modules))

    calibrator = QuantizationCalibrator(verbose=verbose, torch_hist=True)

    # RepVGG and QARepVGG can be quantized only in the fused form
    fuse_repvgg_blocks_residual_branches(model)
    q_util.quantize_module(model)

    calibrator.calibrate_model(
        model,
        method=calibration_method,
        calib_data_loader=calibration_dataloader,
        num_calib_batches=num_calib_batches,
        percentile=percentile,
    )

    model.train()
    return model


In [None]:
load_checkpoint(model, os.path.join(trainer.checkpoints_dir_path, "ckpt_best.pth"))
ptq_model = quantize_and_calibrate(copy.deepcopy(model),
                                    train_dataloader,
                                    num_calib_batches=4,
                                    method_w="max",
                                    method_i="max",
                                    calibration_method="percentile",
                                    percentile=99.99,
                                    per_channel=True,
                                    learn_amax=False,
                                    skip_modules=None,
                                    verbose=False)

100%|██████████| 4/4 [00:00<00:00, 10.26it/s]


In [None]:
validate_model(ptq_model, valid_dataloader, training_params)

Testing:  97%|█████████▋| 153/157 [00:12<00:00, 12.55it/s]

{'Accuracy': 0.6363999843597412, 'Top5': 0.8851000070571899}


Float vs Quantized Model Accuracy

1. ResNet50: top1 (), top5 ()
2. EfficientNet-B0: top1 (), top5 ()
3. MobileNet-V3: top1 (), top5 ()