# Quantize the pre-trained model.

In [1]:
import os

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import models  # quantization
from torchvision.transforms import transforms

In [2]:
# load pretrained model.
model = models.resnet50()
model.fc = nn.Linear(model.fc.in_features, 100)  # resnet

In [3]:
model.load_state_dict(torch.load('resnet_cifar100.pth'))

# Fuse the model layers
model.eval()

# Specify the quantization configuration
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')  # default_dynamic_qconfig # get_default_qconfig

# Prepare the model for quantization
torch.quantization.prepare(model, inplace=True)



ResNet(
  (conv1): Conv2d(
    3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (bn1): BatchNorm2d(
    64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(
        64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False
        (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
      )
      (bn1): BatchNorm2d(
        64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
        (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
      )
      (conv2): Conv2d(
        64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
        (activation_po

In [4]:
my_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])

In [5]:
# Perform calibration on a representative dataset
calibration_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=my_transforms)
calibration_dataloader = DataLoader(calibration_dataset, batch_size=128)

Files already downloaded and verified


In [6]:
# quant = torch.ao.quantization.QuantStub()

In [7]:
# Run calibration
i = 0
with torch.no_grad():
    for inputs, _ in calibration_dataloader:
        model(inputs)
        print(f'{i + 1}  / {len(calibration_dataloader)} done.', end='\r')
        i += 1

79  / 79 done.

In [8]:
# Load CIFAR-100 dataset
train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=my_transforms)
test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=my_transforms)

train_dataloader = DataLoader(train_dataset, batch_size=512, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=512, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [9]:
def test(dataloader, type):
    # Evaluate the quantized model
    model.eval()

    with torch.no_grad():
        correct = 0
        total = 0
        for idx, batch in enumerate(dataloader):
            inputs, labels = batch
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            print(f'[{type}] {idx + 1}  / {len(dataloader)} done.', end='\r')

        accuracy = 100 * correct / total
        print(f'Quantized model in {type} accuracy: {accuracy:.2f}%')

In [10]:
test(train_dataloader, 'train dataset')
test(test_dataloader, 'test dataset')

Quantized model in train dataset accuracy: 92.17%
Quantized model in test dataset accuracy: 49.02%


In [11]:
# Convert the model to a quantized model
torch.quantization.convert(model.eval(), inplace=True)

# Save the quantized model
torch.save(model.state_dict(), 'quantized_resnet_cifar100.pth')

# Get the size of the saved model file
model_size = os.path.getsize('quantized_resnet_cifar100.pth') / (1024 * 1024)  # Size in MB
print(f"Pruned model size: {model_size:.2f} MB")