## Links:

Introductory video:
https://www.youtube.com/watch?v=IPQmGzYuxmc

Introductory documentation:
https://pytorch.org/docs/stable/quantization.html

In PyTorch 1.3 where code was run it is said to be an EXPERIMENTAL feature.

Links below were helpful in understanding the details
* https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html
* https://pytorch.org/docs/master/quantization.html#quantized-torch-tensor-operations
* https://github.com/pytorch/pytorch/blob/master/torch/quantization/quantize.py

In [1]:
import numpy as np
import os
import time
import torch
import torch.quantization as quantization
import torch.nn as nn

from mnist import MNIST
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm_notebook as tqdm

In [2]:
torch.__version__

'1.3.0.post2'

# Prepare dataset

In [3]:
class MnistDataset(Dataset):
    def __init__(self, mnist_data_path):
        self.mndata = MNIST(mnist_data_path)
        self.images, self.labels = self.mndata.load_training()
        
        self.images = np.array(self.images).reshape(-1, 1, 28, 28)/255
        self.labels = np.array(self.labels)
        
    def __len__(self):
         return self.images.shape[0]

    def __getitem__(self, index):
        return torch.tensor(self.images[index], dtype=torch.float), int(self.labels[index])

In [4]:
data = MnistDataset('./python-mnist/data/')

# Prepare model

In [5]:
class MyModel(nn.Module):
    def __init__(self, num_classes=10):
        super(MyModel, self).__init__()
        self.qconfig = quantization.default_qconfig
        self.conv1 = nn.Conv2d( 1,  8, kernel_size=3, stride=2, padding=2, bias=False)
        self.relu1 = nn.ReLU(inplace=True)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=1)
        
        self.conv2 = nn.Conv2d( 8, 16, kernel_size=2, stride=1, padding=2, bias=False)
        self.relu2 = nn.ReLU(inplace=True)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=1)
            
        self.conv3 = nn.Conv2d(16, 32, kernel_size=2, stride=1, padding=2, bias=False)
        self.relu3 = nn.ReLU(inplace=True)
        self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=1)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(32, num_classes)
        
    def forward(self, x):
        x = self.maxpool1(self.relu1(self.conv1(x)))
        x = self.maxpool2(self.relu2(self.conv2(x)))
        x = self.maxpool3(self.relu3(self.conv3(x)))
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

model = MyModel()

# Train model (dummy, not caring about super results here)

In [6]:
optimizer = SGD(model.parameters(), lr=0.1, momentum=0, dampening=0, weight_decay=0, nesterov=False)
criterion = CrossEntropyLoss()
scheduler = StepLR(optimizer, 10, 0.1)
epochs = 30
batch_size = 16

dataloader = DataLoader(data, batch_size=batch_size, shuffle=True)

if os.path.isfile('orig.pth'):
    model.load_state_dict(torch.load('orig.pth'))
else:
    model.train()
    for e in range(epochs):
        losses = []
        tloader = tqdm(dataloader, leave=False)
        number = 0
        hits = 0

        for img, label in tloader:
            optimizer.zero_grad()  
            output = model(img)
            loss = criterion(output, label)
            loss.backward()
            optimizer.step()

            losses.append(float(loss))
            number += output.shape[0]
            output = output.detach().numpy()
            label = label.detach().numpy()
            hits += np.sum( np.argmax(output, axis=1) == label )
            tloader.desc = f"E{e+1}/{epochs} {np.mean(losses):0.5f} ACC: {(hits/number):0.5f}"    
    scheduler.step()

HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))



HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))



HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))



HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))



HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))



HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))



HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))



HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))



HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))



HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))



HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))



HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))



HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))



HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))



HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))



HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))



HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))



HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))



HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))



HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))



HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))



HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))



HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))



HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))



HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))



HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))



HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))



HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))



HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))



HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))



# Quantization

In [7]:
dummy = torch.zeros(1,1,28,28, dtype=torch.float)
test_iterations = 10000

### test time & size before quantization

In [8]:
model.eval()

time_start = time.time()
for i in range(test_iterations):
    result = model.forward(dummy)
time_stop = time.time()
orig_time = int(1000*(time_stop-time_start))

print(f"Result shape: {result.shape} Test time: {orig_time}ms",)

# Check acc
tloader = tqdm(dataloader)
number = 0
hits = 0
for img, label in tloader: 
    output = model(img)
    number += output.shape[0]
    output = output.detach().numpy()
    label = label.detach().numpy()
    hits += np.sum( np.argmax(output, axis=1) == label )
print(f"Accuracy: {(hits/number):0.5f}")

torch.save(model.state_dict(),"orig.pth")
print(f"Model file size: {os.path.getsize('orig.pth')}")

Result shape: torch.Size([1, 10]) Test time: 4908ms


HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))


Accuracy: 0.98225
Model file size: 12992


### perform quantization

In [9]:
%%time

pmodel = quantization.fuse_modules(model, [['conv1','relu1'],['conv2','relu2'],['conv3','relu3']])

# Prepare for stats collection
qmodel = quantization.prepare(pmodel,{"":quantization.default_qconfig})

# Gather stats
tloader = tqdm(dataloader)
for img, _ in tloader: 
    output = qmodel(img)

## Convert model
qmodel = quantization.convert(qmodel)

# first layer scale
print(qmodel.conv1)
print(qmodel.fc)

HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))


QuantizedConvReLU2d(1, 8, kernel_size=(3, 3), stride=(2, 2), scale=0.03981924429535866, zero_point=0, padding=(2, 2), bias=False)
QuantizedLinear(in_features=32, out_features=10, scale=0.4616350829601288, zero_point=63)
CPU times: user 7.88 s, sys: 123 ms, total: 8 s
Wall time: 8.62 s


### test time & size after quantization

In [10]:
qmodel.eval()
qdummy = torch.quantize_per_tensor(dummy, scale=1, zero_point=0, dtype=torch.quint8) 

time_start = time.time()
for i in range(test_iterations):
    result = qmodel.forward(qdummy)
time_stop = time.time()
q_time = int(1000*(time_stop-time_start))

print(f"Result shape: {result.shape} Test time: {q_time}ms",)

# Check acc
tloader = tqdm(dataloader)
number = 0
hits = 0
for img, label in tloader: 
    img = torch.quantize_per_tensor(img, scale=1, zero_point=0, dtype=torch.quint8) 
    output = qmodel(img)
    number += output.shape[0]
    output = torch.dequantize(output).detach().numpy()
    label = label.detach().numpy()
    hits += np.sum( np.argmax(output, axis=1) == label )
print(f"Accuracy: {(hits/number):0.5f}")

torch.save(qmodel.state_dict(),"q.pth")
print(f"QModel file size: {os.path.getsize('q.pth')}")

Result shape: torch.Size([1, 10]) Test time: 2447ms


HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))


Accuracy: 0.95723
QModel file size: 5243
