# MobileNetV2 with CIFAR10


<img src="fig/pytorch_qat_flow.png">


## 0 Import necessary packages

In [1]:
import sys
import os

import torch
import torch.nn as nn

# Set up warnings
import warnings
warnings.filterwarnings(
    action='ignore',
    category=DeprecationWarning,
    module=r'.*'
)
warnings.filterwarnings(
    action='default',
    module=r'torch.ao.quantization'
)

# Specify random seed for repeatable results
torch.manual_seed(2023)

print(f"PyTorch version: {torch.__version__}")

PyTorch version: 1.13.0+cu117


## 1. Model Modificiations: Replace Functionals with Modules

- Replacing addition `+` with `nn.quantized.FloatFunctional().add()` module function

- Insert `QuantStub` and `DeQuantStub` at the beginning and end of the network.

- Replace ReLU6 with ReLU

- Replace last `torch.nn.Linear` layer with `torch.nn.Conv2d` 1x1 kernel, as 1x1 `Conv2d` has better performance than `Linear`.

- Define `fuse_module()` functions to specify how to fuse modules. 

Please check [mobilenetv2.py](./mobilenetv2.py) for all the details.    

## 2. Model Training

### 2.1 Define dataset and data loaders

In [2]:
import torchvision
from torchvision import datasets, transforms

from torch.utils.data import (DataLoader, TensorDataset)

def prepare_data_loaders(data_path, train_batch_size, eval_batch_size, dry_run):
    IMAGE_HEIGHT, IMAGE_WIDTH = 224, 224    

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    if dry_run:
        batch_size = 1        
        dummy_dataset = TensorDataset(torch.rand(batch_size, 3, 224, 224), torch.randint(0, 10, (batch_size,)))
        train_loader = DataLoader(dummy_dataset,
                                batch_size=batch_size)
        test_loader = DataLoader(dummy_dataset,
                                batch_size=1)
    else:
        trainset = torchvision.datasets.CIFAR10(root='./data',train=True, download=True, transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, shuffle=True)

        testset = torchvision.datasets.CIFAR10(root='./data',train=False, download=True, transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset, batch_size=eval_batch_size, shuffle=False)

    return train_loader, test_loader

### 2.2 Define training functions

In [3]:
def train_one_epoch(model, device, train_loader, optimizer, criterion, ntrain_batches):    
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        if batch_idx % 10 == 0:            
            print('.', end='')
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        if batch_idx >= ntrain_batches - 1:
            print('\nTraining: [%d/%d] Loss: %.3f | Acc: %.3f%% (%d/%d)'
                % (batch_idx+1, len(train_loader), train_loss/(batch_idx+1), 100.*correct/total, correct, total))
            return
    
    print('Full training set: [%d/%d] Loss: %.3f | Acc: %.3f%% (%d/%d)'
                % (batch_idx+1, len(train_loader), train_loss/(batch_idx+1), 100.*correct/total, correct, total))
    return

def eval(model, device, test_loader, criterion, neval_batches):
    model.eval()

    test_loss = 0
    correct = 0
    total = 0    
    
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            if batch_idx % 10 == 0:
                print('.', end='')
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            if batch_idx >= neval_batches - 1:
                acc = 100.*correct/total
                loss = test_loss/(batch_idx+1)

                print('\nEval: [%d/%d] Loss: %.3f | Acc: %.3f%% (%d/%d)'
                    % (batch_idx+1, len(test_loader), loss, acc, correct, total))
                
                return loss, acc
    
    return loss, acc

### 2.3 Train

- It will take long time if using the current settings of epochs, ntrain_batches and neval_batches.

- In order to verify the whole flow, you don't have to train and evaluate for large epochs for the model to converge.

    - Define smaller epochs, or smaller ntrain_batches, or smaller neval_batches
    
    - use `dry_run=1`, it will use one random noise sample to train the model

#### 2.3.1 Define the same data loaders and criterion for float model and QAT 

In [4]:
# Check if GPU is available or not
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# Prepare data loader
data_path = './data'
train_loader, test_loader = prepare_data_loaders(data_path, train_batch_size=32, eval_batch_size=50, dry_run=0)

# Define a loss function
criterion = nn.CrossEntropyLoss()

Files already downloaded and verified
Files already downloaded and verified


#### 2.3.2 Training

- If your float model already trained, you can skip this step and load the pretrained model later.

In [7]:
from mobilenetv2 import mobilenet_v2

epochs = 20
ntrain_batches = 1560  # 50000/32
neval_batches = 200  # 10000/50

# Construct the mode
fp32_model = mobilenet_v2(pretrained=False, progress=True)
fp32_model.to(device)

# define the optimizer
optimizer = torch.optim.SGD(fp32_model.parameters(), lr=0.01)

for nepoch in range(epochs):    
    print("\nEpoch: {}".format(nepoch))
    train_one_epoch(fp32_model, device, train_loader, optimizer, criterion, ntrain_batches=ntrain_batches)

    eval(fp32_model, device, test_loader, criterion, neval_batches=neval_batches)


Epoch: 0
............................................................................................................................................................
Training: [1560/1563] Loss: 1.790 | Acc: 35.296% (17620/49920)
....................
Eval: [200/200] Loss: 1.427 | Acc: 47.100% (4710/10000)

Epoch: 1
............................................................................................................................................................
Training: [1560/1563] Loss: 1.373 | Acc: 50.355% (25137/49920)
....................
Eval: [200/200] Loss: 1.172 | Acc: 58.250% (5825/10000)

Epoch: 2
............................................................................................................................................................
Training: [1560/1563] Loss: 1.157 | Acc: 58.808% (29357/49920)
....................
Eval: [200/200] Loss: 1.049 | Acc: 62.900% (6290/10000)

Epoch: 3
.....................................................................

### 2.4 Save the trained model

In [None]:
torch.save(fp32_model.state_dict(), "data/mobilenetv2_cifar10_fp_state_dict.pt")
torch.save(fp32_model, "data/mobilenetv2_cifar10_fp.pt")

## 3. Model Fusing for QAT

In [10]:
from mobilenetv2 import quantizable_mobilenet_v2

# Load pretrained FP32 model
qat_model = quantizable_mobilenet_v2(pretrained=False, progress=True).to(device)
loaded_dict_enc = torch.load("data/mobilenetv2_cifar10_fp_state_dict.pt", map_location=device)
qat_model.load_state_dict(loaded_dict_enc, strict=False)

# use CPU on input_tensor as our backend for parsing GraphTopology forced model to be on CPU
qat_model.eval()
qat_model.fuse_model()

In [11]:
print(qat_model)

QuantizableMobileNetV2(
  (features): Sequential(
    (0): ConvBNActivation(
      (0): ConvReLU2d(
        (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): ReLU(inplace=True)
      )
      (1): Identity()
      (2): Identity()
    )
    (1): QuantizableInvertedResidual(
      (conv): Sequential(
        (0): ConvBNActivation(
          (0): ConvReLU2d(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)
            (1): ReLU(inplace=True)
          )
          (1): Identity()
          (2): Identity()
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
        (2): Identity()
      )
      (skip_add): FloatFunctional(
        (activation_post_process): Identity()
      )
    )
    (2): QuantizableInvertedResidual(
      (conv): Sequential(
        (0): ConvBNActivation(
          (0): ConvReLU2d(
            (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1))
            (1): ReLU(inpla

## 4. Model Preparation for QAT

In [12]:
activation_quant = torch.ao.quantization.FakeQuantize.with_args(
            observer=quantizer.MovingAverageMinMaxObserver.with_args(dtype=torch.qint8), 
            quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_affine, reduce_range=False)
weight_quant = torch.ao.quantization.FakeQuantize.with_args(
            observer=quantizer.MovingAveragePerChannelMinMaxObserver.with_args(dtype=torch.qint8), 
            quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_channel_affine, reduce_range=False)            

# assign qconfig to model
qat_model.qconfig = torch.ao.quantization.QConfig(activation=activation_quant, weight=weight_quant)

# prepare qat model using qconfig settings
qat_model.train()
torch.ao.quantization.prepare_qat(qat_model, inplace=True)

QuantizableMobileNetV2(
  (features): Sequential(
    (0): ConvBNActivation(
      (0): ConvReLU2d(
        3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)
        (weight_fake_quant): FakeQuantize(
          fake_quant_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), observer_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_channel_affine, ch_axis=0, scale=tensor([1.], device='cuda:0'), zero_point=tensor([0], device='cuda:0', dtype=torch.int32)
          (activation_post_process): MovingAveragePerChannelMinMaxObserver(min_val=tensor([], device='cuda:0'), max_val=tensor([], device='cuda:0'))
        )
        (activation_post_process): FakeQuantize(
          fake_quant_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), observer_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_affine, ch_axis=-1, scale=tens

## 5. Model Tuning with QAT

In [18]:
def qat_fine_tune(model, device, train_loader, test_loader, 
                  optimizer, criterion, 
                  ntrain_batches, neval_batches, epochs):
    print("\nEpoch: 0")
    # train one epoch first
    train_one_epoch(model, device, train_loader, optimizer, criterion, ntrain_batches=ntrain_batches)

    # Freeze quantizer parameters
    model.apply(torch.quantization.disable_observer)
        
    # Freeze batch norm mean and variance estimates
    model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)

    eval(model, device, test_loader, criterion, neval_batches=neval_batches)

    # QAT takes time and one needs to train over a few epochs.
    # Train and check accuracy after each epoch
    for epoch in range(1, epochs):
        print("\nEpoch: {}".format(epoch))
        train_one_epoch(model, device, train_loader, optimizer, criterion, ntrain_batches=ntrain_batches)

        eval(model, device, test_loader, criterion, neval_batches=neval_batches)        

In [19]:
epochs = 5
ntrain_batches = 100
neval_batches = 200  # 10000/50

# Redefine optimizer by using smaller learning rate here
optimizer = torch.optim.SGD(qat_model.parameters(), lr=0.001)

# The "device", "train_loader", "test_loader", "criterion" are the same as training the float model
qat_fine_tune(qat_model, device, train_loader, test_loader, 
            optimizer, criterion, 
            ntrain_batches, neval_batches, epochs)


Epoch: 0
..........
Training: [100/1563] Loss: 0.304 | Acc: 89.469% (2863/3200)
....................
Eval: [200/200] Loss: 0.455 | Acc: 85.320% (8532/10000)

Epoch: 1
..........
Training: [100/1563] Loss: 0.277 | Acc: 90.000% (2880/3200)
....................
Eval: [200/200] Loss: 0.451 | Acc: 85.480% (8548/10000)

Epoch: 2
..........
Training: [100/1563] Loss: 0.288 | Acc: 89.594% (2867/3200)
....................
Eval: [200/200] Loss: 0.436 | Acc: 85.840% (8584/10000)

Epoch: 3
..........
Training: [100/1563] Loss: 0.278 | Acc: 90.312% (2890/3200)
....................
Eval: [200/200] Loss: 0.429 | Acc: 85.990% (8599/10000)

Epoch: 4
..........
Training: [100/1563] Loss: 0.267 | Acc: 90.969% (2911/3200)
....................
Eval: [200/200] Loss: 0.427 | Acc: 85.880% (8588/10000)
