# Advanced Usage

Here we will introduce some advanced usage of QSPARSE by topics.

In [5]:
import logging
logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler()])
from qsparse import set_qsparse_options
set_qsparse_options(log_on_created=False)

## Quantization

### Channel-wise Quantization

Channel-wise quantization denotes the technique to use different decimal bits cross different channels, i.e., quantize each channel independently. It is commonly known that channel-wise quantization can reduce quantization error drastically especially when inter-channel numerical ranges have large variance. By default, `quantize` will apply channel-wise quantization along the dimension of index `1`, and it can be configured through the `channelwise` parameter:

In [2]:
import torch
from qsparse import quantize
tensor = torch.rand(10, 20, 30)


In [3]:
quantize_layer = quantize(bits=8)
quantize_layer(tensor)
quantize_layer.decimal.shape

torch.Size([20])

In [4]:
quantize_layer = quantize(bits=8, channelwise=2) # quantize along dimension 2
quantize_layer(tensor)
quantize_layer.decimal.shape

torch.Size([30])

In [5]:
quantize_layer = quantize(bits=8, channelwise=-1) # disable channel-wise quantization
quantize_layer(tensor)
quantize_layer.decimal.shape

torch.Size([1])

Various inference engines support different schemes of channel-wise quantization, which can also vary according to types of layers, e.g. [Conv2d](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html) and [ConvTranspose2d](https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html) are different on the weight layout. The `channelwise` parameter can be used to adjust the network training to match the inference setup.

### Bias Quantization

By default, for weight quantization, `quantize` will only quantize the `weight` parameter and leave the `bias` parameter to have full precision ([Jacob et al.](https://arxiv.org/abs/1712.05877)). The reason is that bias can be used to initialize the high precision accumulator for the mult-add operations. Bias can be quantized in QSPARSE by:


In [6]:
import torch.nn as nn
from qsparse import quantize

quantize(nn.Conv2d(1, 1, 1), bits=8, bias_bits=12)

Conv2d(
  1, 1, kernel_size=(1, 1), stride=(1, 1)
  (quantize): QuantizeLayer()
  (quantize_bias): QuantizeLayer()
)

### Saturated Quantization

When quantizing generative adversarial networks (GAN), we found the estimation of the optimal decimal bits $d^*$ can be sensitive to outliers in the activation distribution, which leads to an under-representation of quantized values.


<figure style="text-align:center;font-style:italic"> 
  <img src="../docs/assets/long-tail.jpg" />
  <figcaption>Long-tailed distribution in the activation space of GANs.</figcaption>
</figure>

We discover that this under-representation issue can be simply and effectively moderated by clipping the outliers. Let hyper parameters $q_l, q_u$ denote the lower and upper quantile of the activation distribution, we can calculate the optimal bits $d^*$ by:

$$
d^* = \arg \min_{d} \Vert Q_u(\mathbf{x}_{t_q}, d) - S_{q_l, q_u}(\mathbf{x}_{t_q})
$$

$$
S_{q_l, q_u}(\mathbf{x}) = \text{clip}(\mathbf{x}, \small{\text{quantile}}(\mathbf{x}, q_l), \small{\text{quantile}}(\mathbf{x}, q_u)) 
$$

This can be implemented in QSPARSE by:

In [19]:
from qsparse import quantize

# saturate_range=(q_u, q_l)
# by default, saturate_range = (0, 1), means no saturation is applied
(quantize(bits=8).saturate_range, 
 quantize(bits=8, saturate_range=(0.0001, 0.9999)).saturate_range)

((0, 1), (0.0001, 0.9999))

## Pruning

### Structured Pruning

Structure pruning denotes the technique to enforce a topology (e.g. pruning over certain dimension) over the binary mask created through pruning procedure. By default, `prune` will apply unstructured pruning, but it supports dimension-wise structured pruning through following configuration:  

In [8]:
import torch
from qsparse import prune, structured_prune_callback
tensor = torch.rand(10, 20, 30, 40)

In [9]:
# pruning on all dimensions
quantize_layer = prune(sparsity=0.5, collapse=-1) 
quantize_layer(tensor)
quantize_layer.mask.shape 

torch.Size([10, 20, 30, 40])

In [10]:
# only prune on the dimension 0, 1
quantize_layer = prune(sparsity=0.5, collapse=-1,
    callback=lambda *args: structured_prune_callback(*args, prunable={0, 1})) 
quantize_layer(tensor)
quantize_layer.mask.shape 

torch.Size([10, 20, 1, 1])


The `callback` parameter can be used to extended the functionalities of both `quantize` and `prune`, which we will revisit with details in [Extending Methods of Quantization And Pruning](#extending-methods-of-quantization-and-pruning). The `collapse` parameter is used to configure which dimension to ignore (e.g. batch dimension) when creating binary masks, which is set automatically according to pruning target (i.e. weight or activation). We explicitly set it to `-1` for more clarity in this example.

### Window Size for Activation Pruning

Contrary to weights which are static during inference, activations are dynamically conditioned on the input to the network. To create a stable binary mask for activation pruning, we introduce a sliding window technique: 

$$
\mathbf{M}_{\mathbf{h}_t,s}(i,j) = \begin{cases}
		1 & \sum\limits_{n=0}^{T-1} |\mathbf{h}_{t-n}(i, j)| \ge \text{quantile}(\sum\limits_{n=0}^{T-1} |\mathbf{h}_{t-n}|, s) \\
		0 & \text{otherwise}
\end{cases}  
$$

where $\mathbf{h}_t$ denote the activations at time $t$ and $T$ denote the size of the sliding window. This can be implemented in QSPARSE by: 

In [11]:
from qsparse import prune

print(prune(sparsity=0.5).window_size) # by default T = 1
print(prune(sparsity=0.5, window_size=10).window_size) # T = 10

1
10


### Change input sizes during evaluation 

By default, QSPARSE learns a binary mask for each feature map. The shape of the binary mask is identical to the corresponding activation. Therefore, error will occur if we vary the input size:

In [6]:
import torch
from qsparse import prune
data = torch.rand((1, 10, 32, 32))
data2x = torch.rand((1, 10, 64, 64))

prune_layer = prune(sparsity=0.5)
prune_layer(data)
prune_layer.eval()
try:
    prune_layer(data2x)
except RuntimeError as e:
    print(f"Catch error as expect: {e}")

Catch error as expect: The expanded size of the tensor (64) must match the existing size (32) at non-singleton dimension 3.  Target sizes: [1, 10, 64, 64].  Tensor sizes: [10, 32, 32]


QSPARSE provides an option to expand the binary mask to matched the input tensor shape during evaluation, which is useful for tasks like super resolution. This option can be enabled by setting `strict` parameter to False:

In [7]:
prune_layer = prune(sparsity=0.5, strict=False)
prune_layer(data)
prune_layer.eval()
prune_layer(data2x)
print('Able to change input shape during evaluation')

Able to change input shape during evaluation


## More Functionalities

### Resuming from Checkpoint 

Both `quantize` and `prune` layers support to resume training from a checkpoint. However, due to the fact that: 

1. QSPARSE determines the shape of its parameters (e.g. `decimal`, `mask`) at the first forward pass.
2. `load_state_dict` does not allow shape mismatch ([pytorch/issues#40859](https://github.com/pytorch/pytorch/issues/40859))

Therefore, it is necessary to ensure the QSPARSE parameters have been initialized before loading the state dict, which can easily be done with a forward pass:

In [12]:
import torch.nn as nn
from qsparse import quantize, prune

def make_conv():
    return quantize(prune(nn.Conv2d(1, 32, 3), 
                        sparsity=0.5, start=200, 
                        interval=10, repetition=4), 
                bits=8, timeout=100)

conv = make_conv()

for _ in range(241):
    conv(torch.rand(10, 1, 28, 28))

try:
    conv2 = make_conv()
    conv2.load_state_dict(conv.state_dict())
except RuntimeError as e:
    print(f'\nCatch error as expected: {e}\n' )

conv3 = make_conv()
conv3(torch.rand(10, 1, 28, 28))
conv3.load_state_dict(conv.state_dict())

tensor = torch.rand(10, 1, 28, 28)
assert torch.all(conv(tensor) == conv3(tensor)).item() == True
print('successfully loading from checkpoint')

[Quantize] (channelwise) avg decimal = 8.0
[Prune] [Step 210] active 0.72, pruned 0.28, window_size = 1
[Prune] [Step 220] active 0.57, pruned 0.43, window_size = 1
[Prune] [Step 230] active 0.51, pruned 0.49, window_size = 1
[Prune] [Step 240] active 0.50, pruned 0.50, window_size = 1

Catch error as expected: Error(s) in loading state_dict for Conv2d:
	size mismatch for prune.mask: copying a param with shape torch.Size([32, 1, 3, 3]) from checkpoint, the shape in current model is torch.Size([]).

successfully loading from checkpoint


### Better Logging

QSPARSE will log quantizing and pruning progress into the console. We can use a function `auto_name_prune_quantize_layers` to associate these logs to the specific layer names and therefore increase the transparency of the training program:

In [8]:
import torch
import torch.nn as nn
from collections import OrderedDict
from qsparse import auto_name_prune_quantize_layers, quantize, prune

net = nn.Sequential(
    OrderedDict([
        ('conv', quantize(prune(nn.Conv2d(1, 32, 3, bias=False), 
                        sparsity=0.5, start=200, 
                        interval=10, repetition=4), 
                    bits=8, timeout=100, channelwise=0)),
        ('conv_output_pruning', prune(sparsity=0.5, start=200, interval=10, repetition=4)),
        ('conv_output_quantization', quantize(bits=8, timeout=100)),
    ])
)

net = auto_name_prune_quantize_layers(net)

for _ in range(241):
    net(torch.rand(10, 1, 28, 28))

INFO:root:[Quantize @ conv.quantize] (channelwise) avg decimal = 8.125
INFO:root:[Quantize @ conv_output_quantization] (channelwise) avg decimal = 7.34375
INFO:root:[Prune @ conv.prune] [Step 210] active 0.72, pruned 0.28, window_size = 1
INFO:root:[Prune @ conv_output_pruning] [Step 210] active 0.71, pruned 0.29, window_size = 1
INFO:root:[Prune @ conv.prune] [Step 220] active 0.57, pruned 0.43, window_size = 1
INFO:root:[Prune @ conv_output_pruning] [Step 220] active 0.56, pruned 0.44, window_size = 1
INFO:root:[Prune @ conv.prune] [Step 230] active 0.51, pruned 0.49, window_size = 1
INFO:root:[Prune @ conv_output_pruning] [Step 230] active 0.51, pruned 0.49, window_size = 1
INFO:root:[Prune @ conv.prune] [Step 240] active 0.50, pruned 0.50, window_size = 1
INFO:root:[Prune @ conv_output_pruning] [Step 240] active 0.50, pruned 0.50, window_size = 1


### Inspecting Parameters of a Pruned/Quantized Model

Parameters of a quantized and pruned networks can be easily inspected and therefore post-processed for use cases such as compiling for neural engines: 

In [14]:
# continue from the `net` variable from the previous example
state_dict = net.state_dict()
for k,v in state_dict.items():
    print(k, v.numpy().shape)

net.state_dict()['conv.quantize.decimal'].numpy()

conv.weight (32, 1, 3, 3)
conv.prune.mask (32, 1, 3, 3)
conv.prune._n_updates (1,)
conv.prune._cur_sparsity (1,)
conv.quantize.decimal (32,)
conv.quantize._n_updates (1,)
conv.quantize.bits ()
conv.quantize._quantized (1,)
conv_output_pruning.mask (32, 26, 26)
conv_output_pruning._n_updates (1,)
conv_output_pruning._cur_sparsity (1,)
conv_output_quantization.decimal (32,)
conv_output_quantization._n_updates (1,)
conv_output_quantization.bits ()
conv_output_quantization._quantized (1,)


array([9., 8., 8., 8., 8., 8., 8., 8., 9., 8., 8., 8., 8., 8., 8., 8., 8.,
       8., 8., 8., 8., 8., 9., 8., 8., 8., 8., 8., 8., 8., 8., 8.],
      dtype=float32)


| Param           | Description                                                                  |
|-----------------|------------------------------------------------------------------------------|
| `mask`          | binary mask for pruning                                                      |
| `decimal`       | bits that used to represent fractionals                                      |
| `bits`          | total number of bits                                                         |
| `_n_updates`    | internal counter for number of training steps                                |
| `_cur_sparsity` | internal variable to record current sparsity                                 |
| `_quantized`    | internal boolean variable to record whether quantization has been triggered. |


### Integer Arithmetic Verification

Using the parameters from the previous section, we give an example in the following to implement the computation with 8-bit integer arithmetic and fully match with floating-point results: 

In [15]:
decimal_inp = 7 # input decimal bits
input_int = torch.randint(-128, 127, size=(10, 1, 28, 28)).int()
input_float = input_int.float() / 2 ** decimal_inp
output_float = net(input_float)

In [16]:
import torch.nn.functional as F
weight_int = torch.clamp((state_dict['conv.weight'] * 
            ((2.0 ** state_dict['conv.quantize.decimal']).view(-1, 1, 1, 1))).int(), 
        -128, 127) * state_dict['conv.prune.mask'].int()

output_int_high_precision = F.conv2d(input_int, weight_int)

# simulating bit shifting in integer arithmetic computation engine
for i in range(output_int_high_precision.shape[1]):
    output_int_high_precision[:, i] = (
        output_int_high_precision[:, i].float() /
         2 ** (decimal_inp + state_dict['conv.quantize.decimal'][i]
                 - state_dict['conv_output_quantization.decimal'][i])
    ).int()

output_int = torch.clamp(output_int_high_precision, min=-128, max=127) * state_dict['conv_output_pruning.mask'].int()

In [17]:
import numpy as np
diff = (
    output_float.detach().numpy() - (output_int.float() / 2 ** state_dict['conv_output_quantization.decimal'].view(1, -1, 1, 1)).detach().numpy()
)
assert np.all(diff == 0)
print('Fully match with integer arithmetic!')

Fully match with integer arithmetic!


### Extending Methods of Quantization And Pruning

Both `quantize` and `prune` have the `callback` parameter, which is a function that operates on the tensor level to implement the quantization and pruning operations. Currently, QSPARSE provides the following built-in callbacks.

- Quantization:
    - [linear_quantize_callback](../reference/quantize/#qsparse.quantize.linear_quantize_callback)
- Pruning:
    - [unstructured_prune_callback](../reference/sparse/#qsparse.sparse.unstructured_prune_callback)
    - [structured_prune_callback](../reference/sparse/#qsparse.sparse.structured_prune_callback)

The type signatures for quantization and pruning can be found at [QuantizeCallback](../reference/common/#qsparse.common.QuantizeCallback) and [PruneCallback](../reference/common/#qsparse.common.PruneCallback). The `callback` parameter can be used as an interface to extend QSPARSE.
