# Import the necessary libraries

In [6]:
import torchvision
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import lenet_model
from torch.utils.data import DataLoader
from torchmetrics import Accuracy
from pathlib import Path
from datetime import datetime
import time
from torch.utils.tensorboard import SummaryWriter
import os
from tqdm import tqdm

# Load the MNIST dataset

In [2]:
# Make torch deterministic
# _ = torch.manual_seed(0)

In [16]:
train_val_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=False, transform=transforms.ToTensor())
class_names = train_val_dataset.classes

# Calculate mean and std of the train dataset
imgs = torch.stack([img for img, _ in train_val_dataset], dim=0)
mean = imgs.view(1, -1).mean(dim=1)    # or imgs.mean()
std = imgs.view(1, -1).std(dim=1)     # or imgs.std()
# create Transformation (converting from Image class to Tensor and normalize)
mnist_transforms = transforms.Compose([transforms.ToTensor(),
                                       transforms.Normalize(mean=mean, std=std)])
mnist_trainset = torchvision.datasets.MNIST(root="./data", train=True, download=False, transform=mnist_transforms)
# split to train dataset and validation dataset
train_size = int(0.8 * len(mnist_trainset))
val_size = len(mnist_trainset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset=mnist_trainset, lengths=[train_size, val_size])

# load dataset and set number of data per batch
BATCH_SIZE = 32
train_dataloader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=True)


# Define the model

In [10]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
net = lenet_model.LeNet5().to(device)

# Insert min-max observers in the model

In [11]:
net.qconfig = torch.ao.quantization.default_qconfig
net.train()
net_quantized = torch.ao.quantization.prepare_qat(net) # Insert observers
net_quantized

LeNet5(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (dequant): DeQuantStub()
  (c1): Conv2d(
    1, 6, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (c2): Conv2d(
    6, 16, kernel_size=(5, 5), stride=(1, 1)
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (c3): Conv2d(
    16, 120, kernel_size=(5, 5), stride=(1, 1)
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
  (max_pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(
    in_features=120, out_features=84, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-i

# Train the model

In [27]:
def train(train_loader, net, epochs=12, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    total_iterations = 0

    for epoch in range(epochs):
        net.train()

        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = net(x)
            # output = net(x.view(-1, 28*28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return
            
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_delme.p")
    print('Size (KB):', os.path.getsize("temp_delme.p")/1e3)
    os.remove('temp_delme.p')

train(train_dataloader, net_quantized, epochs=12)

Epoch 1: 100%|██████████| 1500/1500 [00:31<00:00, 46.90it/s, loss=0.0548]
Epoch 2: 100%|██████████| 1500/1500 [00:34<00:00, 43.95it/s, loss=0.0395]
Epoch 3: 100%|██████████| 1500/1500 [00:37<00:00, 40.19it/s, loss=0.032] 
Epoch 4: 100%|██████████| 1500/1500 [00:32<00:00, 46.14it/s, loss=0.0253]
Epoch 5: 100%|██████████| 1500/1500 [00:36<00:00, 41.51it/s, loss=0.0217]
Epoch 6: 100%|██████████| 1500/1500 [00:37<00:00, 40.38it/s, loss=0.0179]
Epoch 7: 100%|██████████| 1500/1500 [00:36<00:00, 41.23it/s, loss=0.018] 
Epoch 8: 100%|██████████| 1500/1500 [00:37<00:00, 39.60it/s, loss=0.0134]
Epoch 9: 100%|██████████| 1500/1500 [00:31<00:00, 47.34it/s, loss=0.015] 
Epoch 10: 100%|██████████| 1500/1500 [00:33<00:00, 44.22it/s, loss=0.0108]
Epoch 11: 100%|██████████| 1500/1500 [00:32<00:00, 45.71it/s, loss=0.0135]
Epoch 12: 100%|██████████| 1500/1500 [00:34<00:00, 43.40it/s, loss=0.00882]


# Define the testing loop

In [36]:
def test(model: nn.Module, total_iterations: int = None):
    correct = 0
    total = 0

    iterations = 0

    model.eval()

    with torch.no_grad():
        for data in tqdm(val_dataloader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = model(x)
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct +=1
                total +=1
            iterations += 1
            if total_iterations is not None and iterations >= total_iterations:
                break
    print(f'Accuracy: {round(correct/total, 3)}')

# Check the collected statistics during training

In [29]:
print(f'Check statistics of the various layers')
net_quantized

Check statistics of the various layers


LeNet5(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-0.42407387495040894, max_val=2.8215432167053223)
  )
  (dequant): DeQuantStub()
  (c1): Conv2d(
    1, 6, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)
    (weight_fake_quant): MinMaxObserver(min_val=-0.5149684548377991, max_val=0.3841245174407959)
    (activation_post_process): MinMaxObserver(min_val=-8.90822982788086, max_val=7.721226692199707)
  )
  (c2): Conv2d(
    6, 16, kernel_size=(5, 5), stride=(1, 1)
    (weight_fake_quant): MinMaxObserver(min_val=-0.6840593218803406, max_val=0.4230343997478485)
    (activation_post_process): MinMaxObserver(min_val=-27.688060760498047, max_val=14.46810531616211)
  )
  (c3): Conv2d(
    16, 120, kernel_size=(5, 5), stride=(1, 1)
    (weight_fake_quant): MinMaxObserver(min_val=-0.8825253844261169, max_val=0.5730369687080383)
    (activation_post_process): MinMaxObserver(min_val=-43.66505432128906, max_val=31.585952758789062)
  )
  (relu): ReLU()
  (max_pool

# Quantize the model using the statistics collected

In [30]:
net_quantized.eval()
net_quantized = torch.ao.quantization.convert(net_quantized)

In [31]:
print(f'Check statistics of the various layers')
net_quantized

Check statistics of the various layers


LeNet5(
  (quant): Quantize(scale=tensor([0.0256]), zero_point=tensor([17]), dtype=torch.quint8)
  (dequant): DeQuantize()
  (c1): QuantizedConv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), scale=0.13094060122966766, zero_point=68, padding=(2, 2))
  (c2): QuantizedConv2d(6, 16, kernel_size=(5, 5), stride=(1, 1), scale=0.33193832635879517, zero_point=83)
  (c3): QuantizedConv2d(16, 120, kernel_size=(5, 5), stride=(1, 1), scale=0.5925276279449463, zero_point=74)
  (relu): ReLU()
  (max_pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): QuantizedLinear(in_features=120, out_features=84, scale=0.5568221807479858, zero_point=67, qscheme=torch.per_tensor_affine)
  (fc2): QuantizedLinear(in_features=84, out_features=10, scale=1.0586904287338257, zero_point=84, qscheme=torch.per_tensor_affine)
)

# Print weights and size of the model after quantization

In [None]:
# Print the weights matrix of the model before quantization
print(torch.int_repr(net_quantized.c1.weight()))

Weights after quantization
tensor([[[[ 63,  51,  38, -45,   4],
          [-30, -46, -67, -70, -36],
          [-57, -81, -64,  52,  45],
          [  4,   5,  67,  79,  41],
          [  9,  56,  38, -42, -67]]],


        [[[ 41,   6,   4, -35, -19],
          [ 31,  53, -30, -52, -28],
          [ 95,  -1, -54, -69,   3],
          [ 59, -47, -69, -50,   6],
          [ 27, -58, -17,   2,   8]]],


        [[[ 19,  21,  64,  29,  24],
          [ 11,  82,  37,  44,  74],
          [ 42,  -7,   2,  32,  26],
          [  8, -95, -65, -92, -35],
          [-73, -95, -60, -74, -39]]],


        [[[ 34,   2,  25,   2, -49],
          [-32,  79,   2,  35,  -4],
          [-94, -24,  25, -17, -33],
          [-63,  11,  -7,  39,   3],
          [ 13,  -7,   5,  50, -47]]],


        [[[-16, -29,  13,  20,  27],
          [-61, -13,  36, -25,  56],
          [-38, -61,  38,  61, -19],
          [-66,   0,  16,  58,   9],
          [-77,  -1,  32,  34, -23]]],


        [[[ -8, -27, -39, -3

In [37]:
print('Testing the model after quantization')
test(net_quantized)

Testing the model after quantization


Testing: 100%|██████████| 375/375 [00:03<00:00, 108.96it/s]

Accuracy: 0.989





# Save model weights

In [46]:
MODEL_PATH = Path("models")
MODEL_PATH.mkdir(parents=True, exist_ok=True)
time_start = datetime.now().strftime('%Y%m%d_%H%M')
MODEL_NAME = f"lenet5_mnist_{time_start}.pth"
MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME
torch.save(net_quantized.state_dict(), f=MODEL_PATH / f"quantized_{MODEL_NAME}")