In [None]:
from net import CifarNet

In [2]:
import torch

torch.backends.quantized.engine = 'qnnpack'

import torchvision
from torchvision.datasets import CIFAR10
from torchvision.transforms import transforms
tf = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
testloader = torch.utils.data.DataLoader(torchvision.datasets.CIFAR10('data/', train=False, download=True, transform=tf), batch_size=32)

In [None]:
import time 

def net_time(model_class, testloader):
    
    model = model_class()
    t_start = time.time()
    x, _ = next(iter(testloader))
    model(x)
    t_end = time.time()
    t = t_end - t_start
    return t

def net_acc(model_class, state_dict, testloader):

    model = model_class()
    model.load_state_dict(state_dict)
    num_correct = 0
    inputs, targets = next(iter(testloader))
    outputs = model(inputs)
    predicted = torch.argmax(outputs, dim=1) 
    num_correct += (predicted == targets).sum().item()
    
    accuracy = num_correct/32
    return accuracy
    

In [4]:
print(f'Time unquantized: {net_time(CifarNet, testloader)} s')
print(f"Accuracy unquantized: {net_acc(CifarNet, torch.load('state_dict.pt'), testloader):.4%}")

Time unquantized: 0.028605937957763672 s
Accuracy unquantized: 84.3750%


In [9]:
import torch.nn as nn
import torch.nn.functional as F

def f_sd(sd, endswith_key_string):
    keys = [i for i in sd.keys() if i.endswith(endswith_key_string)]
    if not keys:
        raise KeyError(endswith_key_string)
    return sd[keys[0]]

#Quantized Conv2dReLU Module
class QConv2dReLU(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(QConv2dReLU, self).__init__()

        self.weight = torch.nn.Parameter(torch.quantize_per_tensor(torch.Tensor(
                out_channels, in_channels // 1, *(kernel_size, kernel_size)), scale=0.1, zero_point = 0, dtype=torch.qint8), requires_grad=False)
        self.bias = torch.nn.Parameter(torch.Tensor(out_channels), requires_grad=False)

        self.register_buffer('scale', torch.tensor(0.1))

        self.stride = stride
        self.padding = padding
        
        self._prepack = self._prepare_prepack(self.weight, self.bias, stride, padding)
        self._register_load_state_dict_pre_hook(self._sd_hook)

    def _prepare_prepack(self, qweight, bias, stride, padding):
        assert qweight.is_quantized, "QConv2dReLU requires a quantized weight."
        assert not bias.is_quantized, "QConv2dReLU requires a float bias."
        return torch.ops.quantized.conv2d_prepack(qweight, bias, stride=[stride, stride], dilation=[1,1], padding=[padding, padding], groups=1)

    
    def _sd_hook(self, state_dict, prefix, *_):
        self._prepack = self._prepare_prepack(f_sd(state_dict, prefix + 'weight'), f_sd(state_dict, prefix + 'bias'),
                                             self.stride, self.padding)
    
    def forward(self, x):
        return torch.ops.quantized.conv2d_relu(x, self._prepack, self.scale, 64)

    
#Quantized Linear Module
class QLinear(nn.Module):
    def __init__(self, in_features, out_features):
        super(QLinear, self).__init__()

        self.weight = torch.nn.Parameter(torch.quantize_per_tensor(torch.Tensor(out_features, in_features), scale=0.1, zero_point = 0, dtype=torch.qint8), requires_grad=False)
        self.bias = torch.nn.Parameter(torch.Tensor(out_features))

        self.register_buffer('scale', torch.tensor(0.1))
        
        self._prepack = self._prepare_prepack(self.weight, self.bias)
        
        self._register_load_state_dict_pre_hook(self._sd_hook)
        
    def _prepare_prepack(self, qweight, bias):
        assert qweight.is_quantized, "QConv2dReLU requires a quantized weight."
        assert not bias.is_quantized, "QConv2dReLU requires a float bias."
        return torch.ops.quantized.linear_prepack(qweight, bias)
    
    def _sd_hook(self, state_dict, prefix, *_):
        self._prepack = self._prepare_prepack(f_sd(state_dict, prefix + 'weight'), f_sd(state_dict, prefix + 'bias'))
        return

    def forward(self, x):
        return torch.ops.quantized.linear(x, self._prepack, self.scale, 64)

In [6]:
print('state_dict of QConv2dReLU')
qconv2drelu = QConv2dReLU(3, 16)
for key in qconv2drelu.state_dict(): print(key, qconv2drelu.state_dict()[key].dtype)
print('\nstate_dict of QLinear')
qlinear = QLinear(10, 10)
for key in qlinear.state_dict(): print(key, qlinear.state_dict()[key].dtype)

state_dict of QConv2dReLU
weight torch.qint8
bias torch.float32
scale torch.float32

state_dict of QLinear
weight torch.qint8
bias torch.float32
scale torch.float32


In [None]:
class QCifarNet(nn.Module):
    def __init__(self):
        super(QCifarNet, self).__init__()
        
        self.register_buffer("scale", torch.tensor(0.1))

        self.conv1 = QConv2dReLU(3, 16, 3, 1, padding=1)
        self.conv2 = QConv2dReLU(16,16, 3, 1, padding=1)

        self.conv3 = QConv2dReLU(16, 32, 3, 1, padding=1)
        self.conv4 = QConv2dReLU(32, 32, 3, 1, padding=1)

        self.conv5 = QConv2dReLU(32, 64, 3, 1, padding=1)
        self.conv6 = QConv2dReLU(64, 64, 3, 1, padding=1)

        self.fc = QLinear(1024, 10)
        
    def forward(self, x):
        x = torch.quantize_per_tensor(x, scale=0.1, zero_point=64, dtype=torch.quint8)
        
        x = self.conv1(x)
        x = self.conv2(x)
        x = torch.nn.quantized.functional.max_pool2d(x, 2)

        x = self.conv3(x)
        x = self.conv4(x)
        x = torch.nn.quantized.functional.max_pool2d(x, 2)

        x = self.conv5(x)
        x = self.conv6(x)
        x = torch.nn.quantized.functional.max_pool2d(x, 2)

        x = torch.flatten(x, 1)
        x = self.fc(x)

        x = torch.dequantize(x)
        
        return x

In [8]:
#We evaulate how fast the quantized verions of CifarNet is
print(f"Time quantized: {net_time(QCifarNet, testloader)} s")

Time quantized: 0.015759706497192383 s


In [None]:
def tensor_scale(input):
    return float(2*torch.max(torch.abs(torch.max(input)), torch.abs(torch.min(input))))/127.0

def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_w, bn_b):
    """
    Input:
        conv_w: shape=(output_channels, in_channels, kernel_size, kernel_size)
        conv_b: shape=(output_channels)
        bn_rm:  shape=(output_channels)
        bn_rv:  shape=(output_channels)
        bn_w:   shape=(output_channels)
        bn_b:   shape=(output_channels)
    
    Output:
        fused_conv_w = shape=conv_w
        fused_conv_b = shape=conv_b
    """
    bn_eps = 1e-05

    fused_conv = torch.zeros(conv_w.shape)
    fused_bias = torch.zeros(conv_b.shape)
    
    std = torch.sqrt(bn_rv + bn_eps)
    scale = bn_w / std

    fused_conv = conv_w * scale.reshape()
    fused_bias = scale * (conv_b - bn_rm) + bn_b

    return fused_conv, fused_bias

In [None]:
#prints keys from quantized net
qnet = QCifarNet()
qsd = qnet.state_dict()
for key in qsd: print(key, qsd[key].dtype)

sd = torch.load('state_dict.pt')
