# Weight quantization

Some (neuromorphic) hardware platforms are constrained to certain degrees of precision. To approximate these constraints in simulation, PyTorch have implemented methods to [quantize](https://pytorch.org/docs/stable/quantization.html) network parameters. Meaning, we artificially constrict parameters to adhere to certain number types or value ranges. 

Please see the [PyTorch blog post on quantization](https://pytorch.org/blog/introduction-to-quantization-on-pytorch/) for a proper introduction to quantization. Please note that this tutorial concerns eager mode quantization and *not* FX graph quantization.

In summary, three types of quantization exists (taken from the PyTorch documentation): 

* dynamic quantization (weights quantized with activations read/stored in floating point and quantized for compute.)
* static quantization (weights quantized, activations quantized, calibration required post training)
* quantization aware training (weights quantized, activations quantized, quantization numerics modeled during training)

## 1. Dynamic quantization

Dynamic quantization can simply be done by calling the `quantize_dynamic` method from the `torch.quantization` module. We will specify the set op layers to quantize.

In [27]:
class SNNModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = torch.nn.Linear(4, 4)
        self.lif = norse.LIFCell()
        
    def forward(self, x, s):
        return self.lif(self.lin(x), s)

In [31]:
# create a model instance
model_fp32 = SNNModel()

# create a quantized model instance
model_int8 = torch.quantization.quantize_dynamic(
    model_fp32,  # the original model
    {norse.LIFCell},  # a set of layers to dynamically quantize
    dtype=torch.qint8)  # the target dtype for quantized weights

# Apply the model
input_fp32 = torch.randn(4, 4, 4, 4)
z, s = model_int8(input_fp32, None)

## 2. Static quantization

Static quantization *will require calibration* to achieve decent performance. Furthermore, we can indicate quantization blocks in the model.

In [38]:
class SNNModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # QuantStub converts tensors from floating point to quantized
        self.quant = torch.quantization.QuantStub()
        self.lin = torch.nn.Linear(4, 4)
        self.lif = norse.LIFCell()
        # DeQuantStub converts tensors from quantized to floating point
        self.dequant = torch.quantization.DeQuantStub()
        
    def forward(self, x, s):
        # Quantize input and state
        x = self.quant(x)
        s = norse.LIFFeedForwardState(self.quant(s.v), self.quant(s.i))
        z, s = self.lif(self.lin(x), s)
        # Dequantize
        s = norse.LIFFeedForwardState(self.dequant(s.v), self.dequant(s.i))
        return z, s

In [45]:
# Create the input
input_fp32 = torch.randn(4, 1, 4, 4)
input_state = norse.LIFFeedForwardState(torch.randn(4, 4), torch.randn(4, 4))

# Prepare the quantized model to the input
model_fp32 = SNNModel()
model_fp32_prepared = torch.quantization.prepare(model_fp32)
model_fp32_prepared(input_fp32, input_state)

# Convert the observed model to a quantized model. This does several things:
# quantizes the weights, computes and stores the scale and bias value to be
# used with each activation tensor, and replaces key operators with quantized
# implementations.
model_int8 = torch.quantization.convert(model_fp32_prepared)

# run the model, relevant calculations will happen in int8
res = model_int8(input_fp32, input_state)

In [46]:
res[0]

tensor([[0., 0., 1., 0.],
        [0., 0., 0., 1.],
        [1., 0., 0., 0.],
        [0., 0., 0., 0.]])

## 3. Quantization aware training

... TBD