# MobileNetV2 with CIFAR10

## 0 Preparation

In [1]:
# Make sure bstnnx_training is in the system Python path
import sys
import os

sys.path.append("/home/hongbing/Projects/bst-study/bstnnx_training")

import bstnnx_training

print(bstnnx_training.__version__)

1.4.3.2


In [2]:

import torch
import torch.nn as nn
print(f"PyTorch version: {torch.__version__}")

assert bstnnx_training.__version__ >= '1.0.0', 'This notebook need to use bstnnx training >= 1.0.0 release'
assert '1.9.1' in torch.__version__, 'This notebook need to use pytorch 1.9.1 release'

PyTorch version: 1.9.1+cu102


## 1. Modify Model: Replace Functionals with Modules

- Replace `+` with `Add()` module.
- Replace `adaptive_avg_pool2d` with `A1000A0AvgPool2d`, as `adaptive_avg_pool2d` is not supported by BST.
- Replace last `torch.nn.Linear` layer with `torch.nn.Conv2d` 1x1 kernel, as 1x1 `Conv2d` has better performance than `Linear` on BST hardware.

- QAT Only: 

    - Insert `QuantStub` and `DeQuantStub` at the beginning and end of the network.
    - Note: BST QAT can fuse modules by auto detection, unnecessary to specify how to fuse modules.

- PyTorch QAT Only:

    - Replace `ReLU6` with `ReLU`.
    - Define `fuse_module()` functions to specify how to fuse modules. PyTorch QAT has to fuse modules manually.

Please check `quantable_mobilenetv2.py` for all the details.    

## 2. Train Float Model

### 2.1 Define dataset and data loaders

In [3]:
import torchvision
from torchvision import datasets, transforms

from torch.utils.data import (DataLoader, TensorDataset)

def prepare_data_loaders(data_path, train_batch_size, eval_batch_size, dry_run):
    IMAGE_HEIGHT, IMAGE_WIDTH = 224, 224    

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    if dry_run:
        batch_size = 1
        torch.manual_seed(2022)
        dummy_dataset = TensorDataset(torch.rand(batch_size, 3, 224, 224), torch.randint(0, 10, (batch_size,)))
        train_loader = DataLoader(dummy_dataset,
                                batch_size=batch_size)
        test_loader = DataLoader(dummy_dataset,
                                batch_size=1)
    else:
        trainset = torchvision.datasets.CIFAR10(root='./data',train=True, download=True, transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, shuffle=True)

        testset = torchvision.datasets.CIFAR10(root='./data',train=False, download=True, transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset, batch_size=eval_batch_size, shuffle=False)

    return train_loader, test_loader

### 2.2 Define training functions

In [4]:
def train_one_epoch(model, device, train_loader, optimizer, criterion, ntrain_batches):    
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        if batch_idx % 10 == 0:            
            print('.', end='')
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        if batch_idx >= ntrain_batches - 1:
            print('\nTraining: [%d/%d] Loss: %.3f | Acc: %.3f%% (%d/%d)'
                % (batch_idx+1, len(train_loader), train_loss/(batch_idx+1), 100.*correct/total, correct, total))
            return
    
    print('Full training set: [%d/%d] Loss: %.3f | Acc: %.3f%% (%d/%d)'
                % (batch_idx+1, len(train_loader), train_loss/(batch_idx+1), 100.*correct/total, correct, total))
    return

def eval(model, device, test_loader, criterion, neval_batches):
    model.eval()

    test_loss = 0
    correct = 0
    total = 0    
    
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            if batch_idx % 10 == 0:
                print('.', end='')
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            if batch_idx >= neval_batches - 1:
                acc = 100.*correct/total
                loss = test_loss/(batch_idx+1)

                print('\nEval: [%d/%d] Loss: %.3f | Acc: %.3f%% (%d/%d)'
                    % (batch_idx+1, len(test_loader), loss, acc, correct, total))
                
                return loss, acc
    
    return loss, acc

### 2.3 Train

- It will take long time if using the current settings of epochs, ntrain_batches and neval_batches.
- In order to verify the whole flow, you don't have to train and evaluate for large epochs for the model to converge.
    - Define smaller epochs, or smaller ntrain_batches, or smaller neval_batches
    - use `dry_run=1`, it will use one random noise sample to train the model

#### 2.3.1 Define the same data loaders and criterion for float model and QAT 

In [8]:
# Check if GPU is available or not
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# Prepare data loader
data_path = './data'
train_loader, test_loader = prepare_data_loaders(data_path, train_batch_size=32, eval_batch_size=50, dry_run=0)

# Define a loss function
criterion = nn.CrossEntropyLoss()

Files already downloaded and verified
Files already downloaded and verified


#### 2.3.2 Training

- If your float model already trained, you can skip this step and load the pretrained model later.

In [18]:
from mobilenetv2 import mobilenet_v2

epochs = 20
ntrain_batches = 1560  # 50000/32
neval_batches = 200  # 10000/50

# Construct the mode
fp32_model = mobilenet_v2(pretrained=False, progress=True)
fp32_model.to(device)

# define the optimizer
optimizer = torch.optim.SGD(fp32_model.parameters(), lr=0.01)

for nepoch in range(epochs):    
    print("\nEpoch: {}".format(nepoch))
    train_one_epoch(fp32_model, device, train_loader, optimizer, criterion, ntrain_batches=ntrain_batches)

    eval(fp32_model, device, test_loader, criterion, neval_batches=neval_batches)

Files already downloaded and verified
Files already downloaded and verified

------------------------
Epoch: 0
.........................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................

### 2.4 Save the trained model

In [19]:
torch.save(fp32_model.state_dict(), "data/mobilenetv2_fashion_mnist_fp.pt")

## 3. Fuse Modules for QAT

In [9]:
# switch the quantization framework
import bstnnx_training.PyTorch.QAT.core as quantizer

In [10]:
from mobilenetv2 import quantizable_mobilenet_v2

# Load pretrained FP32 model
fp32_model = quantizable_mobilenet_v2(pretrained=False, progress=True, use_bstnn=True).to(device)
loaded_dict_enc = torch.load("data/mobilenetv2_fashion_mnist_fp.pt", map_location=device)
fp32_model.load_state_dict(loaded_dict_enc)

# define one sample data used for fusing model
sample_data = torch.randn(1, 3, 224, 224, requires_grad=True)

# use CPU on input_tensor as our backend for parsing GraphTopology forced model to be on CPU
fp32_model.eval()
fused_model = quantizer.fuse_modules(fp32_model, auto_detect=True, input_tensor=sample_data.cpu())

## 4. Define QConfig and Prepare Model for QAT

In [11]:
def prepare_qat_model(model, device, backend='default', sample_data=None):
    model.to(device)
    model.train()
    
    if backend == 'default':
        activation_quant = quantizer.FakeQuantize.with_args(
            observer=quantizer.default_observer.with_args(dtype=torch.qint8), 
            quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_affine, reduce_range=False)
        weight_quant = quantizer.FakeQuantize.with_args(
            observer=quantizer.default_observer.with_args(dtype=torch.qint8), 
            quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_affine, reduce_range=False)

        # assign qconfig to model
        model.qconfig = quantizer.QConfig(activation=activation_quant, weight=weight_quant)
        
        # prepare qat model using qconfig settings
        prepared_model = quantizer.prepare_qat(model, inplace=False)

    elif backend == 'bst':
        bst_activation_quant = quantizer.FakeQuantize.with_args(
            observer=quantizer.MinMaxObserver.with_args(dtype=torch.qint8), 
            quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_affine, reduce_range=False)
        bst_weight_quant = quantizer.FakeQuantize.with_args(
            observer=quantizer.MinMaxObserver.with_args(dtype=torch.qint8), 
            quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_affine, reduce_range=False)
        
        # 1) [bst_alignment] get b0 pre-bind qconfig adjusting Conv's activation quant scheme
        pre_bind_qconfig = quantizer.pre_bind(model, input_tensor=sample_data.to('cpu'))
        
        # 2) assign qconfig to model
        model.qconfig = quantizer.QConfig(activation=bst_activation_quant, weight=bst_weight_quant,
                                                    qconfig_dict=pre_bind_qconfig)
        
        # 3) prepare qat model using qconfig settings
        prepared_model = quantizer.prepare_qat(model, inplace=False)  
        
        # 4) [bst_alignment] link model observers
        prepared_model = quantizer.link_modules(prepared_model, auto_detect=True, input_tensor=sample_data.to('cpu'), inplace=False)
    
    prepared_model.eval()
    
    return prepared_model


In [12]:
prepared_model = prepare_qat_model(fused_model, device, backend="bst", sample_data=sample_data)

  max_v = torch.tensor(max(abs(self.min_val), abs(self.max_val)), device=device)


## 5. Fine Tune with QAT

In [13]:
def qat_fine_tune(model, device, train_loader, test_loader, 
                  optimizer, criterion, 
                  ntrain_batches, neval_batches, 
                  sample_data, epochs):
    print("\nEpoch: 0")
    # train one epoch first
    train_one_epoch(model, device, train_loader, optimizer, criterion, ntrain_batches=ntrain_batches)

    # Freeze quantizer parameters
    model.apply(torch.quantization.disable_observer)
            
    # Extra step: to align hardware, it will only be applied once for unaligned model
    quantizer.align_bst_hardware(model, sample_data)        
        
    # Freeze batch norm mean and variance estimates
    model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)

    eval(model, device, test_loader, criterion, neval_batches=neval_batches)

    # QAT takes time and one needs to train over a few epochs.
    # Train and check accuracy after each epoch
    for epoch in range(1, epochs):
        print("\nEpoch: {}".format(epoch))
        train_one_epoch(model, device, train_loader, optimizer, criterion, ntrain_batches=ntrain_batches)

        eval(model, device, test_loader, criterion, neval_batches=neval_batches)        

In [14]:
epochs = 5
ntrain_batches = 100
neval_batches = 200  # 10000/50

# Redefine optimizer by using smaller learning rate here
optimizer = torch.optim.SGD(prepared_model.parameters(), lr=0.001)

# The "device", "train_loader", "test_loader", "criterion" are the same as training the float model
qat_fine_tune(prepared_model, device, train_loader, test_loader, 
            optimizer, criterion, 
            ntrain_batches, neval_batches, 
            sample_data, epochs)


Epoch: 0
..........
Training: [100/1563] Loss: 0.296 | Acc: 89.812% (2874/3200)
....................
Eval: [200/200] Loss: 0.437 | Acc: 85.810% (8581/10000)

Epoch: 1
..........
Training: [100/1563] Loss: 0.288 | Acc: 89.812% (2874/3200)
....................
Eval: [200/200] Loss: 0.440 | Acc: 85.170% (8517/10000)

Epoch: 2
..........
Training: [100/1563] Loss: 0.310 | Acc: 88.625% (2836/3200)
....................
Eval: [200/200] Loss: 0.435 | Acc: 85.810% (8581/10000)

Epoch: 3
..........
Training: [100/1563] Loss: 0.276 | Acc: 90.344% (2891/3200)
....................
Eval: [200/200] Loss: 0.430 | Acc: 85.520% (8552/10000)

Epoch: 4
..........
Training: [100/1563] Loss: 0.275 | Acc: 90.469% (2895/3200)
....................
Eval: [200/200] Loss: 0.427 | Acc: 86.100% (8610/10000)


## 6. Export to float ONNX model and quantization parameters with JSON

In [15]:
stage_dict={}
stage_dict['simplify_onnx'] = True
onnx_model_path, quant_param_json_path = quantizer.export_onnx(prepared_model, 
                                                               sample_data, 
                                                               stage_dict=stage_dict, 
                                                               result_dir='./data')

  if self.observer_enabled[0] == 1:
  if self.fake_quant_enabled[0] == 1:
  X, float(self.scale), int(self.zero_point),
  X, float(self.scale), int(self.zero_point),
100%|██████████| 102/102 [00:00<00:00, 504503.55it/s]
