In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
from torch.utils.data import DataLoader
import torch.quantization
from torch.quantization import QuantStub, DeQuantStub
from torchvision import transforms

In [3]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,))])

trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
                                          shuffle=True, num_workers=16, pin_memory=True)

testset = torchvision.datasets.MNIST(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
                                         shuffle=False, num_workers=16, pin_memory=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [Errno 111] Connection refused>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz


Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 14.2MB/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [Errno 111] Connection refused>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 464kB/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [Errno 111] Connection refused>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 3.26MB/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [Errno 111] Connection refused>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 18.6MB/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






In [4]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

def accuracy(output, target):
    """ Computes the top 1 accuracy """
    with torch.no_grad():
        batch_size = target.size(0)

        _, pred = output.topk(1, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        correct_one = correct[:1].view(-1).float().sum(0, keepdim=True)
        return correct_one.mul_(100.0 / batch_size).item()

def print_size_of_model(model):
    """ Prints the real size of the model """
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

def load_model(quantized_model, model):
    """ Loads in the weights into an object meant for quantization """
    state_dict = model.state_dict()
    model = model.to('cpu')
    quantized_model.load_state_dict(state_dict)

def fuse_modules(model):
    """ Fuse together convolutions/linear layers and ReLU """
    torch.quantization.fuse_modules(model, [['conv1', 'relu1'], 
                                            ['conv2', 'relu2'],
                                            ['fc1', 'relu3'],
                                            ['fc2', 'relu4']], inplace=True)

In [5]:
class Net(nn.Module):
    def __init__(self, q = False):
        # By turning on Q we can turn on/off the quantization
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5, bias=False)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5, bias=False)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(256, 120, bias=False)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(120, 84, bias=False)
        self.relu4 = nn.ReLU()
        self.fc3 = nn.Linear(84, 10, bias=False)
        self.q = q
        if q:
          self.quant = QuantStub()
          self.dequant = DeQuantStub()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.q:
          x = self.quant(x)
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        # Be careful to use reshape here instead of view
        x = x.reshape(x.shape[0], -1)
        x = self.fc1(x)
        x = self.relu3(x)
        x = self.fc2(x)
        x = self.relu4(x)
        x = self.fc3(x)
        if self.q:
          x = self.dequant(x)
        return x

In [6]:
net = Net(q=False).cuda()
print_size_of_model(net)

Size (MB): 0.179057


In [7]:
def train(model: nn.Module, dataloader: DataLoader, cuda=False, q=False):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    model.train()
    for epoch in range(20):  # loop over the dataset multiple times

        running_loss = AverageMeter('loss')
        acc = AverageMeter('train_acc')
        for i, data in enumerate(dataloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            if cuda:
              inputs = inputs.cuda()
              labels = labels.cuda()

            # zero the parameter gradients
            optimizer.zero_grad()

            if epoch>=3 and q:
              model.apply(torch.quantization.disable_observer)

            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss.update(loss.item(), outputs.shape[0])
            acc.update(accuracy(outputs, labels), outputs.shape[0])
            if i % 100 == 0:    # print every 100 mini-batches
                print('[%d, %5d] ' %
                    (epoch + 1, i + 1), running_loss, acc)
    print('Finished Training')


def test(model: nn.Module, dataloader: DataLoader, cuda=False) -> float:
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for data in dataloader:
            inputs, labels = data

            if cuda:
              inputs = inputs.cuda()
              labels = labels.cuda()

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    return 100 * correct / total

In [8]:
train(net, trainloader, cuda=True)

[1,     1]  loss 2.301271 (2.301271) train_acc 4.687500 (4.687500)
[1,   101]  loss 2.295300 (2.300438) train_acc 15.625000 (8.029084)
[1,   201]  loss 2.292073 (2.298643) train_acc 21.875000 (10.820896)
[1,   301]  loss 2.290231 (2.296418) train_acc 18.750000 (14.477782)
[1,   401]  loss 2.276408 (2.293633) train_acc 28.125000 (17.471945)
[1,   501]  loss 2.271868 (2.290192) train_acc 34.375000 (19.732410)
[1,   601]  loss 2.249302 (2.285249) train_acc 37.500000 (22.665349)
[1,   701]  loss 2.181720 (2.276491) train_acc 50.000000 (26.279422)
[1,   801]  loss 1.985273 (2.257389) train_acc 71.875000 (29.993758)
[1,   901]  loss 1.470307 (2.203546) train_acc 62.500000 (33.226970)
[2,     1]  loss 1.074759 (1.074759) train_acc 73.437500 (73.437500)
[2,   101]  loss 0.789274 (0.919006) train_acc 67.187500 (73.623144)
[2,   201]  loss 0.682187 (0.777460) train_acc 82.812500 (76.951182)
[2,   301]  loss 0.327191 (0.694045) train_acc 89.062500 (79.074958)
[2,   401]  loss 0.458517 (0.630639) 

In [9]:
score = test(net, testloader, cuda=True)
print('Accuracy of the network on the test images: {}% - FP32'.format(score))

Accuracy of the network on the test images: 98.74% - FP32


## Post-training quantization

In [10]:
qnet = Net(q=True)
load_model(qnet, net)
fuse_modules(qnet)

In [11]:
print_size_of_model(qnet)
score = test(qnet, testloader, cuda=False)
print('Accuracy of the fused network on the test images: {}% - FP32'.format(score))

Size (MB): 0.179249
Accuracy of the fused network on the test images: 98.74% - FP32


In [12]:
qnet.qconfig = torch.quantization.default_qconfig
print(qnet.qconfig)
torch.quantization.prepare(qnet, inplace=True)
print('Post Training Quantization Prepare: Inserting Observers')
print('\n Conv1: After observer insertion \n\n', qnet.conv1)

test(qnet, trainloader, cuda=False)
print('Post Training Quantization: Calibration done')
torch.quantization.convert(qnet, inplace=True)
print('Post Training Quantization: Convert done')
print('\n Conv1: After fusion and quantization \n\n', qnet.conv1)
print("Size of model after quantization")
print_size_of_model(qnet)

QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, quant_min=0, quant_max=127){}, weight=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric){})
Post Training Quantization Prepare: Inserting Observers

 Conv1: After observer insertion 

 ConvReLU2d(
  (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), bias=False)
  (1): ReLU()
  (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
)
Post Training Quantization: Calibration done
Post Training Quantization: Convert done

 Conv1: After fusion and quantization 

 QuantizedConvReLU2d(1, 6, kernel_size=(5, 5), stride=(1, 1), scale=0.052600838243961334, zero_point=0, bias=False)
Size of model after quantization
Size (MB): 0.050084


In [13]:
score = test(qnet, testloader, cuda=False)
print('Accuracy of the fused and quantized network on the test images: {}% - INT8'.format(score))

Accuracy of the fused and quantized network on the test images: 98.76% - INT8


In [14]:
from torch.quantization.observer import MovingAverageMinMaxObserver

qnet = Net(q=True)
load_model(qnet, net)
fuse_modules(qnet)

qnet.qconfig = torch.quantization.QConfig(
                                      activation=MovingAverageMinMaxObserver.with_args(reduce_range=True), 
                                      weight=MovingAverageMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric))
print(qnet.qconfig)
torch.quantization.prepare(qnet, inplace=True)
print('Post Training Quantization Prepare: Inserting Observers')
print('\n Conv1: After observer insertion \n\n', qnet.conv1)

test(qnet, trainloader, cuda=False)
print('Post Training Quantization: Calibration done')
torch.quantization.convert(qnet, inplace=True)
print('Post Training Quantization: Convert done')
print('\n Conv1: After fusion and quantization \n\n', qnet.conv1)
print("Size of model after quantization")
print_size_of_model(qnet)
score = test(qnet, testloader, cuda=False)
print('Accuracy of the fused and quantized network on the test images: {}% - INT8'.format(score))

QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>, reduce_range=True){}, weight=functools.partial(<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric){})
Post Training Quantization Prepare: Inserting Observers

 Conv1: After observer insertion 

 ConvReLU2d(
  (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), bias=False)
  (1): ReLU()
  (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
)




Post Training Quantization: Calibration done
Post Training Quantization: Convert done

 Conv1: After fusion and quantization 

 QuantizedConvReLU2d(1, 6, kernel_size=(5, 5), stride=(1, 1), scale=0.049229469150304794, zero_point=0, bias=False)
Size of model after quantization
Size (MB): 0.050084
Accuracy of the fused and quantized network on the test images: 98.78% - INT8


In [15]:
qnet = Net(q=True)
load_model(qnet, net)
fuse_modules(qnet)

In [16]:
qnet.qconfig = torch.quantization.get_default_qconfig('fbgemm')
print(qnet.qconfig)

torch.quantization.prepare(qnet, inplace=True)
test(qnet, trainloader, cuda=False)
torch.quantization.convert(qnet, inplace=True)
print("Size of model after quantization")
print_size_of_model(qnet)

QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.HistogramObserver'>, reduce_range=True){}, weight=functools.partial(<class 'torch.ao.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric){})
Size of model after quantization
Size (MB): 0.05572


In [17]:
score = test(qnet, testloader, cuda=False)
print('Accuracy of the fused and quantized network on the test images: {}% - INT8'.format(score))

Accuracy of the fused and quantized network on the test images: 98.7% - INT8


## Quantization aware training

In [18]:
qnet = Net(q=True)
fuse_modules(qnet)
qnet.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(qnet, inplace=True)
print('\n Conv1: After fusion and quantization \n\n', qnet.conv1)
qnet=qnet.cuda()


 Conv1: After fusion and quantization 

 ConvReLU2d(
  1, 6, kernel_size=(5, 5), stride=(1, 1), bias=False
  (weight_fake_quant): FusedMovingAvgObsFakeQuantize(
    fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.qint8, quant_min=-128, quant_max=127, qscheme=torch.per_channel_symmetric, reduce_range=False
    (activation_post_process): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
  )
  (activation_post_process): FusedMovingAvgObsFakeQuantize(
    fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.quint8, quant_min=0, quant_max=127, qscheme=torch.per_tensor_affine, reduce_range=True
    (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  )
)


In [19]:
train(qnet, trainloader, cuda=True)

[1,     1]  loss 2.304098 (2.304098) train_acc 14.062500 (14.062500)
[1,   101]  loss 2.299604 (2.301910) train_acc 14.062500 (11.834777)
[1,   201]  loss 2.296406 (2.299836) train_acc 17.187500 (13.456157)
[1,   301]  loss 2.291352 (2.297608) train_acc 20.312500 (15.484842)
[1,   401]  loss 2.288203 (2.294956) train_acc 25.000000 (17.740804)
[1,   501]  loss 2.280631 (2.291599) train_acc 18.750000 (20.602545)
[1,   601]  loss 2.257500 (2.286950) train_acc 46.875000 (24.178453)
[1,   701]  loss 2.221570 (2.279951) train_acc 57.812500 (28.203014)
[1,   801]  loss 2.159049 (2.268833) train_acc 57.812500 (32.020521)
[1,   901]  loss 2.048853 (2.251091) train_acc 59.375000 (35.328801)
[2,     1]  loss 2.010422 (2.010422) train_acc 54.687500 (54.687500)
[2,   101]  loss 1.732704 (1.872161) train_acc 64.062500 (67.651609)
[2,   201]  loss 1.267590 (1.712890) train_acc 79.687500 (69.986007)
[2,   301]  loss 1.061672 (1.520588) train_acc 68.750000 (72.030731)
[2,   401]  loss 0.729978 (1.34408

In [20]:
qnet = qnet.cpu()
torch.quantization.convert(qnet, inplace=True)
print("Size of model after quantization")
print_size_of_model(qnet)

score = test(qnet, testloader, cuda=False)
print('Accuracy of the fused and quantized network (trained quantized) on the test images: {}% - INT8'.format(score))
     

Size of model after quantization
Size (MB): 0.05572
Accuracy of the fused and quantized network (trained quantized) on the test images: 98.64% - INT8


In [21]:
qnet = Net(q=True)
fuse_modules(qnet)
qnet.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(qnet, inplace=True)
qnet = qnet.cuda()
train(qnet, trainloader, cuda=True, q=True)
qnet = qnet.cpu()
torch.quantization.convert(qnet, inplace=True)
print("Size of model after quantization")
print_size_of_model(qnet)

score = test(qnet, testloader, cuda=False)
print('Accuracy of the fused and quantized network (trained quantized) on the test images: {}% - INT8'.format(score))
     

[1,     1]  loss 2.302431 (2.302431) train_acc 9.375000 (9.375000)
[1,   101]  loss 2.292588 (2.298551) train_acc 15.625000 (10.519802)
[1,   201]  loss 2.281883 (2.293629) train_acc 21.875000 (14.886505)
[1,   301]  loss 2.267573 (2.287043) train_acc 42.187500 (21.496055)
[1,   401]  loss 2.238542 (2.278980) train_acc 50.000000 (27.579489)
[1,   501]  loss 2.205835 (2.267949) train_acc 65.625000 (33.348927)
[1,   601]  loss 2.115737 (2.250990) train_acc 73.437500 (38.971506)
[1,   701]  loss 1.959956 (2.223528) train_acc 81.250000 (43.504815)
[1,   801]  loss 1.711493 (2.176959) train_acc 81.250000 (47.635768)
[1,   901]  loss 1.398292 (2.103770) train_acc 71.875000 (51.068257)
[2,     1]  loss 1.058854 (1.058854) train_acc 81.250000 (81.250000)
[2,   101]  loss 0.752661 (0.913124) train_acc 85.937500 (82.549505)
[2,   201]  loss 0.370132 (0.764782) train_acc 90.625000 (83.426617)
[2,   301]  loss 0.411058 (0.677074) train_acc 93.750000 (84.146595)
[2,   401]  loss 0.222220 (0.615775)

In [None]:
# Dummy input tensor
dummy_input = torch.randn(1, 1, 28, 28)

# qconfig = get_default_qconfig("fbgemm")  # Use "qnnpack" for ARM devices
# prepared_model = prepare_fx(model, {"": qconfig}, dummy_input)
# quantized_model_fx = convert_fx(prepared_model)
print(qnet)
# Export the quantized model to ONNX
torch.onnx.export(
    qnet,
    dummy_input,
    "qat_model_mnist.onnx",
    opset_version=13,  # Ensure compatibility with INT8 ops
    input_names=["input"],
    output_names=["output"],
)