# Quantize the pre-trained model.

In [1]:
import os
import time

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import models
from torchvision.transforms import transforms

In [2]:
# load pretrained model.
model = models.resnet50(weights=None)
model.fc = nn.Linear(model.fc.in_features, 100)  # resnet

In [3]:
model.load_state_dict(torch.load('resnet_cifar100.pth'))

# Fuse the model layers
model.eval()

# Specify the quantization configuration
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')  # default_dynamic_qconfig # get_default_qconfig

# Prepare the model for quantization
model = torch.quantization.prepare(model, inplace=True)



In [4]:
# my_transforms = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
# ])

In [5]:
# Perform calibration on a representative dataset
calibration_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transforms.ToTensor())
calibration_dataloader = DataLoader(calibration_dataset, batch_size=512)

Files already downloaded and verified


In [6]:
# quant = torch.ao.quantization.QuantStub()

In [7]:
# Run calibration
i = 0
with torch.no_grad():
    for inputs, _ in calibration_dataloader:
        model(inputs)
        print(f'{i + 1}  / {len(calibration_dataloader)} done.', end='\r')
        i += 1

20  / 20 done.

In [8]:
# Load CIFAR-100 dataset
train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transforms.ToTensor())

train_dataloader = DataLoader(train_dataset, batch_size=512, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=512, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [9]:
def evaluate(dataloader, type):
    # Evaluate the quantized model
    model.eval()

    total_time = 0.
    num_batches = 0

    with torch.no_grad():
        correct = 0
        total = 0

        for idx, batch in enumerate(dataloader):
            inputs, labels = batch

            tic = time.time()
            outputs = model(inputs)
            total_time += (time.time() - tic) * 1000

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            num_batches += 1
            print(f'[{type}] {idx + 1} / {len(dataloader)} done. ({total_time / num_batches / 512:.2f} ms)', end='\r')

        accuracy = 100 * correct / total
        print(
            f'Quantized model in {type} accuracy: {accuracy:.2f}%, '
            f'Average Inference Time: {total_time / num_batches:.2f} ms.'
        )

In [10]:
evaluate(train_dataloader, 'train dataset')
evaluate(test_dataloader, 'test dataset')

Quantized model in train dataset accuracy: 11.43%, Average Inference Time: 1181.38 ms.
Quantized model in test dataset accuracy: 10.93%, Average Inference Time: 1155.17 ms.


In [11]:
# Convert the model to a quantized model
torch.quantization.convert(model.eval(), inplace=True)

# Save the quantized model
torch.save(model.state_dict(), 'quantized_resnet_cifar100.pth')

# Get the size of the saved model file
model_size = os.path.getsize('quantized_resnet_cifar100.pth') / (1024 * 1024)  # Size in MB
print(f"Pruned model size: {model_size:.2f} MB")

Pruned model size: 23.57 MB
