**TODO**: Remove below (drive-specific)

In [1]:
#from google.colab import drive
#drive.mount('/content/drive')
#!ln -s /content/drive/MyDrive/Didattica/OENNE_notebooks/utils .

**TODO**: Remove below (server-specific)

In [2]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=2

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=2


## Hands-on #3: Quantization with PLiNIO



In this notebook, you will:
1. Load the optimized and pruned DNN found at the end of Hands-on #2
2. Apply Quantization-Aware Training (QAT) to it.
3. Export the final model in an ONNX format compatible with the AI Compiler that you will use in Hands-on #4.

Considering the flow seen in class, we are here:

![qat.png]()

# Part 0: Initial Setup

As usual, we start by importing required libraries:

In [3]:
import os
import sys
import random
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

from torch.utils.data import DataLoader
from torchinfo import summary

from plinio.cost import params_bit
from plinio.methods import MPS
from plinio.methods.mps import get_default_qinfo
from plinio.methods.mps.quant.quantizers import PACTAct
from plinio.methods.mps.quant.backends import Backend, integerize_arch
from plinio.methods.mps.quant.backends.match import MATCHExporter

import pytorch_benchmarks.image_classification as icl
from pytorch_benchmarks.utils import CheckPoint, EarlyStopping

from utils.train import set_seed, try_load_checkpoint
from utils.plot import plot_learning_curves

And repeat the initial configurations:

In [4]:
SAVE_DIR = Path(f"experiments/02/")

TRAINING_CONFIG = {
    'in_class': False,          # kept for compatibility with hands-on #1. Leave it as false!
    'epochs': 50,               # max epochs for normal trainings
    'nas_epochs': 100,          # max epochs for the NAS search loop
    'nas_no_stop_epochs': 20,   # initial epochs without early stopping for the NAS
    'batch_size': 32,           # batch size
    'lr': 0.1,                  # initial learning rate for normal trainings
    'search_lr_net': 0.001,     # learning rate for DNN weights during NAS
    'search_lr_nas': 0.001,     # learning rate for NAS parameters during NAS
    'weight_decay': 1e-4,       # weight decay for normal DNN parameters
    'patience': 10,             # early-stopping patience for normal trainings
    'nas_patience': 10,         # early-stopping patience for NAS search
}

set_seed(42)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Working on: {device}")

Working on: cuda:0


# Part 1: Dataset Preparation

Dataset preparation is identical to the previous notebook:

In [5]:
datasets = icl.get_data()
dataloaders = icl.build_dataloaders(datasets, batch_size=TRAINING_CONFIG['batch_size'])
train_dl, val_dl, test_dl = dataloaders

input_shape = datasets[0][0][0].numpy().shape
batch_shape = (1,) + input_shape

Files already downloaded and verified
Files already downloaded and verified


# Part 2: Quantization

All DNN models considered up to now used **32-bit floating point** for internal operations, and for storing weights and activations. However, our hardware target only supports Quantized DNN inference, using **8-bit integers**. Therefore, we need to convert our model to that format before we can export it and compile it.

Simply quantizing a model by replacing all floating point data with their closest integer approximation (the most basic form of Post-Training Quantization) could worsen its accuracy. Fortunately, this drop can often be recovered by running some epochs of the so-called **Quantization-Aware Training (QAT)**, as seen in class.

PLiNIO an be used to perform QAT on our model, as well as allowing to export the final "full integer" model in a format compatible with the compiler used in one of the next sessions.  

More precisely, PLiNIO's QAT function is embedded in the `MPS()` class, which is used to perform a more advanced optimization: **Mixed-Precision Search**. This optimization applies QAT at *multiple bit-widths* simultaneously, and uses a SuperNet-like method to select the *best precision assignment* for the weights and activations of different portions of a DNN (different layers, or even different channels of the same layer).
The optimization can be driven by a two-terms loss function considering accuracy and cost, similar to the one used with SuperNet and PIT.

We will not use MPS in this session, since our target hardware and backend library do not support $<8$ bit inference (*). However, we can still use PLiNIO to perform a simple QAT run, by simply reducing it to a **"corner case" of MPS, with a single precision** (8-bit) to select from.

If you're interested in the details on the MPS algorithm present in PLiNIO, check-out these two papers: [link1](https://arxiv.org/abs/2206.08852), [link2](https://arxiv.org/abs/2004.05795). Feel free to also try applying MPS with multiple precisions on our DNN as an extra. Although we won't be able to deploy models with precisions different from 8-bit, it could still be interesting to check how much we can compress the weights without losing too much accuracy.

 
(*) Actually, the DNN accelerator present in GAP9 would support those representations, but we will only deploy on the multi-core RISC-V cluster.

## Importing the Model

Let's start by loading the final model from Hands-on #2 (Optimized and Pruned):

In [6]:
MODEL_PATH = Path("./experiments/02/final_model.pt")
model = torch.load(MODEL_PATH).eval()

Quickly verify that it's correctly loaded:

In [7]:
criterion = nn.CrossEntropyLoss()
test_metrics = icl.evaluate(model, criterion, test_dl, device)
size = summary(model, batch_shape).total_params
print(f'Size: {size}, Test Loss: {test_metrics["loss"]}, Test Acc: {test_metrics["acc"]}')

Size: 50634, Test Loss: 1.9035611152648926, Test Acc: 28.84000015258789


## Preparing the Model

The constructor of the `MPS()` class in PLiNIO is similar to the one of PIT. The parameters are similar, and the conversion is mostly transparent. 

Note that we can ignore the `cost` parameter, if we're interested in just QAT. When running an actual MPS optimization on the DNN weights, you can for instance set this parameter to `params_bit`, a cost model that accounts for the precision (in bits) for each DNN parameter.

The only key difference w.r.t. to other methods, is that `MPS` also expect a `qinfo` dictionary, containing settings on the desired type of quantization to apply for different parts of the network.

The settings in `qinfo` include the quantization algorithm to use for weights and activations (e.g. min-max, PaCT, etc), and optional configuration parameters. In our case, it suffices to use the reasonable default settings provided by PLiNIO, by calling the `get_default_qinfo()` function. This function expects as input parameters the tuple of weights and activations bitwidths to be included in the optimization (in our case, only 8-bit for both).

There's just one thing to customize in the default `qinfo`, namely the range of the DNN **input** quantizer. In fact, since we know that our (float) data is in the $[0, 1]$ range, we can set the initial range of the quantizer to be the same. This should facilitate the conversion.

In [8]:
# get the default qinfo dictionary, specifying 8-bit as the only precision for both weights and activations
qinfo = get_default_qinfo((8,), (8,))

# modify the default qinfo for the input layer, since we're using signed data in the [0, 1] range
qinfo['input_default']['quantizer'] = PACTAct
qinfo['input_default']['kwargs'] = {'init_clip_val': +1}

# call the PLiNIO constructor
mps_model = MPS(model, input_shape, qinfo=qinfo)
mps_model = mps_model.to(device)

Estimated DNN cost: 49418.0


### Looking at the Pruning Masks

Similarly to the SuperNet, we can look at the initial values of the PIT pruning masks:

In [9]:
with torch.no_grad():
    for p in nas_model.nas_parameters(): 
        print((torch.abs(p)>0.5).int().cpu().numpy())

[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1

**Question**: Is the number of mask values expected? Why?

### Setting the Regularization Strength

Identically to the SuperNet case, we have to set the value of $\lambda$ for our combined loss function. Note that PIT generally requires *lower strength* values. However, as mentioned before, there isn't a golden rule here, unfortunately. Some trial and error is required (or a more advanced regulatization method such as [DUCCIO](https://ieeexplore.ieee.org/abstract/document/10278089). As a suggestion, try values $\le 10^{-6}$.


In [10]:
TRAINING_CONFIG['reg_strength'] = 0 # (result around 83% acc - after fine-tuning, and 70k params - almost no pruning)
#TRAINING_CONFIG['reg_strength'] = 1e-06

## Run the NAS Loop

For running the NAS optimization, we can reuse entirely the code seen in Hands-on #1. Thanks to the `ipynb` Python package, we can load definitions (classes, functions, etc) defined in another Jupyter notebook. Let's use it to load our NAS loop from Hands-on #1. Thanks to the transparent API of PLiNIO, this code, initially written for a SuperNet optimization, will work fine also for PIT. Clearly, to obtain optimal results, one would need to tweak with the parameters, which in some cases could require some code rewriting. However, for this basic example, reusing 100\% of the NAS loop will suffice.

The next cell runs the optimization:

In [None]:
from ipynb.fs.defs.I_SuperNet import nas_loop

criterion = nn.CrossEntropyLoss()
history = nas_loop(SAVE_DIR / 'nas', TRAINING_CONFIG, nas_model, criterion, train_dl, val_dl, device)

Epoch 1: 100%|█████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:06<00:00, 190.15batch/s, loss=1.68, acc=37.4, val_loss=1.56, val_acc=42.5]
Epoch 1: 100%|████████████████████████████████████████████████████████████████████████████████████| 313/313 [00:03<00:00, 88.05batch/s, loss=1.56, acc=42.9, val_loss=1.55, val_acc=42.8]


Network cost after epoch 1 = 49418.0


Epoch 2: 100%|█████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:06<00:00, 186.96batch/s, loss=1.52, acc=44.3, val_loss=1.49, val_acc=46.1]
Epoch 2: 100%|████████████████████████████████████████████████████████████████████████████████████| 313/313 [00:03<00:00, 89.14batch/s, loss=1.49, acc=45.7, val_loss=1.49, val_acc=46.6]


Network cost after epoch 2 = 49418.0


Epoch 3: 100%|███████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:06<00:00, 189.66batch/s, loss=1.44, acc=47.4, val_loss=1.42, val_acc=48]
Epoch 3: 100%|██████████████████████████████████████████████████████████████████████████████████████| 313/313 [00:03<00:00, 87.90batch/s, loss=1.43, acc=48, val_loss=1.42, val_acc=48.6]


Network cost after epoch 3 = 49279.0


Epoch 4: 100%|█████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:07<00:00, 165.57batch/s, loss=1.39, acc=49.5, val_loss=1.41, val_acc=48.7]
Epoch 4: 100%|████████████████████████████████████████████████████████████████████████████████████████| 313/313 [00:03<00:00, 83.65batch/s, loss=1.42, acc=48, val_loss=1.46, val_acc=47]


Network cost after epoch 4 = 48444.0


Epoch 5: 100%|█████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:06<00:00, 186.35batch/s, loss=1.36, acc=50.8, val_loss=1.37, val_acc=50.7]
Epoch 5: 100%|████████████████████████████████████████████████████████████████████████████████████| 313/313 [00:03<00:00, 84.92batch/s, loss=1.37, acc=51.3, val_loss=1.36, val_acc=51.1]


Network cost after epoch 5 = 48168.0


Epoch 6: 100%|█████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:07<00:00, 177.33batch/s, loss=1.33, acc=52.2, val_loss=1.31, val_acc=52.8]
Epoch 6: 100%|████████████████████████████████████████████████████████████████████████████████████| 313/313 [00:04<00:00, 70.48batch/s, loss=1.31, acc=52.5, val_loss=1.32, val_acc=52.4]


Network cost after epoch 6 = 48168.0


Epoch 7: 100%|█████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:07<00:00, 176.89batch/s, loss=1.29, acc=53.2, val_loss=1.28, val_acc=53.3]
Epoch 7: 100%|████████████████████████████████████████████████████████████████████████████████████| 313/313 [00:04<00:00, 75.02batch/s, loss=1.29, acc=53.8, val_loss=1.29, val_acc=53.2]


Network cost after epoch 7 = 48168.0


Epoch 8: 100%|███████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:07<00:00, 171.97batch/s, loss=1.27, acc=54.3, val_loss=1.27, val_acc=54]
Epoch 8: 100%|████████████████████████████████████████████████████████████████████████████████████| 313/313 [00:04<00:00, 76.92batch/s, loss=1.27, acc=54.9, val_loss=1.26, val_acc=54.8]


Network cost after epoch 8 = 47585.0


Epoch 9: 100%|█████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:07<00:00, 166.85batch/s, loss=1.25, acc=54.8, val_loss=1.27, val_acc=54.6]
Epoch 9: 100%|████████████████████████████████████████████████████████████████████████████████████████| 313/313 [00:03<00:00, 87.64batch/s, loss=1.3, acc=53, val_loss=1.3, val_acc=52.9]


Network cost after epoch 9 = 47189.0


Epoch 10: 100%|██████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:07<00:00, 169.49batch/s, loss=1.23, acc=56.1, val_loss=1.24, val_acc=55]
Epoch 10: 100%|████████████████████████████████████████████████████████████████████████████████████| 313/313 [00:04<00:00, 77.06batch/s, loss=1.27, acc=54.7, val_loss=1.9, val_acc=35.6]


Network cost after epoch 10 = 46707.0


Epoch 11: 100%|██████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:07<00:00, 175.05batch/s, loss=1.31, acc=53, val_loss=1.28, val_acc=54.1]
Epoch 11: 100%|███████████████████████████████████████████████████████████████████████████████████| 313/313 [00:04<00:00, 72.56batch/s, loss=1.38, acc=51.7, val_loss=1.29, val_acc=53.6]


Network cost after epoch 11 = 45471.0


Epoch 12: 100%|████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:06<00:00, 180.54batch/s, loss=1.27, acc=54.8, val_loss=1.27, val_acc=54.9]
Epoch 12: 100%|███████████████████████████████████████████████████████████████████████████████████| 313/313 [00:04<00:00, 73.06batch/s, loss=1.27, acc=54.5, val_loss=1.26, val_acc=54.7]


Network cost after epoch 12 = 45607.0


Epoch 13: 100%|████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:06<00:00, 186.60batch/s, loss=1.23, acc=55.8, val_loss=1.22, val_acc=56.3]
Epoch 13: 100%|█████████████████████████████████████████████████████████████████████████████████████| 313/313 [00:04<00:00, 76.81batch/s, loss=1.25, acc=54.8, val_loss=1.26, val_acc=55]


Network cost after epoch 13 = 45073.0


Epoch 14: 100%|████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:06<00:00, 180.25batch/s, loss=1.22, acc=56.4, val_loss=1.25, val_acc=55.2]
Epoch 14: 100%|███████████████████████████████████████████████████████████████████████████████████| 313/313 [00:03<00:00, 82.21batch/s, loss=1.25, acc=55.1, val_loss=1.24, val_acc=55.8]


Network cost after epoch 14 = 44627.0


Epoch 15: 100%|█████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:07<00:00, 163.74batch/s, loss=1.19, acc=57.3, val_loss=1.2, val_acc=57.2]
Epoch 15: 100%|███████████████████████████████████████████████████████████████████████████████████| 313/313 [00:03<00:00, 88.66batch/s, loss=1.23, acc=56.2, val_loss=1.23, val_acc=56.2]


Network cost after epoch 15 = 44139.0


Epoch 16: 100%|████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:07<00:00, 161.29batch/s, loss=1.18, acc=57.9, val_loss=1.22, val_acc=56.4]
Epoch 16: 100%|███████████████████████████████████████████████████████████████████████████████████| 313/313 [00:03<00:00, 83.50batch/s, loss=1.23, acc=56.5, val_loss=1.22, val_acc=56.4]


Network cost after epoch 16 = 43537.0


Epoch 17: 100%|████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:06<00:00, 181.66batch/s, loss=1.17, acc=58.3, val_loss=1.18, val_acc=57.7]
Epoch 17: 100%|████████████████████████████████████████████████████████████████████████████████████| 313/313 [00:04<00:00, 71.11batch/s, loss=1.2, acc=56.6, val_loss=1.18, val_acc=57.5]


Network cost after epoch 17 = 43319.0


Epoch 18: 100%|████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:06<00:00, 187.40batch/s, loss=1.16, acc=58.6, val_loss=1.22, val_acc=56.9]
Epoch 18: 100%|███████████████████████████████████████████████████████████████████████████████████| 313/313 [00:03<00:00, 92.50batch/s, loss=1.25, acc=55.4, val_loss=1.24, val_acc=55.4]


Network cost after epoch 18 = 42319.0


Epoch 19: 100%|███████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:06<00:00, 193.04batch/s, loss=1.16, acc=58.4, val_loss=1.2, val_acc=57]
Epoch 19: 100%|█████████████████████████████████████████████████████████████████████████████████████| 313/313 [00:03<00:00, 81.28batch/s, loss=1.22, acc=56.9, val_loss=1.24, val_acc=56]


Network cost after epoch 19 = 42163.0


Epoch 20: 100%|████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:06<00:00, 184.12batch/s, loss=1.15, acc=58.9, val_loss=1.19, val_acc=58.1]
Epoch 20: 100%|█████████████████████████████████████████████████████████████████████████████████████| 313/313 [00:03<00:00, 90.95batch/s, loss=1.18, acc=58, val_loss=1.18, val_acc=57.6]


Network cost after epoch 20 = 41150.0


Epoch 21: 100%|████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:06<00:00, 194.99batch/s, loss=1.14, acc=59.6, val_loss=1.15, val_acc=58.9]
Epoch 21: 100%|████████████████████████████████████████████████████████████████████████████████████| 313/313 [00:03<00:00, 90.37batch/s, loss=1.2, acc=57.6, val_loss=1.16, val_acc=58.9]


Network cost after epoch 21 = 40679.0


Epoch 22: 100%|████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:06<00:00, 193.69batch/s, loss=1.12, acc=60.1, val_loss=1.17, val_acc=58.6]
Epoch 22: 100%|██████████████████████████████████████████████████████████████████████████████████████| 313/313 [00:04<00:00, 77.11batch/s, loss=1.32, acc=52.9, val_loss=1.2, val_acc=57]


Network cost after epoch 22 = 40253.0


Epoch 23: 100%|████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:06<00:00, 192.81batch/s, loss=1.12, acc=59.9, val_loss=1.16, val_acc=58.8]
Epoch 23: 100%|████████████████████████████████████████████████████████████████████████████████████| 313/313 [00:03<00:00, 88.12batch/s, loss=1.4, acc=49.8, val_loss=1.55, val_acc=44.9]


Network cost after epoch 23 = 39183.0


Epoch 24: 100%|████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:06<00:00, 185.16batch/s, loss=1.18, acc=57.8, val_loss=1.15, val_acc=58.9]
Epoch 24: 100%|███████████████████████████████████████████████████████████████████████████████████| 313/313 [00:04<00:00, 77.43batch/s, loss=1.14, acc=59.2, val_loss=1.15, val_acc=59.5]


Network cost after epoch 24 = 38933.0


Epoch 25: 100%|████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:06<00:00, 199.29batch/s, loss=1.12, acc=59.7, val_loss=1.14, val_acc=58.9]
Epoch 25: 100%|███████████████████████████████████████████████████████████████████████████████████| 313/313 [00:03<00:00, 86.42batch/s, loss=1.15, acc=58.9, val_loss=1.13, val_acc=59.8]


Network cost after epoch 25 = 38847.0


Epoch 26: 100%|████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:06<00:00, 178.60batch/s, loss=1.11, acc=60.2, val_loss=1.16, val_acc=59.1]
Epoch 26: 100%|█████████████████████████████████████████████████████████████████████████████████████| 313/313 [00:03<00:00, 92.77batch/s, loss=1.15, acc=59.7, val_loss=1.17, val_acc=58]


Network cost after epoch 26 = 37898.0


Epoch 27: 100%|████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:06<00:00, 187.56batch/s, loss=1.12, acc=60.1, val_loss=1.13, val_acc=59.8]
Epoch 27: 100%|███████████████████████████████████████████████████████████████████████████████████| 313/313 [00:04<00:00, 75.83batch/s, loss=1.15, acc=58.8, val_loss=1.13, val_acc=59.7]


Network cost after epoch 27 = 36999.0


Epoch 28: 100%|█████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:06<00:00, 199.01batch/s, loss=1.1, acc=60.9, val_loss=1.13, val_acc=59.2]
Epoch 28: 100%|█████████████████████████████████████████████████████████████████████████████████████| 313/313 [00:03<00:00, 90.92batch/s, loss=1.16, acc=58.6, val_loss=1.17, val_acc=58]


Network cost after epoch 28 = 36795.0


Epoch 29: 100%|██████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:06<00:00, 198.73batch/s, loss=1.12, acc=60, val_loss=1.15, val_acc=58.8]
Epoch 29: 100%|███████████████████████████████████████████████████████████████████████████████████| 313/313 [00:03<00:00, 83.00batch/s, loss=1.14, acc=59.5, val_loss=1.13, val_acc=59.6]


Network cost after epoch 29 = 36040.0


Epoch 30: 100%|█████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:06<00:00, 187.59batch/s, loss=1.1, acc=60.8, val_loss=1.13, val_acc=59.3]
Epoch 30: 100%|███████████████████████████████████████████████████████████████████████████████████| 313/313 [00:03<00:00, 89.50batch/s, loss=1.14, acc=59.5, val_loss=1.13, val_acc=59.6]


Network cost after epoch 30 = 35884.0


Epoch 31: 100%|████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:06<00:00, 198.53batch/s, loss=1.08, acc=61.2, val_loss=1.12, val_acc=60.2]
Epoch 31: 100%|███████████████████████████████████████████████████████████████████████████████████| 313/313 [00:03<00:00, 92.54batch/s, loss=1.13, acc=59.7, val_loss=1.13, val_acc=59.8]


Network cost after epoch 31 = 34863.0


Epoch 32: 100%|████████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:06<00:00, 185.49batch/s, loss=1.09, acc=61, val_loss=1.11, val_acc=61]
Epoch 32: 100%|███████████████████████████████████████████████████████████████████████████████████| 313/313 [00:04<00:00, 77.93batch/s, loss=1.11, acc=59.8, val_loss=1.12, val_acc=59.9]


Network cost after epoch 32 = 34863.0


Epoch 33: 100%|█████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:06<00:00, 190.75batch/s, loss=1.08, acc=61.4, val_loss=1.1, val_acc=61.2]
Epoch 33: 100%|███████████████████████████████████████████████████████████████████████████████████| 313/313 [00:03<00:00, 80.07batch/s, loss=1.21, acc=57.8, val_loss=1.22, val_acc=57.1]


Network cost after epoch 33 = 34392.0


Epoch 34: 100%|████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:06<00:00, 185.00batch/s, loss=1.09, acc=61.1, val_loss=1.11, val_acc=60.7]
Epoch 34: 100%|███████████████████████████████████████████████████████████████████████████████████| 313/313 [00:03<00:00, 84.90batch/s, loss=1.11, acc=60.5, val_loss=1.09, val_acc=60.8]


Network cost after epoch 34 = 33295.0


Epoch 35: 100%|████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:07<00:00, 166.81batch/s, loss=1.08, acc=61.7, val_loss=1.13, val_acc=59.4]
Epoch 35: 100%|███████████████████████████████████████████████████████████████████████████████████| 313/313 [00:03<00:00, 90.32batch/s, loss=1.12, acc=60.6, val_loss=1.13, val_acc=60.1]


Network cost after epoch 35 = 32892.0


Epoch 36: 100%|████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:07<00:00, 168.51batch/s, loss=1.07, acc=61.7, val_loss=1.12, val_acc=59.8]
Epoch 36: 100%|███████████████████████████████████████████████████████████████████████████████████| 313/313 [00:03<00:00, 89.68batch/s, loss=1.11, acc=60.6, val_loss=1.11, val_acc=60.5]


Network cost after epoch 36 = 32503.0


Epoch 37: 100%|██████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:07<00:00, 171.20batch/s, loss=1.07, acc=61.8, val_loss=1.09, val_acc=61]
Epoch 37: 100%|████████████████████████████████████████████████████████████████████████████████████| 313/313 [00:03<00:00, 85.15batch/s, loss=1.09, acc=60.8, val_loss=1.1, val_acc=60.7]


Network cost after epoch 37 = 31967.0


Epoch 38: 100%|█████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:06<00:00, 198.29batch/s, loss=1.07, acc=61.8, val_loss=1.1, val_acc=60.8]
Epoch 38: 100%|███████████████████████████████████████████████████████████████████████████████████| 313/313 [00:04<00:00, 78.03batch/s, loss=1.09, acc=61.2, val_loss=1.09, val_acc=61.4]


Network cost after epoch 38 = 31853.0


Epoch 39: 100%|███████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:06<00:00, 180.67batch/s, loss=1.06, acc=62, val_loss=1.1, val_acc=61.3]
Epoch 39: 100%|███████████████████████████████████████████████████████████████████████████████████| 313/313 [00:03<00:00, 87.21batch/s, loss=1.11, acc=60.5, val_loss=1.09, val_acc=60.7]


Network cost after epoch 39 = 31219.0


Epoch 40: 100%|████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:06<00:00, 184.36batch/s, loss=1.06, acc=62.4, val_loss=1.08, val_acc=61.9]
Epoch 40: 100%|███████████████████████████████████████████████████████████████████████████████████| 313/313 [00:03<00:00, 79.06batch/s, loss=1.07, acc=62.4, val_loss=1.08, val_acc=62.1]


Network cost after epoch 40 = 30771.0


Epoch 41: 100%|████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:06<00:00, 200.29batch/s, loss=1.05, acc=62.2, val_loss=1.09, val_acc=61.4]
Epoch 41: 100%|█████████████████████████████████████████████████████████████████████████████████████| 313/313 [00:03<00:00, 91.04batch/s, loss=1.09, acc=61.4, val_loss=1.07, val_acc=62]


Network cost after epoch 41 = 29929.0


Epoch 42: 100%|████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:06<00:00, 192.51batch/s, loss=1.05, acc=62.6, val_loss=1.09, val_acc=61.4]
Epoch 42: 100%|███████████████████████████████████████████████████████████████████████████████████| 313/313 [00:03<00:00, 80.33batch/s, loss=1.09, acc=61.6, val_loss=1.08, val_acc=61.9]


Network cost after epoch 42 = 29929.0


Epoch 43: 100%|████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:07<00:00, 171.46batch/s, loss=1.05, acc=62.6, val_loss=1.09, val_acc=61.4]
Epoch 43: 100%|████████████████████████████████████████████████████████████████████████████████████| 313/313 [00:03<00:00, 92.80batch/s, loss=1.09, acc=60.6, val_loss=1.1, val_acc=60.8]


Network cost after epoch 43 = 29467.0


Epoch 44:  29%|██████████████████████████████▊                                                                           | 363/1250 [00:01<00:03, 228.83batch/s, loss=1.02044, acc=63.43]

### Evaluating the Pruned Model

Let's check the test accuracy of the pruned DNN after applying PIT.

In [None]:
test_metrics = icl.evaluate(nas_model, criterion, test_dl, device)
print(f'Final model cost: {nas_model.cost}, Test Loss: {test_metrics["loss"]}, Test Acc: {test_metrics["acc"]}')

Depending on the regularization strength that you set, you should see that the cost (number of parameters) has reduced significantly once again, possibly at the cost of some accuracy degradation.

### Looking at the Masks (After the Search)

**Question:** Let's look again at the $\theta$ parameters. Have they changed? How? Which layers have been pruned the most? Is there one layer that *hasn't* been pruned at all? Which one and why?


In [None]:
with torch.no_grad():
    for p in nas_model.nas_parameters(): 
        print((torch.abs(p)>0.5).int().cpu().numpy())

## Final Model Export (and Fine-Tuning)

As for the SuperNet, we can use the `model.export()` method to obtain a standard `nn.Module` after the optimization implemented by PIT:

In [None]:
nas_model.train_net_and_nas()
final_model = nas_model.export()
final_model = final_model.to(device)

Let's look at the architecture of the optimized model using `torchinfo`.

In [None]:
print(summary(final_model, batch_shape, depth=5))

**Question**: Look at the exported model summary. Does the number of channels in each layer match with the mask values printed above?


In case of the PIT algorithm, fine-tuning the exported model for some epochs is *more important* than for the SuperNet. This is because, in the same way that PIT *folds* BatchNorm layers before the search, it *unfolds* them during the export. This ensures that the final model has the same architecture of the original one.

You can verify this by testing the model just after export. You will see the accuracy drop significantly. However, few epochs of fine-tuning should suffice to recover the drop, and possibly even improve the final accuracy (thanks to BatchNorm). Let's run them.

In [None]:
from ipynb.fs.defs.I_SuperNet import training_loop

criterion = nn.CrossEntropyLoss()
history = training_loop(SAVE_DIR / 'finetune', TRAINING_CONFIG, final_model, criterion, train_dl, val_dl, device)

Finally, let's evaluate our optimized model on the test set:

In [None]:
test_metrics = icl.evaluate(final_model, criterion, test_dl, device)
print(f'Test Loss: {test_metrics["loss"]}, Test Acc: {test_metrics["acc"]}')

**Question:** Considering SuperNet and PIT combined, by how much did you manage to compress the model size? At what cost in terms of accuracy?

## Saving the Final Model

Let's save the model in a separate location to reuse it more easily in later sessions:

In [None]:
torch.save(final_model, SAVE_DIR / f'final_model.pt')