In [5]:
import torch
import torchvision.datasets as datasets 
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path
import os

In [6]:
# Make torch deterministic
_ = torch.manual_seed(0)

In [14]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Load the MNIST test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

# Define the device
device = "cpu"

In [15]:
class simpleModel(nn.Module):
    def __init__(self, hidden_size_1 = 100, hidden_size_2 = 100):
        super(simpleModel, self).__init__()
        self.linear1 = nn.Linear(28 * 28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
        
    def forward(self, img):
        x = img.view(-1, 28 * 28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

In [16]:
net = simpleModel().to(device)

In [18]:
def train(train_loader, net, epochs = 5, total_iterations_limit = None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr = 0.001)
    
    total_iterations = 0
    
    for epoch in range(epochs):
        net.train()
        
        loss_sum = 0
        num_iterations = 0
        
        data_iterator = tqdm(train_loader, desc = f'Epoch {epoch + 1}')
        
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
            
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            output = net(x.view(-1, 28*28))
            loss = cross_el(output, y)
            loss_sum += loss
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss = avg_loss)
            loss.backward()
            optimizer.step()
            
            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return
            
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_delme.p")
    print('Size (KB):', os.path.getsize("temp_delme.p")/1e3)
    os.remove('temp_delme.p')

MODEL_FILENAME = 'simpleModel.pt'

if Path(MODEL_FILENAME).exists():
    net.load_state_dict(torch.load(MODEL_FILENAME))
    print('Loaded model from disk')
else:
    train(train_loader, net, epochs=1)
    # Save the model to disk
    torch.save(net.state_dict(), MODEL_FILENAME)

Epoch 1: 100%|██████████| 6000/6000 [00:23<00:00, 259.57it/s, loss=tensor(0.2224, grad_fn=<DivBackward0>)]


In [19]:
def test(model: nn.Module, total_iterations: int = None):
    correct = 0
    total = 0
    iterations = 0
    
    model.eval()
    with torch.no_grad():
        for data in tqdm(test_loader, desc = "Testing"):
            x, y = data
            x, y = x.to(device), y.to(device)
            output = model(x.view(-1, 28*28))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct += 1
                total += 1
            iterations +=1 
            if total_iterations is not None and iterations >= total_iterations:
                break
    print(f'Accuracy: {round(correct/total, 3)}')

In [20]:
# Print the weights matrix of the model before quantization
print('Weights before quantization')
print(net.linear1.weight)
print(net.linear1.weight.dtype)

Weights before quantization
Parameter containing:
tensor([[ 0.0319,  0.0403,  0.0439,  ...,  0.0155,  0.0443,  0.0091],
        [ 0.0340,  0.0353,  0.0034,  ...,  0.0424,  0.0112, -0.0212],
        [-0.0038, -0.0030,  0.0518,  ...,  0.0362,  0.0360,  0.0053],
        ...,
        [ 0.0590,  0.0271,  0.0055,  ...,  0.0082,  0.0328,  0.0242],
        [ 0.0397,  0.0019,  0.0072,  ..., -0.0057, -0.0257, -0.0216],
        [ 0.0643,  0.0465,  0.0525,  ...,  0.0143,  0.0785,  0.0345]],
       requires_grad=True)
torch.float32


In [21]:
print('Size of the model before quantization')
print_size_of_model(net)

Size of the model before quantization
Size (KB): 361.401


In [22]:
print(f'Accuracy of the model before quantization: ')
test(net)

Accuracy of the model before quantization: 


Testing: 100%|██████████| 1000/1000 [00:01<00:00, 705.46it/s]

Accuracy: 0.948





Insert Min-Max Observers in the model

In [23]:
class QuantizedSimpleModel(nn.Module):
    def __init__(self, hidden_size_1 = 100, hidden_size_2 = 100):
        super(QuantizedSimpleModel, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.linear1 = nn.Linear(28 * 28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub()
        
    def forward(self, img):
        x = img.view(-1, 28 * 28)
        x = self.quant(x)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        x = self.dequant(x)
        return x

In [24]:
net_quantized = QuantizedSimpleModel().to(device)
# Copy weights from unquantized model
net_quantized.load_state_dict(net.state_dict())
net_quantized.eval()

net_quantized.qconfig = torch.ao.quantization.default_qconfig
net_quantized = torch.ao.quantization.prepare(net_quantized) # Insert observers
net_quantized

QuantizedSimpleModel(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

In [25]:
test(net_quantized)

Testing: 100%|██████████| 1000/1000 [00:01<00:00, 675.84it/s]

Accuracy: 0.948





Quantize the model using the statistics collected

In [26]:
net_quantized = torch.ao.quantization.convert(net_quantized)

In [27]:
print(f'Check statistics of the various layers')
net_quantized

Check statistics of the various layers


QuantizedSimpleModel(
  (quant): Quantize(scale=tensor([0.0256]), zero_point=tensor([17]), dtype=torch.quint8)
  (linear1): QuantizedLinear(in_features=784, out_features=100, scale=0.5807303190231323, zero_point=70, qscheme=torch.per_tensor_affine)
  (linear2): QuantizedLinear(in_features=100, out_features=100, scale=0.437996506690979, zero_point=74, qscheme=torch.per_tensor_affine)
  (linear3): QuantizedLinear(in_features=100, out_features=10, scale=0.4465687870979309, zero_point=78, qscheme=torch.per_tensor_affine)
  (relu): ReLU()
  (dequant): DeQuantize()
)

In [28]:
# Print the weights matrix of the model after quantization
print('Weights after quantization')
print(torch.int_repr(net_quantized.linear1.weight()))

Weights after quantization
tensor([[ 7,  8,  9,  ...,  3,  9,  2],
        [ 7,  7,  1,  ...,  9,  2, -4],
        [-1, -1, 11,  ...,  7,  7,  1],
        ...,
        [12,  6,  1,  ...,  2,  7,  5],
        [ 8,  0,  1,  ..., -1, -5, -4],
        [13, 10, 11,  ...,  3, 16,  7]], dtype=torch.int8)


In [29]:
print('Size of the model after quantization')
print_size_of_model(net_quantized)

Size of the model after quantization
Size (KB): 95.797
