# Import libs

In [1]:
import os
import numpy as np
from tqdm import tqdm
from datetime import datetime
import copy
# torch libs
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import pickle
import utils
from quant_utils import *
device = torch.device('cpu')

In [2]:

class BasicBlock(nn.Module):

    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()

        #residual function
        self.residual_function = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels * BasicBlock.expansion)
        )

        #shortcut
        self.shortcut = nn.Sequential()
        self.skip_add = nn.quantized.FloatFunctional()
        self.relu = nn.ReLU(inplace=True)
        #the shortcut output dimension is not the same with residual function
        #use 1*1 convolution to match the dimension
        if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * BasicBlock.expansion)
            )

    def forward(self, x):
        fwd = self.residual_function(x)
        fwd = self.skip_add.add(fwd, self.shortcut(x))
        return self.relu(fwd)

class ResNet(nn.Module):

    def __init__(self, block, num_block, num_classes=100):
        super().__init__()

        self.in_channels = 64

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True))
        #we use a different inputsize than the original paper
        #so conv2_x's stride is 1
        self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
        self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
        self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
        self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, out_channels, num_blocks, stride):

        # we have num_block blocks per layer, the first block
        # could be 1 or 2, other blocks would always be 1
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        output = self.conv1(x)
        output = self.conv2_x(output)
        output = self.conv3_x(output)
        output = self.conv4_x(output)
        output = self.conv5_x(output)
        output = self.avg_pool(output)
        output = torch.flatten(output, 1)#output.view(output.size(0), -1)
        output = self.fc(output)

        return output

def resnet18(num_classes=100):

    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)


In [3]:
train_loader, mean, std = utils.get_subtraining_dataloader_cifar10_intersect(
    propor=1.0, 
    batch_size=128, 
    num_workers=8, 
    shuffle=True, 
    sub_idx=1)
test_loader = utils.get_test_dataloader_cifar10(
    mean, std, 
    batch_size=128, num_workers=8, shuffle=False, pin_memory=False)

Files already downloaded and verified


In [4]:
model = resnet18(num_classes=10)
model.load_state_dict(
    torch.load('/data1/checkpoint/hash/cifar10/resnet18_0.pth', map_location=device))
model.eval()
model.to(device)
print("Loaded model.")

Loaded model.


# Quantization

In [5]:
fused_model= copy.deepcopy(model)
model.to('cpu')
model.eval()
# The model has to be switched to evaluation mode before any layer fusion.
# Otherwise the quantization will not work correctly.
fused_model.eval()
fused_model = torch.quantization.fuse_modules(
    fused_model, 
    [["conv1.0",
     "conv1.1",
    "conv1.2"]], 
    inplace=True)


In [6]:
for module_name, module in fused_model.named_children():
    print(module_name, "!")
    if '_x' in module_name:
        for basic_block_name, basic_block in module.named_children():
            print(basic_block_name, '$')
            for basic_subblock_name, basic_subblock in basic_block.named_children():
                if 'residual' in basic_subblock_name:
                    torch.quantization.fuse_modules(
                        basic_subblock, [["0", "1", "2"], ["3", "4"]], inplace=True)
                
                if 'shortcut' in basic_subblock_name and len(list(basic_subblock.named_children())) == 2:
                    torch.quantization.fuse_modules(
                        basic_subblock, [["0", "1"]], inplace=True)

conv1 !
conv2_x !
0 $
1 $
conv3_x !
0 $
1 $
conv4_x !
0 $
1 $
conv5_x !
0 $
1 $
avg_pool !
fc !


In [7]:
quantized_model = QuantizedNetwork(fused_model)
quantized_model.eval()
quantization_config = torch.quantization.get_default_qconfig("fbgemm")
quantized_model.qconfig = quantization_config
print(quantized_model.qconfig)
torch.quantization.prepare(quantized_model, inplace=True)

QConfig(activation=functools.partial(<class 'torch.quantization.observer.HistogramObserver'>, reduce_range=True), weight=functools.partial(<class 'torch.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric))




QuantizedNetwork(
  (quant): QuantStub(
    (activation_post_process): HistogramObserver()
  )
  (dequant): DeQuantStub()
  (model): ResNet(
    (conv1): Sequential(
      (0): ConvReLU2d(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
        (activation_post_process): HistogramObserver()
      )
      (1): Identity()
      (2): Identity()
    )
    (conv2_x): Sequential(
      (0): BasicBlock(
        (residual_function): Sequential(
          (0): ConvReLU2d(
            (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): ReLU(inplace=True)
            (activation_post_process): HistogramObserver()
          )
          (1): Identity()
          (2): Identity()
          (3): Conv2d(
            64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
            (activation_post_process): HistogramObserver()
          )
          (4): Identity()
        )
        (shortcut): Sequenti

In [8]:
%%time
calibrate_model(model=quantized_model, loader=train_loader, device='cpu')
quantized_model = torch.quantization.convert(quantized_model, inplace=True)

To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  /opt/conda/conda-bld/pytorch_1623448255797/work/aten/src/ATen/native/BinaryOps.cpp:467.)
  return torch.floor_divide(self, other)


CPU times: user 27min 51s, sys: 26 s, total: 28min 17s
Wall time: 1min 38s


In [9]:
quantized_model.eval()
# Print quantized model.
# print(quantized_model)
# Save quantized model.
save_torchscript_model(model=quantized_model, model_dir='/data1/checkpoint/hash/cifar10/', model_filename="resnet18_0_quant.pth")

In [10]:
print_size_of_model(model)

model   	 Size (KB): 44776.141


44776141

In [11]:
print_size_of_model(quantized_model)

model   	 Size (KB): 11308.065


11308065

In [12]:
_, int8_eval_accuracy = evaluate_model(model=quantized_model, test_loader=test_loader, device=device, criterion=None)
print("INT8 evaluation accuracy: {:.3f}".format(int8_eval_accuracy))


100%|██████████| 79/79 [00:08<00:00,  9.38it/s]

INT8 evaluation accuracy: 0.914





In [13]:
_, fp32_eval_accuracy = evaluate_model(model=model, test_loader=test_loader, device=device, criterion=None)
print("FP32 evaluation accuracy: {:.3f}".format(fp32_eval_accuracy))


100%|██████████| 79/79 [00:17<00:00,  4.45it/s]

FP32 evaluation accuracy: 0.916





In [14]:
rand_input = torch.clamp(torch.rand(500, 3, 32, 32), 0, 1)
rand_output  = quantized_model(rand_input).softmax(dim=1)
import pickle
pickle.dump(rand_output, open("../results/hash/cifar10/resnet18_0_quant.pkl", "wb"))