# Advanced Usage

Here we will introduce some advanced usage of QSPARSE by topics. More information can be found at API Reference. 

In [1]:
%load_ext autoreload
%autoreload 2
import logging
logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler()])
from qsparse import set_qsparse_options
set_qsparse_options(log_on_created=False)

## Layerwise Pruning 

The function `devise_layerwise_pruning_schedule` will traverse all `pruning operator` throughout the network from input and assign the step for each operator to be activated, to ensure that each pruning operator is activated after all its preceding layers are pruned. The motivation and algorithm details can be found in our MDPI publication. 

In [2]:
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from qsparse.sparse import prune, devise_layerwise_pruning_schedule


net = nn.Sequential(nn.Conv2d(3, 3, 3), 
                    prune(sparsity=0.5),  # no need to specify `start, repetition, interval`
                    nn.Conv2d(3, 3, 3), 
                    prune(sparsity=0.5))

devise_layerwise_pruning_schedule(net, start=1, interval=10) # notice the `start` of each prune layer increases 

[31mPruning stops at iteration - 23[0m


Sequential(
  (0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
  (1): PruneLayer(sparsity=0.5, start=1, interval=10, repetition=1, dimensions={1})
  (2): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
  (3): PruneLayer(sparsity=0.5, start=12, interval=10, repetition=1, dimensions={1})
)

## Network Conversion

The function `convert` comes in handy in producing pruned and quantized network instance without touching the existing floating-point network implementation. Here we introduce some frequent usage.

### 1. Inserting pruning operator after all ReLU layers

In [3]:
from collections import OrderedDict
from qsparse import convert, quantize, prune

net = nn.Sequential(OrderedDict([
        ("first_half", nn.Sequential(nn.Conv2d(3, 3, 3), nn.ReLU())),
        ("second_half", nn.Sequential(nn.Conv2d(3, 3, 3), nn.ReLU()))]))

convert(net, prune(sparsity=0.5), activation_layers=[nn.ReLU], inplace=False)

Apply `prunesparsity=0.5, start=1000, interval=1000, repetition=4, dimensions={1}` on the .first_half.1 activation
Apply `prunesparsity=0.5, start=1000, interval=1000, repetition=4, dimensions={1}` on the .second_half.1 activation


Sequential(
  (first_half): Sequential(
    (0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
    (1): Sequential(
      (0): ReLU()
      (1): PruneLayer(sparsity=0.5, start=1000, interval=1000, repetition=4, dimensions={1})
    )
  )
  (second_half): Sequential(
    (0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
    (1): Sequential(
      (0): ReLU()
      (1): PruneLayer(sparsity=0.5, start=1000, interval=1000, repetition=4, dimensions={1})
    )
  )
)

### 2. Applying the quantization operator on the weight of all Conv2D layers

In [4]:
convert(net, quantize(bits=4), weight_layers=[nn.Conv2d], inplace=False)

Apply `quantizebits=4, timeout=1000, callback=scalerquantizer, channelwise=1` on the .first_half.0 weight
Apply `quantizebits=4, timeout=1000, callback=scalerquantizer, channelwise=1` on the .second_half.0 weight


Sequential(
  (first_half): Sequential(
    (0): Conv2d(
      3, 3, kernel_size=(3, 3), stride=(1, 1)
      (quantize): QuantizeLayer(bits=4, timeout=1000, callback=ScalerQuantizer, channelwise=1)
    )
    (1): ReLU()
  )
  (second_half): Sequential(
    (0): Conv2d(
      3, 3, kernel_size=(3, 3), stride=(1, 1)
      (quantize): QuantizeLayer(bits=4, timeout=1000, callback=ScalerQuantizer, channelwise=1)
    )
    (1): ReLU()
  )
)

### 3. Applying (1) and (2), but excluding the last ReLU and the first Conv2D layer

In [5]:
convert(convert(net, prune(sparsity=0.5), activation_layers=[nn.ReLU], 
                excluded_activation_layer_indexes=[(nn.ReLU, [-1])], inplace=False), 
        quantize(bits=4), weight_layers=[nn.Conv2d],
        excluded_weight_layer_indexes=[(nn.Conv2d, [0])], inplace=False)

Apply `prunesparsity=0.5, start=1000, interval=1000, repetition=4, dimensions={1}` on the .first_half.1 activation
Exclude .second_half.1 activation
Exclude .first_half.0 weight
Apply `quantizebits=4, timeout=1000, callback=scalerquantizer, channelwise=1` on the .second_half.0 weight


Sequential(
  (first_half): Sequential(
    (0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
    (1): Sequential(
      (0): ReLU()
      (1): PruneLayer(sparsity=0.5, start=1000, interval=1000, repetition=4, dimensions={1})
    )
  )
  (second_half): Sequential(
    (0): Conv2d(
      3, 3, kernel_size=(3, 3), stride=(1, 1)
      (quantize): QuantizeLayer(bits=4, timeout=1000, callback=ScalerQuantizer, channelwise=1)
    )
    (1): ReLU()
  )
)

### 4. Only insert pruning at the first half of the network


In [6]:
convert(net, prune(sparsity=0.5), activation_layers=[nn.ReLU], include=['first'], inplace=False)
# or convert(net, prune(sparsity=0.5), activation_layers=[nn.ReLU], exclude=['second'], inplace=False)

Apply `prunesparsity=0.5, start=1000, interval=1000, repetition=4, dimensions={1}` on the .first_half.1 activation
Exclude .second_half.1 activation


Sequential(
  (first_half): Sequential(
    (0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
    (1): Sequential(
      (0): ReLU()
      (1): PruneLayer(sparsity=0.5, start=1000, interval=1000, repetition=4, dimensions={1})
    )
  )
  (second_half): Sequential(
    (0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
  )
)

### 5. Inserting pruning operator before all Conv2D layers

In [7]:
convert(net, prune(sparsity=0.5), activation_layers=[nn.Conv2d], order="pre", inplace=False)

Apply `prunesparsity=0.5, start=1000, interval=1000, repetition=4, dimensions={1}` on the .first_half.0 activation
Apply `prunesparsity=0.5, start=1000, interval=1000, repetition=4, dimensions={1}` on the .second_half.0 activation


Sequential(
  (first_half): Sequential(
    (0): Sequential(
      (0): PruneLayer(sparsity=0.5, start=1000, interval=1000, repetition=4, dimensions={1})
      (1): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
    )
    (1): ReLU()
  )
  (second_half): Sequential(
    (0): Sequential(
      (0): PruneLayer(sparsity=0.5, start=1000, interval=1000, repetition=4, dimensions={1})
      (1): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
    )
    (1): ReLU()
  )
)

## More Quantization

### Symmetric Quantization with Scaler

The class `ScalerQuantizer` implements the algorithm 3 in our MDPI paper. Similarly, the class `DecimalQuantizer` shares the exact same implementation except the scaling factor is always restricted to be a power of 2. Their instances can be passed to the `callback` argument of `quantize`, like: 

In [8]:
from qsparse.quantize import DecimalQuantizer

quantize(bits=8, callback=DecimalQuantizer())

QuantizeLayer(bits=8, timeout=1000, callback=DecimalQuantizer, channelwise=1)

The `ScalerQuantizer` and `DecimalQuantizer` includes the functions of both inference and parameters learning. To access only the inference function to quantize tensors, one can use functions `quantize_with_scaler` and `quantize_with_decimal`: 

In [9]:
import torch
from qsparse.quantize import quantize_with_decimal, quantize_with_scaler

data = torch.rand(1000)

((data - quantize_with_decimal(data, bits=8, decimal=6))**2).mean(), ((data - quantize_with_scaler(data, bits=8, scaler=0.01))**2).mean()

(tensor(8.4718e-05), tensor(8.5786e-06))

### Asymmetric Quantization

The class `AdaptiveQuantizer` implements the algorithm 2 in our MDPI paper, which estimates the lower and upper bounds of incoming data streams and apply assymmetric quantization. Its inference function can be accessed from `quantize_with_line`.


In [10]:
from qsparse.quantize import AdaptiveQuantizer

quantize(bits=8, callback=AdaptiveQuantizer())

QuantizeLayer(bits=8, timeout=1000, callback=AdaptiveQuantizer, channelwise=1)

In [11]:
from qsparse.quantize import quantize_with_line

((data - quantize_with_line(data, bits=8, lines=(0, 1)))**2).mean() # lines specify the (lower, upper) bounds. 

tensor(1.3243e-06)

### Channelwise Quantization

Channel-wise quantization denotes the technique to use different decimal bits cross different channels, i.e., quantize each channel independently. It is commonly known that channel-wise quantization can reduce quantization error drastically especially when inter-channel numerical ranges have large variance. 

To specify channelwise quantization on dimension 1 (dimension 1 as channel): 

In [12]:
quantize(bits=8, channelwise=1)

QuantizeLayer(bits=8, timeout=1000, callback=ScalerQuantizer, channelwise=1)

To disable channelwise quantization:

In [13]:
quantize(bits=8, channelwise=-1)

QuantizeLayer(bits=8, timeout=1000, callback=ScalerQuantizer, channelwise=-1)

### Groupwise Quantization

Channelwise quantization allocates one set of scaling factor and zero-point for each channel, which could possibly complicate the inference implementation when both weight and activations are quantized channel-wisely, especially for networks with a large number of channels. Here, we provide a technique, which we name as _groupwise quantization_. Specifically, we cluster the channel-wise quantization parameters (scaling factor and zero-points) into groups, and share one set of quantization parameter within each group. We empirically find that groupwise quantization yields little to no performance drop compared to channelwise pruning, even with an extremely small group number, e.g. 4. 

In [14]:

layer = quantize(bits=8, channelwise=1, 
                 callback=DecimalQuantizer(group_num=4, 
                                   # `group_timeout` denotes the steps when the clustering starts after the activation of the quantization operator. 
                                   group_timeout=10), timeout=10)
for _ in range(21):
    layer(torch.rand(1, 1024, 3, 3))

[33mquantizing  with 8 bits[0m
[31mclustering 1024 channels into 4 groups[0m


For a convolution layer with 1024 channels, using groupwise quantization with 4 groups produces a 256 times of reduction in the number of quantization parameters.

### Quantization Bias

By default, for weight quantization, quantize will only quantize the weight parameter and leave the bias parameter to have full precision [(Jacob et al.)](https://arxiv.org/abs/1712.05877). The reason is that bias can be used to initialize the high precision accumulator for the mult-add operations. Bias can be quantized in QSPARSE by:

In [15]:
from qsparse import quantize

quantize(nn.Conv2d(1, 1, 1), bits=8, bias_bits=12)

Conv2d(
  1, 1, kernel_size=(1, 1), stride=(1, 1)
  (quantize): QuantizeLayer(bits=8, timeout=1000, callback=ScalerQuantizer, channelwise=1)
  (quantize_bias): QuantizeLayer(bits=12, timeout=1000, callback=ScalerQuantizer, channelwise=0)
)

### Integer Arithmetic Verification

Here we provide an example to demonstrate floating-point simulated quantization can fully match with 8-bit integer arithmetic.

In [16]:
ni = 7 # input shift
no = 6 # output shift

input = torch.randint(-128, 127, size=(3, 10, 32, 32))
input_float = input.float() / 2 ** ni

Quantization computation simulated with floating-point:

In [17]:
timeout = 5
qconv = quantize(
    torch.nn.Conv2d(10, 30, 3, bias=False), bits=8, timeout=timeout, channelwise=0, callback=DecimalQuantizer()
) 
qconv.train()
for _ in range(timeout + 1):  # ensure the quantization has been triggered
    qconv(input_float)
output_float = quantize_with_decimal(qconv(input_float), 8, no)

[33mquantizing  with 8 bits[0m


Reproduce the above computation in 8-bit arithmetic:

In [18]:
decimal = (1 / qconv.quantize.weight).nan_to_num(posinf=1, neginf=1).log2().round().int()
weight = qconv.weight * (2.0 ** decimal).view(-1, 1, 1, 1)
output_int = F.conv2d(input.int(), weight.int())
for i in range(output_int.shape[1]):
    output_int[:, i] = (
        output_int[:, i].float() / 2 ** (ni + decimal[i] - no)
    ).int()

diff = (
    output_float.detach().numpy() - (output_int.float() / 2 ** no).detach().numpy()
)
assert np.all(diff == 0)

print("Fully match with integer arithmetic")

Fully match with integer arithmetic


## Extras

### Resuming from Checkpoint 

Both `quantize` and `prune` layers support to resume training from a checkpoint. However, due to the fact that: 

1. QSPARSE determines the shape of its parameters (e.g. `scaling factor`, `mask`) at the first forward pass.
2. `load_state_dict` currently does not allow shape mismatch ([pytorch/issues#40859](https://github.com/pytorch/pytorch/issues/40859))

Therefore, we provide the `preload_qsparse_state_dict` to be called before the `load_state_dict` to mitigate the above issue.

In [60]:
from qsparse.util import preload_qsparse_state_dict

def make_conv():
    return quantize(prune(nn.Conv2d(16, 32, 3), 
                        sparsity=0.5, start=200, 
                        interval=10, repetition=4), 
                bits=8, timeout=100)

conv = make_conv()

for _ in range(241):
    conv(torch.rand(10, 16, 7, 7))

try:
    conv2 = make_conv()
    conv2.load_state_dict(conv.state_dict())
except RuntimeError as e:
    print(f'\nCatch error as expected: {e}\n' )

conv3 = make_conv()
preload_qsparse_state_dict(conv3, conv.state_dict())
conv3.load_state_dict(conv.state_dict())

tensor = torch.rand(10, 16, 7, 7)
assert np.allclose(conv(tensor).detach().numpy(), conv3(tensor).detach().numpy(), atol=1e-6)
print('successfully loading from checkpoint')

[33m[Prune] start = 200 interval = 10 repetition = 4 sparsity = 0.5 dimensions = {1}[0m
[Quantize] bits=8 channelwise=1 timeout=100
[33mquantizing  with 8 bits[0m
[33m[Prune] [Step 200] pruned 0.29[0m
[33mStart pruning at  @ 200[0m
[33m[Prune] [Step 210] pruned 0.44[0m
[33m[Prune] [Step 220] pruned 0.49[0m
[33m[Prune] [Step 230] pruned 0.50[0m
[33m[Prune] start = 200 interval = 10 repetition = 4 sparsity = 0.5 dimensions = {1}[0m
[Quantize] bits=8 channelwise=1 timeout=100

Catch error as expected: Error(s) in loading state_dict for Conv2d:
	Unexpected key(s) in state_dict: "prune.callback.magnitude", "quantize.weight", "quantize._n_updates". 
	size mismatch for prune.mask: copying a param with shape torch.Size([1, 16, 1, 1]) from checkpoint, the shape in current model is torch.Size([]).

[33m[Prune] start = 200 interval = 10 repetition = 4 sparsity = 0.5 dimensions = {1}[0m
[Quantize] bits=8 channelwise=1 timeout=100
successfully loading from checkpoint


### Inspecting Parameters of a Pruned/Quantized Model

Parameters of a quantized and pruned networks can be easily inspected and therefore post-processed for use cases such as compiling for neural engines: 

In [24]:
state_dict = conv.state_dict()
for k,v in state_dict.items():
    print(k, v.numpy().shape)

weight (32, 16, 3, 3)
bias (32,)
prune.mask (1, 16, 1, 1)
prune._n_updates (1,)
prune._cur_sparsity (1,)
prune.callback.t (1,)
prune.callback.magnitude (1, 16, 1, 1)
quantize.weight (16, 1)
quantize._n_updates (1,)



| Param           | Description                                                                  |
|-----------------|------------------------------------------------------------------------------|
| `quantize.weight` | scaling factors                                     |
| `*._n_updates`    | internal counter for number of training steps                                |
| `prune.mask`            | binary mask for pruning                                                      |
| `prune._cur_sparsity` | internal variable to record current sparsity                                 |
| `prune.callback.magnitude`    | internal boolean variable to record whether quantization has been triggered. |