In [1]:

import torch
import torchvision.datasets as datasets 
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path
import os

In [2]:
# Make torch deterministic
_ = torch.manual_seed(0)

In [3]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Load the MNIST test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

# Define the device
device = "cuda"

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:19<00:00, 500641.92it/s]


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 94452.86it/s]


Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:04<00:00, 380528.03it/s]


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 2233094.45it/s]


Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw



In [7]:
# Define the model
class VerySimpleNet(nn.Module):
    def __init__(self, hidden_size_1=100, hidden_size_2=100):
        super(VerySimpleNet, self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
    
    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

In [8]:
net = VerySimpleNet().to(device)

#### Train the model

In [9]:
def train(train_loader, net, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    total_iterations = 0

    for epoch in range(epochs):
        net.train()

        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = net(x.view(-1, 28*28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return

In [10]:
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_delme.p")
    print('Size (KB):', os.path.getsize("temp_delme.p")/1e3)
    os.remove("temp_delme.p")

MODEL_FILENAME = 'simplenet_ptq.pt'

if Path(MODEL_FILENAME).exists():
    net.load_state_dict(torch.load(MODEL_FILENAME))
    print('Loaded model form disk')
else:
    train(train_loader, net, epochs=1)
    torch.save(net.state_dict(), MODEL_FILENAME)

Epoch 1: 100%|██████████| 6000/6000 [04:15<00:00, 23.46it/s, loss=0.219]


## Defie the testing loop

In [11]:
def test(model: nn.Module, total_iterations: int = None):
    correct = 0
    total = 0
    
    iterations = 0
    
    model.eval()
    
    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data 
            x = x.to(device)
            y = y.to(device)
            output = model(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct += 1
                total += 1
            iterations += 1
            if total_iterations is not None and iterations >= total_iterations:
                break 
    print(f'Accuracy: {round(correct/total, 3)}')

Print the weights

In [15]:
print('Weights before quantization')
print(net.linear1.weight)
print(net.linear1.weight.dtype)
print(net.linear1.weight.shape)

Weights before quantization
Parameter containing:
tensor([[ 0.0063,  0.0074,  0.0210,  ...,  0.0199,  0.0036,  0.0322],
        [ 0.0327,  0.0265,  0.0057,  ...,  0.0362,  0.0017, -0.0027],
        [ 0.0433,  0.0419,  0.0472,  ...,  0.0078,  0.0528,  0.0139],
        ...,
        [ 0.0480,  0.0226, -0.0004,  ...,  0.0114,  0.0536,  0.0587],
        [ 0.0399,  0.0004, -0.0184,  ..., -0.0054,  0.0094,  0.0184],
        [ 0.0131,  0.0252, -0.0057,  ..., -0.0183,  0.0114,  0.0165]],
       device='cuda:0', requires_grad=True)
torch.float32
torch.Size([100, 784])


In [16]:
print('Size of model before quantization')
print_size_of_model(net)

Size of model before quantization
Size (KB): 361.062


In [17]:
print(f'Accuracy of the model before quantization:')
test(net)

Accuracy of the model before quantization:


Testing: 100%|██████████| 1000/1000 [00:15<00:00, 66.52it/s]

Accuracy: 0.957





## Insert min-max observers in the model

In [20]:
class QuantizedVerySimpleNet(nn.Module):
    def __init__(self, hidden_size_1=100, hidden_size_2=100):
        super(QuantizedVerySimpleNet, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub()
    
    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.quant(x)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        x = self.dequant(x)
        return x

In [23]:
net_quantized = QuantizedVerySimpleNet().to(device)
# Copy weights from quantized model
net_quantized.load_state_dict(net.state_dict())
net_quantized.eval()

net_quantized.qconfig = torch.ao.quantization.default_qconfig
net_quantized = torch.ao.quantization.prepare(net_quantized) # Insert observers
net_quantized

QuantizedVerySimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

#### Calibrate the model using the test set

In [24]:
test(net_quantized)

Testing: 100%|██████████| 1000/1000 [00:06<00:00, 144.40it/s]


Accuracy: 0.957


In [25]:
print(f'Check the statistics of the various layers')
net_quantized

Check the statistics of the various layers


QuantizedVerySimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-0.4242129623889923, max_val=2.821486711502075)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=-45.925804138183594, max_val=31.942672729492188)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=-27.5479736328125, max_val=21.284912109375)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=-27.22031021118164, max_val=20.436033248901367)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

In [26]:
## Quantize the model using statistics
net_quantized = torch.ao.quantization.convert(net_quantized)

In [27]:
print('Stats of various layers')
print(net_quantized)

Stats of various layers
QuantizedVerySimpleNet(
  (quant): Quantize(scale=tensor([0.0256], device='cuda:0'), zero_point=tensor([17], device='cuda:0'), dtype=torch.quint8)
  (linear1): QuantizedLinear(in_features=784, out_features=100, scale=0.6131376028060913, zero_point=75, qscheme=torch.per_tensor_affine)
  (linear2): QuantizedLinear(in_features=100, out_features=100, scale=0.38451090455055237, zero_point=72, qscheme=torch.per_tensor_affine)
  (linear3): QuantizedLinear(in_features=100, out_features=10, scale=0.3752467930316925, zero_point=73, qscheme=torch.per_tensor_affine)
  (relu): ReLU()
  (dequant): DeQuantize()
)


In [28]:
# Print the weights of the model after quantization
print('Weights after quantization')
print(torch.int_repr(net_quantized.linear1.weight()))

Weights after quantization
tensor([[ 1,  2,  5,  ...,  4,  1,  7],
        [ 7,  6,  1,  ...,  8,  0, -1],
        [10,  9, 10,  ...,  2, 12,  3],
        ...,
        [11,  5,  0,  ...,  3, 12, 13],
        [ 9,  0, -4,  ..., -1,  2,  4],
        [ 3,  6, -1,  ..., -4,  3,  4]], device='cuda:0', dtype=torch.int8)


In [32]:
print('Oringial weights:')
print(net.linear1.weight)
print('')
print(f'Dequantized weights')
print(torch.dequantize(net_quantized.linear1.weight()))
print('')

Oringial weights:
Parameter containing:
tensor([[ 0.0063,  0.0074,  0.0210,  ...,  0.0199,  0.0036,  0.0322],
        [ 0.0327,  0.0265,  0.0057,  ...,  0.0362,  0.0017, -0.0027],
        [ 0.0433,  0.0419,  0.0472,  ...,  0.0078,  0.0528,  0.0139],
        ...,
        [ 0.0480,  0.0226, -0.0004,  ...,  0.0114,  0.0536,  0.0587],
        [ 0.0399,  0.0004, -0.0184,  ..., -0.0054,  0.0094,  0.0184],
        [ 0.0131,  0.0252, -0.0057,  ..., -0.0183,  0.0114,  0.0165]],
       device='cuda:0', requires_grad=True)

Dequantized weights
tensor([[ 0.0045,  0.0091,  0.0227,  ...,  0.0182,  0.0045,  0.0318],
        [ 0.0318,  0.0272,  0.0045,  ...,  0.0363,  0.0000, -0.0045],
        [ 0.0454,  0.0409,  0.0454,  ...,  0.0091,  0.0545,  0.0136],
        ...,
        [ 0.0499,  0.0227,  0.0000,  ...,  0.0136,  0.0545,  0.0590],
        [ 0.0409,  0.0000, -0.0182,  ..., -0.0045,  0.0091,  0.0182],
        [ 0.0136,  0.0272, -0.0045,  ..., -0.0182,  0.0136,  0.0182]],
       device='cuda:0')



Print size and accuracy of the quantized model

In [33]:
print('Size of the model after quantization')
print_size_of_model(net_quantized)

Size of the model after quantization
Size (KB): 95.458


In [34]:
print('Testing the model after quantization')
test(net_quantized)

Testing the model after quantization


Testing:   0%|          | 0/1000 [00:00<?, ?it/s]


RuntimeError: Unable to find an engine to execute this computation Quantized Linear Cudnn