In [None]:
import torch

##  Quantization schemes
<img src="./img/q_scheme.png" width="600" />

Two sets of schemes:
* Symmetric
* Affine

And

* Per-channel
* Per-Tensor

### Per-Channel and Per-Tensor

<img src="./img/per_t_c.png" width="600" />

In [None]:
x = torch.tensor([
    [0.5827, 0.8619], 
    [0.3827, -0.1982], 
    [-0.8213, 0.6351]])

print(x.size())

In [None]:
# per-tensor

scale = torch.tensor(1e-2)
zero_pt = torch.tensor(0)

xq = torch.quantize_per_tensor(x, scale, zero_pt, dtype=torch.qint8)
print(xq)

In [None]:
# per-channel

channel_axis = 0
scale = torch.tensor([1e-2, 1e-3, 5e-2])
zero_pt = torch.zeros(3)

xq = torch.quantize_per_channel(x, scale, zero_pt, dtype=torch.qint8, axis=0)
print(xq)

### Symmetric and Affine

Symmetric
* Input range is calculated symmetrically around 0
* Good for quantizing weights
* Wasteful for quantizing activations - why?

Affine 
* Clips the input tightly 


<img src="./img/affine-symmetric.png" width="600" />

### Observers

<img src="./img/observer.png" width="600" />

In [None]:
from torch.ao.quantization.observer import MovingAverageMinMaxObserver, HistogramObserver, MovingAveragePerChannelMinMaxObserver

size = (3,4)
normal = torch.distributions.normal.Normal(0,1)
input = [normal.sample(size) for _ in range(3)]

observers = [
    MovingAverageMinMaxObserver(qscheme=torch.per_tensor_affine), 
    HistogramObserver(), 
    MovingAveragePerChannelMinMaxObserver(qscheme=torch.per_channel_symmetric)
    ]



In [None]:
for obs in observers:
  for x in input: 
      obs(x) 
  print(obs.__class__.__name__, obs.calculate_qparams())


### QConfig

* High-level abstraction wrapping these knobs in one object
* Allows separate configuration for activation and weights of a layer

In [None]:
from torch.ao.quantization.observer import MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver
from torch.ao.quantization.qconfig import QConfig

my_qconfig = QConfig(
  activation=MovingAverageMinMaxObserver.with_args(
      qscheme=torch.per_tensor_affine,
      dtype=torch.quint8),
  weight=MovingAveragePerChannelMinMaxObserver.with_args(
      qscheme=torch.per_channel_symmetric)
)


#### Default QConfigs out of the box

In [None]:
torch.quantization.qconfig.default_per_channel_qconfig

In [None]:
print(torch.quantization.qconfig.default_dynamic_qconfig)

In [None]:
print(torch.quantization.qconfig.per_channel_dynamic_qconfig)