In [3]:
import Ipynb_importer
from a_basic_quant import *

importing Jupyter notebook from a_basic_quant.ipynb


### 1、定义基本模型（无bn）

In [2]:
class Net(nn.Module):

    def __init__(self, num_channels=1):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, 40, 3, 1)
        self.conv2 = nn.Conv2d(40, 40, 3, 1, groups=20)
        self.fc = nn.Linear(5*5*40, 10)
    
    # 正常模型推理
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 5*5*40)
        x = self.fc(x)
        return x

    def quantize_init(self, num_bits=8):
        # 根据 num_bits 初始化变量（存储每个层量化后的参数和激活值）
        self.qconv1 = QConv2d(self.conv1, has_qin=True, has_qout=True, num_bits=num_bits)
        self.qrelu1 = QReLU()
        self.qmaxpool2d_1 = QMaxPooling2d(kernel_size=2, stride=2, padding=0)
        self.qconv2 = QConv2d(self.conv2, has_qin=False, has_qout=True, num_bits=num_bits)
        self.qrelu2 = QReLU()
        self.qmaxpool2d_2 = QMaxPooling2d(kernel_size=2, stride=2, padding=0)
        self.qfc = QLinear(self.fc, has_qin=False, has_qout=True, num_bits=num_bits)

    def quantize_forward(self, x):
        # 统计参数|激活值的 min、max，计算scale和zero_point
        x = self.qconv1(x)
        x = self.qrelu1(x)
        x = self.qmaxpool2d_1(x)
        x = self.qconv2(x)
        x = self.qrelu2(x)
        x = self.qmaxpool2d_2(x)
        x = x.view(-1, 5*5*40)
        x = self.qfc(x)
        return x

    def freeze(self):
        # 在得到计算的scale和zero_point后，量化权值和偏置
        self.qconv1.freeze()
        self.qrelu1.freeze(self.qconv1.q_out)
        self.qmaxpool2d_1.freeze(self.qconv1.q_out)
        self.qconv2.freeze(q_in=self.qconv1.q_out)
        self.qrelu2.freeze(self.qconv2.q_out)
        self.qmaxpool2d_2.freeze(self.qconv2.q_out)
        self.qfc.freeze(q_in=self.qconv2.q_out)

    def quantize_inference(self, x):
        # 量化推理使用的函数
        qx = self.qconv1.q_in.quantize_tensor(x)  # 输入量化到int8
        qx = self.qconv1.quantize_inference(qx)
        qx = self.qrelu1.quantize_inference(qx)
        qx = self.qmaxpool2d_1.quantize_inference(qx)
        qx = self.qconv2.quantize_inference(qx)
        qx = self.qrelu2.quantize_inference(qx)
        qx = self.qmaxpool2d_2.quantize_inference(qx)
        qx = qx.view(-1, 5*5*40)
        qx = self.qfc.quantize_inference(qx)
        out = self.qfc.q_out.dequantize_tensor(qx)  # 输出反量化到 float
        return out


### 2、定义基本模型（bn)

In [3]:
class NetBN(nn.Module):

    def __init__(self, num_channels=1):
        super(NetBN, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, 40, 3, 1)
        self.bn1 = nn.BatchNorm2d(40)
        self.conv2 = nn.Conv2d(40, 40, 3, 1)
        self.bn2 = nn.BatchNorm2d(40)
        self.fc = nn.Linear(5 * 5 * 40, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2)
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 5 * 5 * 40)
        x = self.fc(x)
        return x

    def quantize(self, num_bits=8):
        self.qconv1 = QConvBNReLU(self.conv1, self.bn1, has_qin=True, has_qout=True, num_bits=num_bits)
        self.qmaxpool2d_1 = QMaxPooling2d(kernel_size=2, stride=2, padding=0)
        self.qconv2 = QConvBNReLU(self.conv2, self.bn2, has_qin=False, has_qout=True, num_bits=num_bits)
        self.qmaxpool2d_2 = QMaxPooling2d(kernel_size=2, stride=2, padding=0)
        self.qfc = QLinear(self.fc, has_qin=False, has_qout=True, num_bits=num_bits)

    def quantize_forward(self, x):
        x = self.qconv1(x)
        x = self.qmaxpool2d_1(x)
        x = self.qconv2(x)
        x = self.qmaxpool2d_2(x)
        x = x.view(-1, 5*5*40)
        x = self.qfc(x)
        return x

    def freeze(self):
        self.qconv1.freeze()
        self.qmaxpool2d_1.freeze(self.qconv1.q_out)
        self.qconv2.freeze(q_in=self.qconv1.q_out)
        self.qmaxpool2d_2.freeze(self.qconv2.q_out)
        self.qfc.freeze(q_in=self.qconv2.q_out)

    def quantize_inference(self, x):
        qx = self.qconv1.q_in.quantize_tensor(x)
        qx = self.qconv1.quantize_inference(qx)
        qx = self.qmaxpool2d_1.quantize_inference(qx)
        qx = self.qconv2.quantize_inference(qx)
        qx = self.qmaxpool2d_2.quantize_inference(qx)
        qx = qx.view(-1, 5*5*40)

        qx = self.qfc.quantize_inference(qx)
        
        out = self.qfc.q_out.dequantize_tensor(qx)
        return out
