# Tutorial

In this tutorial, we will give a brief introduction on the quantization and pruning techniques upon which QSPARSE is built. Using our library, we guide you through the building of a image classification neural network, whose both weights and activations are fully quantized and pruned to a given sparsity level.

> If you are already familiar with quantization and pruning methods and want to learn the programming syntax, please fast forward to [Building Network with QSPARSE](#building-network-with-qsparse).

## Preliminaries

Quantization and pruning are core techniques used to reduce the inference costs of deep neural networks and have been  studied extensively. Approaches to quantization are often divided into two categories: 

1. Post-training quantization
2. Quantization aware training

The former applies quantization after a network has been trained, and the latter quantizes the network during training and thereby reduces the quantization error throughout training process and usually yields superior performance. 

Pruning techniques are often divided into unstructured or structured approaches which define if and how to impose a pre-defined topology, e.g. channel-wise pruning. 

Here, we focus on applying quantization and unstructured pruning during training.

<figure style="text-align:center;font-style:italic"> 
  <img src="../docs/assets/network_diagram-p1.svg" />
  <figcaption>Conceptual diagram of the computational graph of a network whose weights and activations are quantized and pruned using QSPARSE.</figcaption>
</figure>


In QSPARSE, we implement the quantization and pruning as independent operators, which can be applied on both weights and activations, as demonstrated in the figure above.

### Uniform Quantization

We denote the uniform quantization operation as $Q_u(\mathbf{x}, d)$, where $\mathbf{x}$ denotes the input to the operator  (i.e. weights or activations), $N$ denotes the total number of bits used to represent weights and activations, and $d$ denotes the number of bits used to represent the fractional (i.e. the position of the decimal point to the right, we will refer $d$ as decimal bits).

$$
Q_u(\mathbf{x}, d) = \text{clip}(\lfloor\mathbf{x} \times 2^{d}\rfloor, -2^{N-1}, 2^{N-1}-1) / 2^d
$$

Straight-through estimator (STE) is applied to calculate gradients in the backward computation.

$$
\frac{\partial Loss}{\partial \mathbf{x}} = \text{clip}(\frac{\partial Loss}{\partial Q_u(\mathbf{x}, d)}, -2^{N-d-1}, 2^{N-d-1} - 2^{-d})
$$

However, STE is known to be sensitive to weight initialization, therefore, we design the quantization operator as $\text{Quantize}$ in the following. Starting with the original full-precision network, we delay the quantization of the network to later training stages, and calculate the optimal decimal bits $d^*$ by minimizing the quantization error after a given number of update steps $t_q$.

$$
\text{Quantize}(\mathbf{x}_t)  = \begin{cases} 
    \mathbf{x}_t & t < t_q \\
    Q_u(\mathbf{x}_t, d^*)  &    t \ge t_q   \\
    \end{cases} 
$$

$$
d^* = \arg \min_{d} \Vert Q_u(\mathbf{x}_{t_q}, d) - \mathbf{x}_{t_q} \Vert^2
$$


### Magnitude-based Unstructured Pruning

We denote the unstructured pruning operator $\textbf{Prune}(\mathbf{x}, s)$ as element-wise multiplication between $\mathbf{x}$ and $\mathbf{M}_{\mathbf{x},s}$, where $\mathbf{x}$ denotes the input to the operator (i.e., weights or activations), $s$ denotes the target sparsity as measured by the percentage of zero-valued elements, and $\mathbf{M}_{\mathbf{x},s}$ denotes a binary mask.

$$
P(\mathbf{x}, s)  = \mathbf{x} \circ \mathbf{M}_{\mathbf{x},s}
$$

Given that $(i,j)$ are the row and column indices, respectively, the binary mask $\mathbf{M}_{\mathbf{x},s}$ is calculated as belows, where the $\text{quantile}(\mathbf{x}, a)$ is the a-th quantile of $\mathbf{x}$.

$$
\mathbf{M}_{\mathbf{x},s}^{(i,j)}  = \begin{cases}
		1 & |\mathbf{x}^{(i, j)}| \ge \text{quantile}(|\mathbf{x}|, s) \\
		0 & \text{otherwise}
		\end{cases}
$$


As proposed by [Zhu et al.](https://arxiv.org/pdf/1710.01878.pdf), the sparsity level $s$ is controlled and updated according to a sparsification schedule at time steps $t_p + i \Delta t_p$ such that $i \in \{1,2,..,,n\}$, where $t_p$, $\Delta t_p$, and $n$ are hyper parameters that represent the starting pruning step, frequency, and total number of pruning iterations, respectively.

## Building Network with QSPARSE

With the above methods in mind, in the following, we will use QSPARSE to build a quantized and sparse network upon the below full precision network borrowed from pytorch official [MNIST example](https://github.com/pytorch/examples/blob/master/mnist/main.py).

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.bn1 = nn.BatchNorm2d(32)

        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.bn2 = nn.BatchNorm2d(64)

        self.fc1 = nn.Linear(9216, 128)
        self.bn3 = nn.BatchNorm1d(128)

        self.fc2 = nn.Linear(128, 10)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.bn3(self.fc1(x)))
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

Net()

Net(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc1): Linear(in_features=9216, out_features=128, bias=True)
  (bn3): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
  (dropout1): Dropout(p=0.25, inplace=False)
  (dropout2): Dropout(p=0.5, inplace=False)
)

### Weight Quantization and Pruning

<figure style="text-align:center;font-style:italic"> 
  <img src="../docs/assets/network_diagram-p2.svg" />
  <figcaption>The part of diagram in red corresponds to weight quantization and pruning.</figcaption>
</figure>


We can easily create a weight quantized and pruned layer with QSPARSE. Take the convolution as an example:


In [2]:
from qsparse import prune, quantize, set_qsparse_options
set_qsparse_options(log_on_created=False)

conv = quantize(prune(nn.Conv2d(1, 32, 3), 
                      sparsity=0.5, start=200, 
                      interval=10, repetition=4), 
                bits=8, timeout=100)
conv

Conv2d(
  1, 32, kernel_size=(3, 3), stride=(1, 1)
  (prune): PruneLayer()
  (quantize): QuantizeLayer()
)

We can see that `prune` and `quantize` layers are injected. The output layer will behave identically to `nn.Conv2d` except that `conv.weight` will return a quantized and pruned version of the vanilla weight. As for the hyper parameters, they map to QSPARSE arguments as the table below.


| Param        | QSPARSE Argument |
|--------------|-----------------------|
| $N$          | `bits`                |
| $t_q$        | `timeout`             |
| $s$          | `sparsity`            |
| $t_p$        | `start`               |
| $n$          | `repetition`          |
| $\Delta t_p$ | `interval`            |


Both the `prune` and `quantize` layers maintain an internal counter to record the number of training steps that have passed through. The counter values can be accessed through the `_n_updates` attribute. Based on the above specified arguments, `conv.weight` will be quantized from step 100 and pruned with 50% sparsity from step 240, which can be verified by: 

In [3]:
data = torch.rand((1, 1, 32, 32))
for _ in range(241):
    conv(data)

conv.quantize._n_updates

[Quantize] (channelwise) avg decimal = 8.0
[Prune] [Step 210] active 0.72, pruned 0.28, window_size = 1
[Prune] [Step 220] active 0.57, pruned 0.43, window_size = 1
[Prune] [Step 230] active 0.51, pruned 0.49, window_size = 1
[Prune] [Step 240] active 0.50, pruned 0.50, window_size = 1


Parameter containing:
tensor([241], dtype=torch.int32)

In [4]:
(conv.weight * (2**conv.quantize.decimal) 
- (conv.weight * (2**conv.quantize.decimal)).int()).sum().item()

0.0

In [5]:
print(len(conv.prune.mask.nonzero()) / np.prod(conv.prune.mask.shape))
print(np.all((conv.weight.detach().numpy() == 0) 
          == (conv.prune.mask.detach().numpy() == 0)))

0.5034722222222222
True


The `mask` and `decimal` denote the binary mask for pruning and number of fractional bits for quantization, which we will revisit in [Inspecting Parameters of a Pruned/Quantized Model](../advanced_usage/#inspecting-parameters-of-a-prunedquantized-model). The `prune` and `quantize` functions are compatible with any pytorch module as long as their parameters can be accessed from their `weight` attribute. Take another example of fully-connected layer:

In [6]:
quantize(prune(nn.Linear(128, 10), 0.5), 8)

Linear(
  in_features=128, out_features=10, bias=True
  (prune): PruneLayer()
  (quantize): QuantizeLayer()
)

### Activation Quantization and Pruning


<figure style="text-align:center;font-style:italic"> 
  <img src="../docs/assets/network_diagram-p3.svg" />
  <figcaption>The part of diagram in red corresponds to activation quantization and pruning.</figcaption>
</figure>

To prune and quantize and the output of a convolution, we can directly insert `quantize` and `prune` into the computation graph by:


In [7]:
nn.Sequential(
    conv,
    prune(sparsity=0.5, start=200, interval=10, repetition=4),
    quantize(bits=8, timeout=100),
    nn.ReLU()
)

Sequential(
  (0): Conv2d(
    1, 32, kernel_size=(3, 3), stride=(1, 1)
    (prune): PruneLayer()
    (quantize): QuantizeLayer()
  )
  (1): PruneLayer()
  (2): QuantizeLayer()
  (3): ReLU()
)

Similarly, the output of `conv` will be quantized from step 100 and pruned with 50% sparsity from step 240.

### Building a Network with Both Weight and Activation Quantized and Pruned

Using the techniques introduced above, we can implement the `Net` so as to have joint quantization and pruning training capabilities with full transparency and minimal efforts: 

In [8]:

class NetPQ(nn.Module):  
    def __init__(self, epoch_size=100):
        super(NetPQ, self).__init__()
        # input quantization, quantize at epoch 10
        self.qin = quantize(bits=8, timeout=epoch_size * 10) 

        # For the sake of simplicity, we ignore the `timeout,start,repetition,
        # interval` parameters in the following.
        self.conv1 = quantize(nn.Conv2d(1, 32, 3, 1), 8)
        self.bn1 = nn.BatchNorm2d(32)
        self.p1, self.q1 = prune(sparsity=0.5), quantize(bits=8)

        self.conv2 = quantize(prune(nn.Conv2d(32, 64, 3, 1), 0.5), 8)
        self.bn2 = nn.BatchNorm2d(64)
        self.p2, self.q2 = prune(sparsity=0.5), quantize(bits=8)

        self.fc1 = quantize(prune(nn.Linear(9216, 128), 0.5), 8)
        self.bn3 = nn.BatchNorm1d(128)
        self.p3, self.q3 = prune(sparsity=0.5), quantize(bits=8)

        self.fc2 = quantize(nn.Linear(128, 10), 8)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)

    def forward(self, x):
        x = self.qin(x)                                         
        x = F.relu(self.q1(self.p1(self.bn1(self.conv1(x)))))
        x = F.relu(self.q2(self.p2(self.bn2(self.conv2(x)))))
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.q3(self.p3(self.bn3(self.fc1(x)))))
        x = self.dropout2(x)
        x = self.fc2(x)                                         
        output = F.log_softmax(x, dim=1)
        return output

NetPQ()

NetPQ(
  (qin): QuantizeLayer()
  (conv1): Conv2d(
    1, 32, kernel_size=(3, 3), stride=(1, 1)
    (quantize): QuantizeLayer()
  )
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (p1): PruneLayer()
  (q1): QuantizeLayer()
  (conv2): Conv2d(
    32, 64, kernel_size=(3, 3), stride=(1, 1)
    (prune): PruneLayer()
    (quantize): QuantizeLayer()
  )
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (p2): PruneLayer()
  (q2): QuantizeLayer()
  (fc1): Linear(
    in_features=9216, out_features=128, bias=True
    (prune): PruneLayer()
    (quantize): QuantizeLayer()
  )
  (bn3): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (p3): PruneLayer()
  (q3): QuantizeLayer()
  (fc2): Linear(
    in_features=128, out_features=10, bias=True
    (quantize): QuantizeLayer()
  )
  (dropout1): Dropout(p=0.25, inplace=False)
  (dropout2): Dropout(p=0.5, inplace=False)
)


The network created by `NetPQ` is a pytorch module that only consists of its original components and `PruneLayer / QuantizeLayer` introduced by `prune` and `quantize`.
It does not require you to modify the training loop or even the weight initialization code, and it also supports to [resume training from checkpoints](../advanced_usage/#resuming-from-checkpoint).

The full example of training MNIST classifier with different pruning and quantization configurations can be found at [examples/mnist.py](https://github.com/mlzxy/qsparse/blob/main/examples/). More examples can be found in [here](https://github.com/mlzxy/qsparse-examples).

## Summary

In this tutorial, we introduce some basics about joint quantization and pruning training, and the implementation of this training paradigm with QSPARSE. Next, we introduce more [advanced usage](../advanced_usage/).