In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

from torch import nn
from torch.ao import quantization as q

In [3]:
zero_point = 0
quant_min = -128
quant_max = 127

x = torch.rand(5)
x.requires_grad = True
print(f"x: {x}")

mn = x.min()
mx = x.max()
scale = ((mx - mn) / (quant_max - quant_min)).item()

print(f"mn: {mn}")
print(f"mx: {mx}")
print(f"scale: {scale}")


print("Forward")
x_fq = torch.fake_quantize_per_tensor_affine(x, scale, zero_point, quant_min, quant_max)
print(f"Input:  {x.detach().numpy()}")
print(f"Output: {x_fq.detach().numpy()}")

print("Backward")
loss = x_fq.sum()
loss.backward()

print(f"grad: {x.grad}")

x: tensor([0.2069, 0.1993, 0.8558, 0.7171, 0.0387], requires_grad=True)
mn: 0.038726806640625
mx: 0.8558149337768555
scale: 0.0032042672391980886
Forward
Input:  [0.20688236 0.19932824 0.85581493 0.7171295  0.03872681]
Output: [0.20827737 0.19866458 0.40694195 0.40694195 0.03845121]
Backward
grad: tensor([1., 1., 0., 0., 1.])


In [13]:
class MyFakeQuantize(nn.Module):
    def __init__(self, observer, quant_min=-128, quant_max=127):
        super().__init__()
        self.observer = observer
        self.quant_min = quant_min
        self.quant_max = quant_max
        self.scale = None
        self.zero_point = None

    def forward(self, x):
        if self.training:
            self.observer(x)
            self.scale, self.zero_point = self.observer.calculate_qparams()

            x_fq = torch.fake_quantize_per_tensor_affine(
                x,
                scale=self.scale,
                zero_point=self.zero_point,
                quant_min=self.quant_min,
                quant_max=self.quant_max,
            )

            return x_fq
        else:
            return torch.fake_quantize_per_tensor_affine(
                x, self.scale, self.zero_point, self.quant_min, self.quant_max
            )


In [14]:
class QuantConv2d(nn.Module):    
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        padding=0,
        bias=True,
        weight_quant_min=-128,
        weight_quant_max=127,
        activation_quant_min=0,
        activation_quant_max=255,
    ):
        super().__init__()
        
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            bias=bias,
        )
        
        self.weight_fake_quant = MyFakeQuantize(
            observer=q.MinMaxObserver(
                dtype=torch.qint8,
                qscheme=torch.per_tensor_symmetric,
            ),
            quant_min=weight_quant_min,
            quant_max=weight_quant_max,
        )
        
        self.activation_fake_quant = MyFakeQuantize(
            observer=q.MinMaxObserver(
                dtype=torch.quint8,
                qscheme=torch.per_tensor_affine,
            ),
            quant_min=activation_quant_min,
            quant_max=activation_quant_max,
        )
        
    def forward(self, x):
        x_quantized = self.activation_fake_quant(x)
        
        weight_quantized = self.weight_fake_quant(self.conv.weight)
        
        output = torch.nn.functional.conv2d(
            x_quantized,
            weight_quantized,
            bias=self.conv.bias,
            stride=self.conv.stride,
            padding=self.conv.padding,
        )
        
        return output

In [None]:
qconv = QuantConv2d(
    in_channels=3,
    out_channels=16,
    kernel_size=3,
    padding=1,
)

x_test = torch.randn(1, 3, 32, 32)

qconv.train()
output_train = qconv(x_test)
print(f"Input shape: {x_test.shape}")
print(f"Output shape: {output_train.shape}")
print(f"Weight scale: {qconv.weight_fake_quant.scale}")
print(f"Weight zero_point: {qconv.weight_fake_quant.zero_point}")
print(f"Activation scale: {qconv.activation_fake_quant.scale}")
print(f"Activation zero_point: {qconv.activation_fake_quant.zero_point}")

Input shape: torch.Size([1, 3, 32, 32])
Output shape: torch.Size([1, 16, 32, 32])
Weight scale: tensor([0.0015])
Weight zero_point: tensor([0])
Activation scale: tensor([0.0285])
Activation zero_point: tensor([125], dtype=torch.int32)
