freeze(self, q_in, q_out)# 1、后量化训练
## 1.1 线性量化
https://zhuanlan.zhihu.com/p/156835141

scale:
> s = rmax - rmin / qmax - qmin

zero point:
> z = round(qmax - rmax / s)

### 1. 量化基本公式

In [1]:
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# 计算 scale 和 zero_point 的基本公式
def calcu_scale_and_zeropoint(min_val, max_val, num_bits=8):
    q_min = 0.
    q_max = 2. ** num_bits - 1
    scale = float((max_val - min_val) / (q_max - q_min))
    zero_point = np.clip(int(q_max - max_val / scale), q_min, q_max)
    
    return scale, zero_point

In [3]:
# tensor 量化和反量化
def quantize_tensor(x, scale, zero_point, num_bits=8, signed=False):
    if signed:
        q_min = - 2. ** (num_bits - 1)
        q_max = 2. ** (num_bits - 1) - 1.
    else:
        q_min = 0.
        q_max = 2. ** num_bits -1.
        
    q_x = x / scale + zero_point
    q_x.clamp_(q_min, q_max).round()  # q=round(r/S+Z)
    
    return q_x.float()   # 由于pytorch不支持int类型的运算，因此我们还是用float来表示整数

def dequantize_tensor(q_x, scale, zero_point):
    return scale * (q_x -zero_point)  # r=S(q-Z)

### 2. 封装成类

In [4]:
class QParam:
    def __init__(self, num_bits=8):
        self.num_bits = num_bits
        self.scale = None
        self.zero_point = None
        self.min = None
        self.max = None
    
    def update(self, tensor):
        self.max = max(0, tensor.max()) if self.max is None \
                                        else max(0, self.max, tensor.max())
        self.min = min(0, tensor.min()) if self.min is None \
                                        else min(0, self.min, tensor.min())
        self.scale, self.zero_point = calcu_scale_and_zeropoint(self.min, self.max, self.num_bits)
        
    def quantize_tensor(self, tensor):
        return quantize_tensor(tensor, self.scale, self.zero_point, num_bits=self.num_bits)
    
    def dequantize_tensor(self, q_x):
        return dequantize_tensor(q_x, self.scale, self.zero_point)        

### 3、量化网络基类定义

In [5]:
class QModule(nn.Module):
    def __init__(self, has_qin=True, has_qout=True, num_bits=8):
        # 指定量化的位数外，还需指定是否提供量化输入 (qin) 及输出参数 (qout)
        # 不是每一个网络模块都需要统计输入的 min、max，大部分中间层都是用上一层的 qout 来作为自己的 qin 的，
        # 另外有些中间层的激活函数也是直接用上一层的 qin 来作为自己的 qin 和 qout。
        super(QModule, self).__init__()
        if has_qin:
            self.q_in = QParam(num_bits)
        if has_qout:
            self.q_out = QParam(num_bits)
    
    def freeze(self):
        # 函数会在统计完 min、max 后发挥作用
        # 很多项是可以提前计算好的，freeze 就是把这些项提前固定下来
        # 同时也将网络的权重由浮点实数转化为定点整数。
        pass
    
    def quantize_inference(self, x):
        # 在量化 inference 的时候会使用
        raise NotImplementedError('quantize_inference should be implemented.')

### 4、量化卷积层类的定义
<img src="https://www.zhihu.com/equation?tex=a%3D%5Csum_%7Bi%7D%5EN+w_i+x_i%2Bb+%5Ctag%7B1%7D+" alt="[公式]" style="zoom:80%;" />
由此得到量化的公式
<img src="https://www.zhihu.com/equation?tex=S_a+%28q_a-Z_a%29%3D%5Csum_%7Bi%7D%5EN+S_w%28q_w-Z_w%29S_x%28q_x-Z_x%29%2BS_b%28q_b-Z_b%29+%5Ctag%7B2%7D+" alt="[公式]" style="zoom:80%;" />
<img src="https://www.zhihu.com/equation?tex=q_a%3D%5Cfrac%7BS_w+S_x%7D%7BS_a%7D%5Csum_%7Bi%7D%5EN+%28q_w-Z_w%29%28q_x-Z_x%29%2B%5Cfrac%7BS_b%7D%7BS_a%7D%28q_b-Z_b%29%2BZ_a+%5Ctag%7B3%7D+" alt="[公式]" style="zoom:80%;" />


> <img src="./image.png" style="zoom:60%;" />

经过调整：

<img src="https://www.zhihu.com/equation?tex=%5Cbegin%7Balign%7D+q_a%26%3D%5Cfrac%7BS_w+S_x%7D%7BS_a%7D%28%5Csum_%7Bi%7D%5EN%28q_w-Z_w%29%28q_x-Z_x%29%2Bq_b%29%2BZ_a+%5Cnotag+%5C%5C+%26%3DM%28%5Csum_%7Bi%7D%5EN+q_wq_x-%5Csum_i%5EN+q_wZ_x-%5Csum_i%5EN+q_xZ_w%2B%5Csum_i%5ENZ_wZ_x%2Bq_b%29%2BZ_a+%5Ctag%7B4%7D+%5Cend%7Balign%7D+" alt="[公式]" style="zoom:80%;" />

In [6]:
class QConv2d(QModule):
    def __init__(self, conv_module, has_qin=True, has_qout=True, num_bits=8):
        super(QConv2d, self).__init__(has_qin, has_qout, num_bits)
        self.num_bits = num_bits
        self.conv_module = conv_module
        self.q_weight = QParam(num_bits=num_bits)
    
    def freeze(self, q_in=None, q_out=None):
        if hasattr(self, 'q_in') and q_in is not None:
            raise ValueError('q_in has been provided in init function.')
        if not hasattr(self, 'q_in') and q_in is None:
            raise ValueError('q_in is not existed, should be provided.')

        if hasattr(self, 'q_out') and q_out is not None:
            raise ValueError('q_out has been provided in init function.')
        if not hasattr(self, 'q_out') and q_out is None:
            raise ValueError('q_out is not existed, should be provided.')
            
        if q_in is not None:
            self.q_in = q_in
        if q_out is not None:
            self.q_out = q_out
        
        # 计算 M = s_w * s_in / s_out
        self.M = self.q_weight.scale * self.q_in.scale / self.q_out.scale
        
        # 量化卷积层中的权重
        self.conv_module.weight.data = self.q_weight.quantize_tensor(self.conv_module.weight.data) \
                                        - self.q_weight.zero_point
        # 量化卷积层中的偏置
        self.conv_module.bias.data = quantize_tensor(self.conv_module.bias.data, scale=self.q_in.scale * self.q_weight.scale,
                                                     zero_point=0, num_bits=32, signed=True)
        
        
    def forward(self, x):
        if hasattr(self, 'q_in'):
            self.q_in.update(x)
            
        self.q_weight.update(self.conv_module.weight.data)
        
        # 伪量化节点
        self.conv_module.weight.data = self.q_weight.quantize_tensor(self.conv_module.weight.data)
        self.conv_module.weight.data = self.q_weight.dequantize_tensor(self.conv_module.weight.data)
        
        x = self.conv_module(x)
        
        if hasattr(self, 'q_out'):
            self.q_out.update(x)
            
        return x
    
    def quantize_inference(self, x):
        x = x - self.q_in.zero_point
        x = self.conv_module(x)
        x = self.M * x
        x.round_()
        x = x + self.q_out.zero_point
        x.clamp_(0., 2. ** self.num_bits - 1.).round_()
        return x 

### 5、量化线性层类的定义

In [7]:
class QLinear(QModule):
    def __init__(self, fc_module, has_qin=True, has_qout=True, num_bits=8):
        super(QLinear, self).__init__(has_qin, has_qout, num_bits)
        self.num_bits = num_bits
        self.fc_module = fc_module
        self.q_weight = QParam(num_bits=num_bits)
    
    def freeze(self, q_in=None, q_out=None):
        if hasattr(self, 'q_in') and q_in is not None:
            raise ValueError('q_in has been provided in init function.')
        if not hasattr(self, 'q_in') and q_in is None:
            raise ValueError('q_in is not existed, should be provided.')

        if hasattr(self, 'q_out') and q_out is not None:
            raise ValueError('q_out has been provided in init function.')
        if not hasattr(self, 'q_out') and q_out is None:
            raise ValueError('q_out is not existed, should be provided.')
            
        if q_in is not None:
            self.q_in = q_in
        if q_out is not None:
            self.q_out = q_out
        
        # 计算 M = s_w * s_in / s_out
        self.M = self.q_weight.scale * self.q_in.scale / self.q_out.scale
        
        # 量化卷积层中的权重
        self.fc_module.weight.data = self.q_weight.quantize_tensor(self.fc_module.weight.data) \
                                        - self.q_weight.zero_point
        # 量化卷积层中的偏置
        self.fc_module.bias.data = quantize_tensor(self.fc_module.bias.data, scale=self.q_in.scale * self.q_weight.scale,
                                                     zero_point=0, num_bits=32, signed=True)
        
    def forward(self, x):
        if hasattr(self, 'q_in'):
            self.q_in.update(x)
            
        self.q_weight.update(self.fc_module.weight.data)
        
        # 伪量化节点
        self.fc_module.weight.data = self.q_weight.quantize_tensor(self.fc_module.weight.data)
        self.fc_module.weight.data = self.q_weight.dequantize_tensor(self.fc_module.weight.data)
        
        x = self.fc_module(x)
        
        if hasattr(self, 'q_out'):
            self.q_out.update(x)
            
        return x
    
    def quantize_inference(self, x):
        x = x - self.q_in.zero_point
        x = self.fc_module(x)
        x = self.M * x
        x.round_()
        x = x + self.q_out.zero_point
        x.clamp_(0., 2. ** self.num_bits - 1.).round_()
        return x 

### 6、量化ReLu 类的定义

In [8]:
class QReLU(QModule):
    def __init__(self, has_qin=False, num_bits=None):
        super(QReLU, self).__init__(has_qin=has_qin,num_bits=num_bits)
        
    def freeze(self, q_in=None):
        if hasattr(self, 'q_in') and q_in is not None:
            raise ValueError('q_in has been provided in init function.')
        if not hasattr(self, 'q_in') and q_in is None:
            raise ValueError('q_in is not existed, should be provided.')
            
        if q_in is not None:
            self.q_in = q_in
    
    def forward(self, x):
        if hasattr(self, 'q_in'):
            self.q_in.update(x)
            
        x = F.relu(x)
        
        return x

    def quantize_inference(self, x):
        x = x.clone()
        x[x < self.q_in.zero_point] = self.q_in.zero_point
        return x

### 7、量化最大池化层类的定义

In [9]:
class QMaxPooling2d(QModule):
    def __init__(self, kernel_size=3, stride=1, padding=0, has_qin=False, num_bits=None):
        super(QMaxPooling2d, self).__init__(has_qin=has_qin, num_bits=num_bits)
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        
    def freeze(self, q_in=None):
        if hasattr(self, 'q_in') and q_in is not None:
            raise ValueError('q_in has been provided in init function.')
        if not hasattr(self, 'q_in') and q_in is None:
            raise ValueError('q_in is not existed, should be provided.')
            
        if q_in is not None:
            self.q_in = q_in
            
    def forward(self, x):
        if hasattr(self, 'q_in'):
            self.q_in.update(x)
            
        x = F.max_pool2d(x, self.kernel_size, self.stride, self.padding)
        
        return x
    
    def quantize_inference(self, x):
        return F.max_pool2d(x, self.kernel_size, self.stride, self.padding)

In [10]:
class Net(nn.Module):

    def __init__(self, num_channels=1):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, 40, 3, 1)
        self.conv2 = nn.Conv2d(40, 40, 3, 1, groups=20)
        self.fc = nn.Linear(5*5*40, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 5*5*40)
        x = self.fc(x)
        return x

    def quantize(self, num_bits=8):
        # 逐个量化每个模块
        self.qconv1 = QConv2d(self.conv1, has_qin=True, has_qout=True, num_bits=num_bits)
        self.qrelu1 = QReLU()
        self.qmaxpool2d_1 = QMaxPooling2d(kernel_size=2, stride=2, padding=0)
        self.qconv2 = QConv2d(self.conv2, has_qin=False, has_qout=True, num_bits=num_bits)
        self.qrelu2 = QReLU()
        self.qmaxpool2d_2 = QMaxPooling2d(kernel_size=2, stride=2, padding=0)
        self.qfc = QLinear(self.fc, has_qin=False, has_qout=True, num_bits=num_bits)

    def quantize_forward(self, x):
        # 统计 min、max，同时模拟量化误差
        x = self.qconv1(x)
        x = self.qrelu1(x)
        x = self.qmaxpool2d_1(x)
        x = self.qconv2(x)
        x = self.qrelu2(x)
        x = self.qmaxpool2d_2(x)
        x = x.view(-1, 5*5*40)
        x = self.qfc(x)
        return x

    def freeze(self):
        # 在统计 min、max后，对一些变量进行固化
        self.qconv1.freeze()
        self.qrelu1.freeze(self.qconv1.q_out)
        self.qmaxpool2d_1.freeze(self.qconv1.q_out)
        self.qconv2.freeze(q_in=self.qconv1.q_out)
        self.qrelu2.freeze(self.qconv2.q_out)
        self.qmaxpool2d_2.freeze(self.qconv2.q_out)
        self.qfc.freeze(q_in=self.qconv2.q_out)

    def quantize_inference(self, x):
        # 量化推理使用的函数
        qx = self.qconv1.q_in.quantize_tensor(x)  # 输入量化到int8
        qx = self.qconv1.quantize_inference(qx)
        qx = self.qrelu1.quantize_inference(qx)
        qx = self.qmaxpool2d_1.quantize_inference(qx)
        qx = self.qconv2.quantize_inference(qx)
        qx = self.qrelu2.quantize_inference(qx)
        qx = self.qmaxpool2d_2.quantize_inference(qx)
        qx = qx.view(-1, 5*5*40)
        qx = self.qfc.quantize_inference(qx)
        out = self.qfc.q_out.dequantize_tensor(qx)  # 输出反量化到 float
        return out


In [11]:
class NetBN(nn.Module):

    def __init__(self, num_channels=1):
        super(NetBN, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, 40, 3, 1)
        self.bn1 = nn.BatchNorm2d(40)
        self.conv2 = nn.Conv2d(40, 40, 3, 1)
        self.bn2 = nn.BatchNorm2d(40)
        self.fc = nn.Linear(5 * 5 * 40, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2)
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 5 * 5 * 40)
        x = self.fc(x)
        return x

    def quantize(self, num_bits=8):
        self.qconv1 = QConvBNReLU(self.conv1, self.bn1, has_qin=True, has_qout=True, num_bits=num_bits)
        self.qmaxpool2d_1 = QMaxPooling2d(kernel_size=2, stride=2, padding=0)
        self.qconv2 = QConvBNReLU(self.conv2, self.bn2, has_qin=False, has_qout=True, num_bits=num_bits)
        self.qmaxpool2d_2 = QMaxPooling2d(kernel_size=2, stride=2, padding=0)
        self.qfc = QLinear(self.fc, has_qin=False, has_qout=True, num_bits=num_bits)

    def quantize_forward(self, x):
        x = self.qconv1(x)
        x = self.qmaxpool2d_1(x)
        x = self.qconv2(x)
        x = self.qmaxpool2d_2(x)
        x = x.view(-1, 5*5*40)
        x = self.qfc(x)
        return x

    def freeze(self):
        self.qconv1.freeze()
        self.qmaxpool2d_1.freeze(self.qconv1.q_out)
        self.qconv2.freeze(q_in=self.qconv1.q_out)
        self.qmaxpool2d_2.freeze(self.qconv2.q_out)
        self.qfc.freeze(q_in=self.qconv2.q_out)

    def quantize_inference(self, x):
        qx = self.qconv1.q_in.quantize_tensor(x)
        qx = self.qconv1.quantize_inference(qx)
        qx = self.qmaxpool2d_1.quantize_inference(qx)
        qx = self.qconv2.quantize_inference(qx)
        qx = self.qmaxpool2d_2.quantize_inference(qx)
        qx = qx.view(-1, 5*5*40)

        qx = self.qfc.quantize_inference(qx)
        
        out = self.qfc.q_out.dequantize_tensor(qx)
        return out


In [12]:
import torch
import torch.optim as optim
from torchvision import datasets, transforms
import os

In [13]:
def train(model, device, train_loader, criterion, optimizer, epoch):
    model.train()
    for batch_idx, (datas, targets) in enumerate(train_loader):
        datas, targets = datas.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(datas)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print('Train Epoch:{} [{}/{}] \t Loss: {:.4f}'.format(
                epoch, batch_idx * len(datas), len(train_loader.dataset), loss.item()
            ))

In [14]:
def test(model, device, test_loader, criterion):
    model.eval()
    test_loss = 0
    correct = 0
    for batch_idx, (datas, targets) in enumerate(test_loader):
        datas, targets = datas.to(device), targets.to(device)
        outputs = model(datas)
        loss = criterion(outputs, targets)
        
        test_loss += loss.item()
        pred = outputs.argmax(dim=1, keepdim=True)
        correct += pred.eq(targets.view_as(pred)).sum().item()
        
    test_loss /= len(test_loader.dataset)
    
    print('\nTest set: Average loss: {:.4f}, Accuracy: {:.0f}%\n'.format(
        test_loss, 100. * correct / len(test_loader.dataset)
    ))

In [15]:
def dataset_loader(batch_size, test_batch_size):
    train_transform = transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])
    test_transform = transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])
    
    train_dataset = datasets.MNIST(r'C:\Users\xia\Documents\datasets', train=True, download=True,
                                  transform=train_transform)
    test_dataset = datasets.MNIST(r'C:\Users\xia\Documents\datasets', train=False, download=True,
                                 transform=test_transform)
    
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                              shuffle=True, num_workers=1)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size,
                                             shuffle=False, num_workers=1)
    
    return train_loader, test_loader

In [16]:
def direct_quantize(model, test_loader):
    for idx ,(datas, targets) in enumerate(test_loader,1):
        output = model.quantize_forward(datas)
        if idx % 500 == 0:
            break
    print('direct quantization finish')

In [17]:
def full_inference(model, test_loader):
    correct = 0
    for idx, (datas, targets) in enumerate(test_loader, 1):
        output = model(datas)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(targets.view_as(pred)).sum().item()
    print('\nTest set: Full Model Accuracy: {:.0f}%\n'.format(100. * correct / len(test_loader.dataset)))


In [18]:
def quantize_inference(model, test_loader):
    correct = 0
    for i, (datas, targets) in enumerate(test_loader, 1):
        output = model.quantize_inference(datas)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(targets.view_as(pred)).sum().item()
    print('\nTest set: Quant Model Accuracy: {:.0f}%\n'.format(100. * correct / len(test_loader.dataset)))


In [19]:
def main():
    batch_size = 64
    test_batch_size = 64
    using_bn = False
    
    train_loader, test_loader = dataset_loader(batch_size, test_batch_size)
    
    if using_bn:
        model = NetBN()
        model.load_state_dict(torch.load('ckpt/mnist_cnnbn.pt'))
    else:
        model = Net()
        model.load_state_dict(torch.load('ckpt/mnist_cnn.pt'))
    
    model.eval()
    full_inference(model, test_loader)

    num_bits = 8
    model.quantize(num_bits=num_bits)
    model.eval()
    print('Quantization bit: %d' % num_bits)

    direct_quantize(model, train_loader)

    model.freeze()

    quantize_inference(model, test_loader)

In [20]:
if __name__ == "__main__":
    main()


Test set: Full Model Accuracy: 98%

Quantization bit: 8
direct quantization finish

Test set: Quant Model Accuracy: 98%

