In [None]:
import os
import time

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

Dataset 준비 (CIFAR10)

In [None]:
batch_size = 64
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

trainset = datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=0.0, max=170498071.0), HTML(value='')))


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


## 1. Baselline Model 만들기 (Fp32)


### 1.1 Model 정의

### * Misssion
**자신만의 모델을 만들어 같은 과정을 반복하며 quantization을 효과를 비교해봅시다.**

In [None]:
# quantize 가능한 단순한 floating point Model을 정의합니다.
class ConvBNReLU(nn.Sequential):
    def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
        padding = (kernel_size - 1) // 2
        super(ConvBNReLU, self).__init__(
            nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
            nn.BatchNorm2d(out_planes, momentum=0.1),
            # Replace with ReLU
            nn.ReLU(inplace=False)
        )

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        # QuantStub을 이용해 floating point에서 quantized tensor로 변환합니다.
        self.quant = torch.quantization.QuantStub()
        
        # Conv2d (in channel, out channel, kernel size)
        self.convbnrelu1 = ConvBNReLU(3, 32, 3, 2)
        self.convbnrelu2 = ConvBNReLU(32, 64, 3, 1)
        self.maxpool1 = nn.MaxPool2d(2)
        self.convbnrelu3 = ConvBNReLU(64, 128, 3, 1)
        self.maxpool2 = nn.MaxPool2d(2)
        self.convbnrelu4 = ConvBNReLU(128, 256, 3, 1)
        self.flatten = torch.nn.Flatten()
        
        # Image size : 32 x 32
        self.linear = torch.nn.Linear(256, 10)

        # DeQuantStub을 이용해 quantized tensor에서 floating point로 변환합니다.
        self.dequant = torch.quantization.DeQuantStub()
    
    def forward(self, x):
        x = self.quant(x)

        # |x| = (batch_size, 3, 32, 32)
        x = self.convbnrelu1(x)
        x = self.convbnrelu2(x)
        x = self.maxpool1(x)
        x = self.convbnrelu3(x)
        x = self.maxpool2(x)
        x = self.convbnrelu4(x)

        x = nn.functional.adaptive_avg_pool2d(x, 1)

        # |x| = (batch_size, 1, 32, 32)
        x = self.flatten(x)

        # |x| = (batch_size, 3*32*32)
        x = self.linear(x)

        # |x| = (batch_size, 10)
        x = self.dequant(x)
        return x

    # Fuse Conv+BN and Conv+BN+Relu modules prior to quantization
    # This operation does not change the numerics
    def fuse_model(self):
        for m in self.modules():
            if type(m) == ConvBNReLU:
                torch.quantization.fuse_modules(m, ['0', '1', '2'], inplace=True)

def print_size_of_model(model):
    model_path = "./model.p"
    torch.save(model.state_dict(), model_path)
    print('size(mb) : ', os.path.getsize(model_path) / 1e6)
    os.remove(model_path)

@torch.no_grad()
def test_model(model, testloader, device, half=False):
    correct = 0
    total = 0
    model = model.to(device)
    start_time = time.time()
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        if half:
          images = images.half()
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print(f'Accuracy of the network on the 10000 test images: {round(100.0 * correct / total, 2)}%')
    print(f'Elpased time: {round(time.time() - start_time, 3)}s, on {device}')
    
def load_model(model_file):
    model = Model()
    state_dict = torch.load(model_file)
    model.load_state_dict(state_dict)
    return model


### 1.2 Baseline 모델 학습

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
epochs = 20
saved_model_dir = 'data'
float_model_file = 'pretrained_float.pth'

model = Model()
model.train()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
model = model.to(device)

for epoch in range(epochs):
  running_loss = 0.0
  for i, data in enumerate(trainloader, 0):

    inputs, labels = data
    inputs, labels = inputs.to(device), labels.to(device)
    
    optimizer.zero_grad()
    outputs = model(inputs)

    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

    running_loss += loss.item()
  print('[%d, %5d] loss: %.3f' %
        (epoch + 1, i + 1, running_loss / len(trainloader)))
  test_model(model=model, testloader=testloader, device=device)

print('Finished Training')

# save model
fp32_path = os.path.join(saved_model_dir, float_model_file)
torch.save(model.state_dict(), fp32_path)

[1,   782] loss: 1.220
Accuracy of the network on the 10000 test images: 65.79%
Elpased time: 1.734s, on cuda
[2,   782] loss: 0.850
Accuracy of the network on the 10000 test images: 70.82%
Elpased time: 1.755s, on cuda
[3,   782] loss: 0.703
Accuracy of the network on the 10000 test images: 73.72%
Elpased time: 1.697s, on cuda
[4,   782] loss: 0.595
Accuracy of the network on the 10000 test images: 75.21%
Elpased time: 1.7s, on cuda
[5,   782] loss: 0.514
Accuracy of the network on the 10000 test images: 76.38%
Elpased time: 1.732s, on cuda
[6,   782] loss: 0.445
Accuracy of the network on the 10000 test images: 76.76%
Elpased time: 1.695s, on cuda
[7,   782] loss: 0.385
Accuracy of the network on the 10000 test images: 77.36%
Elpased time: 1.72s, on cuda
[8,   782] loss: 0.328
Accuracy of the network on the 10000 test images: 77.61%
Elpased time: 1.699s, on cuda
[9,   782] loss: 0.282
Accuracy of the network on the 10000 test images: 77.26%
Elpased time: 1.716s, on cuda
[10,   782] l

## 2. Fp16 quantization

### 2.1 Fp16 quantization

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = load_model(fp32_path)
model.eval()
print("[fp32]")
print_size_of_model(model)
test_model(model=model,
           testloader=testloader,
           device=device,
           half=False)
model = load_model(fp32_path)
model.eval()
fp16_model = model.half()
print("[fp16]")
print_size_of_model(fp16_model)
test_model(model=fp16_model,
           testloader=testloader,
           device=device,
           half=True)

[fp32]
size(mb) :  1.578521
Accuracy of the network on the 10000 test images: 77.26%
Elpased time: 1.695s, on cuda
[fp16]
size(mb) :  0.793689
Accuracy of the network on the 10000 test images: 77.26%
Elpased time: 1.694s, on cuda


## 3. Post Training Quantization (Static Quantization)

* Post Training Quantization(이하 PTQ)은 model의 weights와 activations를 fp32 -> qint8로 quantize 합니다. 
* PTQ는 activations를 이전 레이어에 fuse 시킵니다.
* Quantization 할 때, activation의 optimal quantization parameter를 찾기 위해 representative dataset이 필요합니다.

* model 인스턴스를 만들고, static quantization을 위해 eval mode로 세팅합니다.

In [None]:
model = load_model(fp32_path)
model.eval()
model = model.to('cpu')
print(model)

Model(
  (quant): QuantStub()
  (convbnrelu1): ConvBNReLU(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (convbnrelu2): ConvBNReLU(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (convbnrelu3): ConvBNReLU(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (maxpool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (convbnrelu4): ConvBNReLU(
    (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False

* pytorch에서는 사용하는 cpu의 사양에 따라  두 종류의 quantization backend를 지원합니다.
* x86 CPUs 경우 'fbgemm'
* ARM CPUs 경우 'qnnpack'


In [None]:
backend = "fbgemm"
model.qconfig = torch.quantization.get_default_qconfig(backend)

* activations를 이전 레이어에 fuse 시킵니다
* calibration을 위해 prepare된 model에 representative 한 dataset을 입력합니다.

In [None]:
model.fuse_model()
model_prepared = torch.quantization.prepare(model)

# calibration with traindata
test_model(model_prepared, trainloader, 'cpu')

  reduce_range will be deprecated in a future release of PyTorch."


Accuracy of the network on the 10000 test images: 98.39%
Elpased time: 46.7s, on cpu


* 모델을 convert 하여 quantization 된 model을 얻습니다.
* model의 크기를 출력하여 비교합니다.

In [None]:
model_int8 = torch.quantization.convert(model_prepared)

model = load_model(fp32_path)
model.eval()
print("[fp32]")
print_size_of_model(model)
test_model(model, testloader, 'cpu')

print("[int8(ptq)]")
print_size_of_model(model_int8)
test_model(model_int8, testloader, 'cpu')

[fp32]
size(mb) :  1.578521
Accuracy of the network on the 10000 test images: 77.26%
Elpased time: 7.367s, on cpu
[int8(ptq)]
size(mb) :  0.410703
Accuracy of the network on the 10000 test images: 77.4%
Elpased time: 4.433s, on cpu


## 4. Quantization Aware Training (QAT)
* Quantization Aware Trainig(QAT)는 위에서 설명한 방법과 다르게 training 과정에서 quantization error를  모델링하여 quantization 합니다.
* PTQ와 비슷하게 모델을 fuse 시키고 prepare 합니다.

In [None]:
model = load_model(fp32_path)
model = model.to('cpu')
model.train()

backend = "fbgemm"
model.qconfig = torch.quantization.get_default_qat_qconfig(backend)
model.fuse_model()
qat_model = torch.quantization.prepare_qat(model, inplace=True)

In [None]:
device = "cpu"
# train
for epoch in range(8):
  running_loss = 0.0
  if epoch > 3:
      # Freeze quantizer parameters
      qat_model.apply(torch.quantization.disable_observer)
  if epoch > 2:
      # Freeze batch norm mean and variance estimates
      qat_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
  for i, data in enumerate(trainloader, 0):
    inputs, labels = data
    inputs, labels = inputs.to(device), labels.to(device)
    optimizer.zero_grad()
    outputs = qat_model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    running_loss += loss.item()
  quantized_model = torch.quantization.convert(qat_model.eval(), inplace=False)
  quantized_model.eval()
  print('[%d, %5d] loss: %.3f' %
        (epoch + 1, i + 1, running_loss / len(trainloader)))
  test_model(model=quantized_model, testloader=testloader, device=device)

print('Finished Training')

* quantization 이전 모델과 이후 모델의 크기를 확인합니다.

In [None]:
qat_int8 = torch.quantization.convert(qat_model.eval())

model = load_model(fp32_path)
model.eval()
print("[fp32]")
print_size_of_model(model)
test_model(model, testloader, 'cpu')

print("[int8(qat)]")
print_size_of_model(qat_int8)
test_model(qat_int8, testloader, 'cpu')

## 5. Mixed precision
* model.half()를 사용하여 model의 dtype을 fp32 -> fp16으로 변환합니다.
* quantization을 위해 model을 eval mode로 설정합니다.

In [None]:
model_fp32 = Model()
model_fp32.eval()

Model(
  (quant): QuantStub()
  (convbnrelu1): ConvBNReLU(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (convbnrelu2): ConvBNReLU(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (convbnrelu3): ConvBNReLU(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (maxpool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (convbnrelu4): ConvBNReLU(
    (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False

In [None]:
print("[fp32]")
print_size_of_model(model_fp32)


print("[fp16")
model_fp16 = model_fp32.half()
print_size_of_model(model_fp16)

[fp32]
size(mb) :  1.578521
[fp16
size(mb) :  0.793689


## 6. Conv-bn fuse 
Code mostly borrowed from https://learnml.today/speeding-up-model-with-fusing-batch-normalization-and-convolution-3

In [None]:
import torch
import torchvision

def fuse_convbn(conv, bn):
    """Fuse conv + bn module into single conv."""

    fused = torch.nn.Conv2d(
        conv.in_channels,
        conv.out_channels,
        kernel_size=conv.kernel_size,
        stride=conv.stride,
        padding=conv.padding,
        bias=True
    )

    # Setting weights
    w_conv = conv.weight.clone().view(conv.out_channels, -1)
    # Scaling factor of combined normalize - renormalize
    w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps+bn.running_var)))
    fused.weight.copy_(torch.mm(w_bn, w_conv).view(fused.weight.size()))
    
    # Setting bias
    if conv.bias is not None:
        b_conv = conv.bias
    else:
        b_conv = torch.zeros(conv.weight.size(0))
    # Shifting factor of combined normalize - renormalize
    b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(
                            torch.sqrt(bn.running_var + bn.eps)
                        )
    fused.bias.copy_(bn.weight.mul(b_conv) + b_bn)

    return fused

Check if its equivalent, compare speed

In [None]:
import torch.autograd.profiler as profiler

# we need to turn off gradient calculation because we didn't write it
torch.set_grad_enabled(False)
x = torch.randn(16, 3, 256, 256)
resnet18 = torchvision.models.resnet18(pretrained=True)
# removing all learning variables, etc
resnet18.eval()
# detach only single layer
conv_bn = torch.nn.Sequential(
    resnet18.conv1,
    resnet18.bn1
)
fused_conv = fuse_convbn(conv_bn[0], conv_bn[1])

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth


HBox(children=(FloatProgress(value=0.0, max=46827520.0), HTML(value='')))




In [None]:
device = torch.device("cpu")
x = x.to(device)
conv_bn = conv_bn.to(device)
with profiler.profile(record_shapes=True, profile_memory=True) as prof:
    with profiler.record_function("model_inference"):
        f1 = conv_bn.forward(x)
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                            Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                 model_inference         1.99%       2.634ms        99.98%     132.087ms     132.087ms      64.00 Mb     -64.00 Mb             1  
                    aten::conv2d         0.01%      11.000us        86.91%     114.826ms     114.826ms      64.00 Mb           0 b             1  
               aten::convolution         0.01%      17.192us        86.91%     114.815ms     114.815ms      64.00 Mb           0 b             1  
              aten::_convolution         0.03%      41.494us        86.89%     114.797ms     114.797ms      64.00 Mb  

In [None]:
x = x.to(device)
fused_conv = fused_conv.to(device)
with profiler.profile(record_shapes=True, profile_memory=True) as prof:
    with profiler.record_function("model_inference"):
        f2 = fused_conv.forward(x)
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                        Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
             model_inference         0.10%      93.046us        99.97%      95.145ms      95.145ms      64.00 Mb         -20 b             1  
                aten::conv2d         0.01%       9.681us        99.87%      95.050ms      95.050ms      64.00 Mb           0 b             1  
           aten::convolution         0.01%      10.625us        99.86%      95.040ms      95.040ms      64.00 Mb           0 b             1  
          aten::_convolution         0.02%      19.649us        99.85%      95.029ms      95.029ms      64.00 Mb           0 b             1  

In [None]:
d = (f1 - f2).mean().item()
print("error:",d)

error: 9.24077741409901e-12


# Further studies

https://tutorials.pytorch.kr/advanced/static_quantization_tutorial.html