# MobileNetV2 Quantization Aware Training (QAT) on CIFAR10

```Note: As of today (01/09/2023), PyTorch Quantization Model Conversion can only be done CPU.```

<img src="fig/pytorch_qat_flow.png">


## 0. Import necessary packages

In [1]:
import sys
import os

import torch
import torch.nn as nn

# Set up warnings
import warnings
warnings.filterwarnings(
    action='ignore',
    category=DeprecationWarning,
    module=r'.*'
)
warnings.filterwarnings(
    action='default',
    module=r'torch.ao.quantization'
)

# Specify random seed for repeatable results
torch.manual_seed(2023)

print(f"PyTorch version: {torch.__version__}")

PyTorch version: 1.13.0+cu117


## 1. Model Modificiations: Replace Functionals with Modules

- Replacing addition `+` with `nn.quantized.FloatFunctional().add()` module function

- Insert `QuantStub` and `DeQuantStub` at the beginning and end of the network.

- Replace ReLU6 with ReLU

- Replace last `torch.nn.Linear` layer with `torch.nn.Conv2d` 1x1 kernel, as 1x1 `Conv2d` has better performance than `Linear`.

- Define `fuse_module()` functions to specify how to fuse modules. 

Please check [mobilenetv2.py](./mobilenetv2.py) for all the details.    

## 2. Model Training

### 2.1 Define dataset and data loaders

In [2]:
import torchvision
from torchvision import datasets, transforms

from torch.utils.data import (DataLoader, TensorDataset)

def prepare_data_loaders(data_path, train_batch_size, eval_batch_size, dry_run):
    IMAGE_HEIGHT, IMAGE_WIDTH = 224, 224    

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    if dry_run:
        batch_size = 1        
        dummy_dataset = TensorDataset(torch.rand(batch_size, 3, 224, 224), torch.randint(0, 10, (batch_size,)))
        train_loader = DataLoader(dummy_dataset,
                                batch_size=batch_size)
        test_loader = DataLoader(dummy_dataset,
                                batch_size=1)
    else:
        trainset = torchvision.datasets.CIFAR10(root='./data',train=True, download=True, transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, shuffle=True)

        testset = torchvision.datasets.CIFAR10(root='./data',train=False, download=True, transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset, batch_size=eval_batch_size, shuffle=False)

    return train_loader, test_loader

### 2.2 Define training functions

In [3]:
def train_one_epoch(model, device, train_loader, optimizer, criterion, ntrain_batches):    
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        if batch_idx % 10 == 0:            
            print('.', end='')
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        if batch_idx >= ntrain_batches - 1:
            print('\nTraining: [%d/%d] Loss: %.3f | Acc: %.3f%% (%d/%d)'
                % (batch_idx+1, len(train_loader), train_loss/(batch_idx+1), 100.*correct/total, correct, total))
            return
    
    print('Full training set: [%d/%d] Loss: %.3f | Acc: %.3f%% (%d/%d)'
                % (batch_idx+1, len(train_loader), train_loss/(batch_idx+1), 100.*correct/total, correct, total))
    return

def eval(model, device, test_loader, criterion, neval_batches):
    model.eval()

    test_loss = 0
    correct = 0
    total = 0    
    
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            if batch_idx % 10 == 0:
                print('.', end='')
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            if batch_idx >= neval_batches - 1:
                acc = 100.*correct/total
                loss = test_loss/(batch_idx+1)

                print('\nEval: [%d/%d] Loss: %.3f | Acc: %.3f%% (%d/%d)'
                    % (batch_idx+1, len(test_loader), loss, acc, correct, total))
                
                return loss, acc
    
    return loss, acc

### 2.3 Train

- It will take long time if using the current settings of epochs, ntrain_batches and neval_batches.

- In order to verify the whole flow, you don't have to train and evaluate for large epochs for the model to converge.

    - Define smaller epochs, or smaller ntrain_batches, or smaller neval_batches
    
    - use `dry_run=1`, it will use one random noise sample to train the model

#### 2.3.1 Define the same data loaders and criterion for float model and QAT 

In [4]:
# Check if GPU is available or not
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# Prepare data loader
data_path = './data'
train_loader, test_loader = prepare_data_loaders(data_path, train_batch_size=32, eval_batch_size=50, dry_run=0)

# Define a loss function
criterion = nn.CrossEntropyLoss()

Files already downloaded and verified
Files already downloaded and verified


#### 2.3.2 Training

- If your float model already trained, you can skip this step and load the pretrained model later.

In [7]:
from mobilenetv2 import mobilenet_v2

epochs = 20
ntrain_batches = 1560  # 50000/32
neval_batches = 200  # 10000/50

# Construct the mode
fp32_model = mobilenet_v2(pretrained=False, progress=True)
fp32_model.to(device)

# define the optimizer
optimizer = torch.optim.SGD(fp32_model.parameters(), lr=0.01)

for nepoch in range(epochs):    
    print("\nEpoch: {}".format(nepoch))
    train_one_epoch(fp32_model, device, train_loader, optimizer, criterion, ntrain_batches=ntrain_batches)

    eval(fp32_model, device, test_loader, criterion, neval_batches=neval_batches)


Epoch: 0
............................................................................................................................................................
Training: [1560/1563] Loss: 1.790 | Acc: 35.296% (17620/49920)
....................
Eval: [200/200] Loss: 1.427 | Acc: 47.100% (4710/10000)

Epoch: 1
............................................................................................................................................................
Training: [1560/1563] Loss: 1.373 | Acc: 50.355% (25137/49920)
....................
Eval: [200/200] Loss: 1.172 | Acc: 58.250% (5825/10000)

Epoch: 2
............................................................................................................................................................
Training: [1560/1563] Loss: 1.157 | Acc: 58.808% (29357/49920)
....................
Eval: [200/200] Loss: 1.049 | Acc: 62.900% (6290/10000)

Epoch: 3
.....................................................................

### 2.4 Save the trained model

In [None]:
torch.save(fp32_model.state_dict(), "data/mobilenetv2_cifar10_fp_state_dict.pt")
torch.save(fp32_model, "data/mobilenetv2_cifar10_fp.pt")

## 3. Model Fusing for QAT

In [10]:
from mobilenetv2 import quantizable_mobilenet_v2

# Load pretrained FP32 model
qat_model = quantizable_mobilenet_v2(pretrained=False, progress=True).to(device)
loaded_dict_enc = torch.load("data/mobilenetv2_cifar10_fp_state_dict.pt", map_location=device)
qat_model.load_state_dict(loaded_dict_enc, strict=False)

# use CPU on input_tensor as our backend for parsing GraphTopology forced model to be on CPU
print('\n Inverted Residual Block: Before fusion \n\n', qat_model.features[1].conv)
qat_model.eval()

# Fuses modules
qat_model.fuse_model()

# Note fusion of Conv+BN+Relu and Conv+Relu
print('\n Inverted Residual Block: After fusion\n\n',qat_model.features[1].conv)


 Inverted Residual Block: Before fusion 

 Sequential(
  (0): ConvBNActivation(
    (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

 Inverted Residual Block: After fusion

 Sequential(
  (0): ConvBNActivation(
    (0): ConvReLU2d(
      (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)
      (1): ReLU(inplace=True)
    )
    (1): Identity()
    (2): Identity()
  )
  (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
  (2): Identity()
)


The fused model structure is [fused model structure](./mobilenetv2_cifar10_fused_model.txt).

## 4. Model Preparation for QAT

- Try default QConfig

In [8]:
import torch.ao.quantization as quantizer

qat_model.qconfig = quantizer.get_default_qat_qconfig('fbgemm')

# prepare qat model using qconfig settings
qat_model.train()
quantizer.prepare_qat(qat_model, inplace=True)
print('Inverted Residual Block: After preparation for QAT, note fake-quantization modules \n',qat_model.features[1].conv)

Inverted Residual Block: After preparation for QAT, note fake-quantization modules 
 Sequential(
  (0): ConvBNActivation(
    (0): ConvReLU2d(
      32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32
      (weight_fake_quant): FusedMovingAvgObsFakeQuantize(
        fake_quant_enabled=tensor([1], device='cuda:0'), observer_enabled=tensor([1], device='cuda:0'), scale=tensor([1.], device='cuda:0'), zero_point=tensor([0], device='cuda:0', dtype=torch.int32), dtype=torch.qint8, quant_min=-128, quant_max=127, qscheme=torch.per_channel_symmetric, reduce_range=False
        (activation_post_process): MovingAveragePerChannelMinMaxObserver(min_val=tensor([], device='cuda:0'), max_val=tensor([], device='cuda:0'))
      )
      (activation_post_process): FusedMovingAvgObsFakeQuantize(
        fake_quant_enabled=tensor([1], device='cuda:0'), observer_enabled=tensor([1], device='cuda:0'), scale=tensor([1.], device='cuda:0'), zero_point=tensor([0], device='cuda:0', dtype=torch.int32

- Try customerized QConfig

    ```Note: the qat_model has to be reset. We can't do prepare_qat() twice. It may have unpredictable behaviours```.

In [11]:
activation_quant = quantizer.FakeQuantize.with_args(
            observer=quantizer.MovingAverageMinMaxObserver.with_args(dtype=torch.qint8), 
            quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_affine, reduce_range=False)
weight_quant = quantizer.FakeQuantize.with_args(
            observer=quantizer.MovingAveragePerChannelMinMaxObserver.with_args(dtype=torch.qint8), 
            quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_channel_symmetric, reduce_range=False)

# assign qconfig to model
qat_model.qconfig = torch.ao.quantization.QConfig(activation=activation_quant, weight=weight_quant)

# prepare qat model using qconfig settings
qat_model.train()
torch.ao.quantization.prepare_qat(qat_model, inplace=True)
print('Inverted Residual Block: After preparation for QAT, note fake-quantization modules \n',qat_model.features[1].conv)

Inverted Residual Block: After preparation for QAT, note fake-quantization modules 
 Sequential(
  (0): ConvBNActivation(
    (0): ConvReLU2d(
      32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32
      (weight_fake_quant): FakeQuantize(
        fake_quant_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), observer_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_channel_symmetric, ch_axis=0, scale=tensor([1.], device='cuda:0'), zero_point=tensor([0], device='cuda:0', dtype=torch.int32)
        (activation_post_process): MovingAveragePerChannelMinMaxObserver(min_val=tensor([], device='cuda:0'), max_val=tensor([], device='cuda:0'))
      )
      (activation_post_process): FakeQuantize(
        fake_quant_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), observer_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=tor

The prepared model structure is [prepared model structure](./mobilenetv2_cifar10_qat_prepared.txt).

## 5. Model Tuning with QAT

In [14]:
epochs = 5
ntrain_batches = 100
neval_batches = 200  # 10000/50

# Redefine optimizer by using smaller learning rate here
optimizer = torch.optim.SGD(qat_model.parameters(), lr=0.001)

# The "device", "train_loader", "test_loader", "criterion" are the same as training the float model
qat_fine_tune(qat_model, device, train_loader, test_loader, 
            optimizer, criterion, 
            ntrain_batches, neval_batches, epochs)


Epoch: 0
..........
Training: [100/1563] Loss: 0.306 | Acc: 89.531% (2865/3200)
....................
Eval: [200/200] Loss: 0.457 | Acc: 85.330% (8533/10000)

Epoch: 1
..........
Training: [100/1563] Loss: 0.279 | Acc: 90.219% (2887/3200)
....................
Eval: [200/200] Loss: 0.460 | Acc: 85.210% (8521/10000)

Epoch: 2
..........
Training: [100/1563] Loss: 0.286 | Acc: 89.625% (2868/3200)
....................
Eval: [200/200] Loss: 0.432 | Acc: 85.970% (8597/10000)

Epoch: 3
..........
Training: [100/1563] Loss: 0.277 | Acc: 90.656% (2901/3200)
....................
Eval: [200/200] Loss: 0.429 | Acc: 85.920% (8592/10000)

Epoch: 4
..........
Training: [100/1563] Loss: 0.265 | Acc: 90.906% (2909/3200)
....................
Eval: [200/200] Loss: 0.423 | Acc: 86.100% (8610/10000)


In [15]:
# Save the model state_dict()
torch.save(qat_model.state_dict(), "data/mobilenetv2_cifar10_qat_state_dict.pt")

# Please note we can't save the model structure, since the local observer object can't be serialized
# torch.save(qat_model.eval(), "data/mobilenetv2_cifar10_qat.pt")

## 6. Model Conversion

```Note that Model Conversion is currently only supported on CPUs.```

In [16]:
quantized_model = torch.ao.quantization.convert(qat_model.to('cpu').eval(), inplace=False)
print('Inverted Residual Block: After quantization and conversion done \n',quantized_model.features[1].conv)

Inverted Residual Block: After quantization and conversion done 
 Sequential(
  (0): ConvBNActivation(
    (0): QuantizedConvReLU2d(32, 32, kernel_size=(3, 3), stride=(1, 1), scale=0.043570488691329956, zero_point=-128, padding=(1, 1), groups=32)
    (1): Identity()
    (2): Identity()
  )
  (1): QuantizedConv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), scale=0.06476081162691116, zero_point=7)
  (2): Identity()
)




## 7. Model Save

In [17]:
torch.save(quantized_model, "./data/mobilenetv2_cifar10_quantized.pth")
torch.jit.save(torch.jit.script(quantized_model), "./data/mobilenetv2_cifar10_quantized_jit.pth")