In [1]:
import torch
import torch.utils.data
from torch import nn

from pytorch_quantization import nn as quant_nn
from pytorch_quantization import calib
from pytorch_quantization.tensor_quant import QuantDescriptor

In [2]:
# Define a simple CNN model
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(-1, 64 * 7 * 7)
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [3]:
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torchvision
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Data preprocessing
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = torchvision.datasets.MNIST(root="./data", train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root="./data", train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

In [4]:
model = CNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

state_dict = torch.load("mnist_cnn.pth")
model.load_state_dict(state_dict)
model.cuda()

  state_dict = torch.load("mnist_cnn.pth")


CNN(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=3136, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.5, inplace=False)
)

In [5]:
from pytorch_quantization import quant_modules
quant_modules.initialize()

In [6]:
quant_desc_input = QuantDescriptor(calib_method='histogram')
quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)

In [8]:
model_qat = CNN().to(device)

state_dict = torch.load("mnist_cnn.pth")
model_qat.load_state_dict(state_dict)
model_qat.cuda()

  state_dict = torch.load("mnist_cnn.pth")


CNN(
  (conv1): QuantConv2d(
    1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
    (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
    (_weight_quantizer): TensorQuantizer(8bit fake axis=0 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
  )
  (conv2): QuantConv2d(
    32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
    (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
    (_weight_quantizer): TensorQuantizer(8bit fake axis=0 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
  )
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): QuantLinear(
    in_features=3136, out_features=128, bias=True
    (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
    (_weight_quantizer): TensorQuantizer(8bit fake axis=0 amax=dyna

In [12]:
from tqdm import tqdm
def collect_stats(model, data_loader, num_batches):
    """Feed data to the network and collect statistic"""

    # Enable calibrators
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.disable_quant()
                module.enable_calib()
            else:
                module.disable()

    for i, (image, _) in tqdm(enumerate(data_loader), total=num_batches):
        model(image.cuda())
        if i >= num_batches:
            break

    # Disable calibrators
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.enable_quant()
                module.disable_calib()
            else:
                module.enable()

def compute_amax(model, **kwargs):
    # Load calib result
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                if isinstance(module._calibrator, calib.MaxCalibrator):
                    module.load_calib_amax()
                else:
                    module.load_calib_amax(**kwargs)
            print(F"{name:40}: {module}")
    model.cuda()

# It is a bit slow since we collect histograms on CPU
with torch.no_grad():
    collect_stats(model_qat, train_loader, num_batches=2)
    compute_amax(model_qat, method="percentile", percentile=99.99)

  0%|          | 0/2 [00:00<?, ?it/s]

100%|██████████| 2/2 [00:06<00:00,  3.28s/it]
W0203 06:16:29.298797 139736912962816 tensor_quantizer.py:174] Disable HistogramCalibrator
W0203 06:16:29.299856 139736912962816 tensor_quantizer.py:174] Disable MaxCalibrator
W0203 06:16:29.300999 139736912962816 tensor_quantizer.py:174] Disable HistogramCalibrator
W0203 06:16:29.301458 139736912962816 tensor_quantizer.py:174] Disable MaxCalibrator
W0203 06:16:29.302242 139736912962816 tensor_quantizer.py:174] Disable HistogramCalibrator
W0203 06:16:29.302683 139736912962816 tensor_quantizer.py:174] Disable MaxCalibrator
W0203 06:16:29.303240 139736912962816 tensor_quantizer.py:174] Disable HistogramCalibrator
W0203 06:16:29.303801 139736912962816 tensor_quantizer.py:174] Disable MaxCalibrator
W0203 06:16:29.317455 139736912962816 tensor_quantizer.py:238] Load calibrated amax, shape=torch.Size([]).
W0203 06:16:29.318041 139736912962816 tensor_quantizer.py:239] Call .cuda() if running on GPU after loading calibrated amax.
W0203 06:16:29.318

conv1._input_quantizer                  : TensorQuantizer(8bit fake per-tensor amax=2.8201 calibrator=HistogramCalibrator scale=1.0 quant)
conv1._weight_quantizer                 : TensorQuantizer(8bit fake axis=0 amax=[0.3197, 0.5787](32) calibrator=MaxCalibrator scale=1.0 quant)
conv2._input_quantizer                  : TensorQuantizer(8bit fake per-tensor amax=4.4820 calibrator=HistogramCalibrator scale=1.0 quant)
conv2._weight_quantizer                 : TensorQuantizer(8bit fake axis=0 amax=[0.1594, 0.4042](64) calibrator=MaxCalibrator scale=1.0 quant)
fc1._input_quantizer                    : TensorQuantizer(8bit fake per-tensor amax=8.7409 calibrator=HistogramCalibrator scale=1.0 quant)
fc1._weight_quantizer                   : TensorQuantizer(8bit fake axis=0 amax=[0.0230, 0.2383](128) calibrator=MaxCalibrator scale=1.0 quant)
fc2._input_quantizer                    : TensorQuantizer(8bit fake per-tensor amax=44.7653 calibrator=HistogramCalibrator scale=1.0 quant)
fc2._weight_q

In [13]:
# Evaluation function
def evaluate(model, device, test_loader, criterion):
    model.eval()
    total_loss = 0
    correct = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            total_loss += criterion(outputs, labels).item()
            correct += (outputs.argmax(1) == labels).sum().item()

    print(f"Test Loss: {total_loss / len(test_loader):.4f}, Accuracy: {correct / len(test_loader.dataset) * 100:.2f}%")

with torch.no_grad():
    evaluate(model_qat, device, test_loader, criterion)

# Save the model
torch.save(model_qat.state_dict(), "mnist_quant-calibrated.pth")

Test Loss: 0.0229, Accuracy: 99.33%


In [14]:
import pytorch_quantization
dummy_input = torch.randn(1, 1, 28, 28, device='cuda')

input_names = [ "actual_input_1" ]
output_names = [ "output1" ]

with pytorch_quantization.enable_onnx_export():
     # enable_onnx_checker needs to be disabled. See notes below.
     torch.onnx.export(
         model_qat, dummy_input, "mnist_quant.onnx", verbose=True, opset_version=10, enable_onnx_checker=False, input_names = [ "actual_input_1" ], output_names = [ "output1" ]
         )

AttributeError: module 'pytorch_quantization' has no attribute 'enable_onnx_export'